Browse Source

完成tcp server

yjp 1 year ago
parent
commit
22ecb209eb
7 changed files with 487 additions and 33 deletions
  1. 6 6
      network/connection.go
  2. 161 0
      network/tcp_client.go
  3. 213 0
      network/tcp_server.go
  4. 87 0
      network/tcp_test.go
  5. 7 9
      network/udp_client.go
  6. 6 13
      network/udp_server.go
  7. 7 5
      network/udp_test.go

+ 6 - 6
network/connection.go

@@ -97,8 +97,8 @@ func writeUDP(conn *net.UDPConn, data []byte, opts ...connectionWriteOptions) er
 	return nil
 }
 
-// ReadTCP 读取TCP数据
-func ReadTCP(conn net.Conn, bufferSize int, readCallback func(data []byte) (bool, error), opts ...connectionReadOptions) error {
+// readTCP 读取TCP数据
+func readTCP(conn net.Conn, bufferSize int, readCallback func(data []byte) (bool, error), opts ...connectionReadOptions) error {
 	for {
 		buffer := make([]byte, bufferSize)
 
@@ -127,8 +127,8 @@ func ReadTCP(conn net.Conn, bufferSize int, readCallback func(data []byte) (bool
 	}
 }
 
-// WriteTCP 写TCP数据
-func WriteTCP(conn net.Conn, data []byte, opts ...connectionWriteOptions) error {
+// writeTCP 写TCP数据
+func writeTCP(conn net.Conn, data []byte, opts ...connectionWriteOptions) error {
 	writeBytesCount := 0
 
 	for {
@@ -153,8 +153,8 @@ func WriteTCP(conn net.Conn, data []byte, opts ...connectionWriteOptions) error
 	}
 }
 
-// CloseConnection 关闭连接
-func CloseConnection(conn net.Conn) {
+// closeConnection 关闭连接
+func closeConnection(conn net.Conn) {
 	err := conn.Close()
 	if err != nil {
 		fmt.Println("Close Connection Error:", err.Error())

+ 161 - 0
network/tcp_client.go

@@ -1 +1,162 @@
 package network
+
+import (
+	"fmt"
+	"net"
+	"time"
+)
+
+const (
+	tcpClientReceiveBufferSize = 1024
+)
+
+// TCPClientResponseCallback 响应回调
+// 参数:
+// dataReader: 响应数据DataReader
+// 返回值:
+// 是否读取完成,false继续读取
+type TCPClientResponseCallback func(dataReader *DataReader) bool
+
+type TCPClientOption func(opt *TCPClientOptions)
+
+func WithTCPClientRequestNonBlockCount(count int) TCPClientOption {
+	return func(opt *TCPClientOptions) {
+		opt.requestNonBlockCount = count
+	}
+}
+
+func WithTCPClientReceiveBufferSize(receiveBufferSize int) TCPClientOption {
+	return func(opt *TCPClientOptions) {
+		opt.receiveBufferSize = receiveBufferSize
+	}
+}
+
+func WithTCPClientWriteTimeout(timeout time.Duration) TCPClientOption {
+	return func(opt *TCPClientOptions) {
+		opt.writeTimeout = timeout
+	}
+}
+
+func WithTCPClientReadTimeout(timeout time.Duration) TCPClientOption {
+	return func(opt *TCPClientOptions) {
+		opt.readTimeout = timeout
+	}
+}
+
+func WithTCPClientResponseCallback(responseCallback TCPClientResponseCallback) TCPClientOption {
+	return func(opt *TCPClientOptions) {
+		opt.responseCallback = responseCallback
+	}
+}
+
+type TCPClientOptions struct {
+	// 不阻塞的请求数量,默认为阻塞
+	requestNonBlockCount int
+
+	// 默认1024字节,一般保证足够收取一个数据包
+	receiveBufferSize int
+
+	// 写超时,不设置就是阻塞写
+	writeTimeout time.Duration
+
+	// 读超时,不设置就是阻塞读
+	readTimeout time.Duration
+
+	// 响应数据回调
+	responseCallback TCPClientResponseCallback
+}
+
+func NewTCPClientOptions(opts ...TCPClientOption) *TCPClientOptions {
+	options := new(TCPClientOptions)
+
+	for _, opt := range opts {
+		opt(options)
+	}
+
+	if options.receiveBufferSize == 0 {
+		options.receiveBufferSize = tcpClientReceiveBufferSize
+	}
+
+	return options
+}
+
+type TCPClient struct {
+	options     *TCPClientOptions
+	conn        *net.TCPConn
+	requestChan chan []byte
+	doneChan    chan any
+}
+
+// Connect 建立连接
+func (client *TCPClient) Connect(serverAddress string, options *TCPClientOptions) error {
+	serverAddr, err := net.ResolveTCPAddr("tcp", serverAddress)
+	if err != nil {
+		panic(err)
+	}
+
+	conn, err := net.DialTCP("tcp", nil, serverAddr)
+	if err != nil {
+		panic(err)
+	}
+
+	client.options = options
+	client.conn = conn
+	client.requestChan = make(chan []byte, options.requestNonBlockCount)
+	client.doneChan = make(chan any)
+
+	// 启动发送请求协程
+	go client.sendRequest()
+
+	return nil
+}
+
+// Disconnect 断开连接
+func (client *TCPClient) Disconnect() {
+	client.doneChan <- nil
+	close(client.doneChan)
+	client.doneChan = nil
+
+	close(client.requestChan)
+	client.requestChan = nil
+
+	closeConnection(client.conn)
+	client.conn = nil
+}
+
+// Send 发送数据包,data应该为大端字节序
+func (client *TCPClient) Send(data []byte) {
+	client.requestChan <- data
+}
+
+func (client *TCPClient) sendRequest() {
+	dealRequestDoneChannels := make([]chan any, 0)
+
+	for {
+		select {
+		case <-client.doneChan:
+			for _, dealRequestDoneChan := range dealRequestDoneChannels {
+				dealRequestDoneChan <- nil
+				close(dealRequestDoneChan)
+			}
+
+			return
+		case data := <-client.requestChan:
+			err := writeTCP(client.conn, data, withWriteDeadline(client.options.writeTimeout))
+			if err != nil {
+				fmt.Println(err)
+				continue
+			}
+
+			if client.options.responseCallback != nil {
+				err = readTCP(client.conn, client.options.receiveBufferSize, func(data []byte) (bool, error) {
+					readOver := client.options.responseCallback(NewDataReader(data))
+					return readOver, nil
+				}, withReadDeadline(client.options.readTimeout))
+				if err != nil {
+					fmt.Println(err)
+					continue
+				}
+			}
+		}
+	}
+}

+ 213 - 0
network/tcp_server.go

@@ -1 +1,214 @@
 package network
+
+import (
+	"errors"
+	"fmt"
+	"git.sxidc.com/go-tools/utils/syncutils"
+	"io"
+	"net"
+	"time"
+)
+
+const (
+	tcpServerReceiveBufferSize = 1024
+)
+
+var TCPServerIgnoreResponse = errors.New("忽略响应")
+
+// TCPServerRequestCallback 请求回调
+// 参数:
+// dataReader: 请求数据DataReader
+// 返回值:
+// 响应数据DataReader: 可以使用DataWriter构造,然后使用ToReader转换成DataReader,返回nil代表还要继续接收数据,不做响应
+type TCPServerRequestCallback func(dataReader *DataReader) *DataReader
+
+type TCPServerOption func(opt *TCPServerOptions)
+
+func WithTCPServerReceiveBufferSize(receiveBufferSize int) TCPServerOption {
+	return func(opt *TCPServerOptions) {
+		opt.receiveBufferSize = receiveBufferSize
+	}
+}
+
+func WithTCPServerWriteTimeout(timeout time.Duration) TCPServerOption {
+	return func(opt *TCPServerOptions) {
+		opt.writeTimeout = timeout
+	}
+}
+
+func WithTCPServerReadTimeout(timeout time.Duration) TCPServerOption {
+	return func(opt *TCPServerOptions) {
+		opt.readTimeout = timeout
+	}
+}
+
+func WithTCPServerRequestCallback(requestCallback TCPServerRequestCallback) TCPServerOption {
+	return func(opt *TCPServerOptions) {
+		opt.requestCallback = requestCallback
+	}
+}
+
+type TCPServerOptions struct {
+	// 默认1024字节,一般保证足够收取一个数据包
+	receiveBufferSize int
+
+	// 写超时,不设置就是阻塞写
+	writeTimeout time.Duration
+
+	// 读超时,不设置就是阻塞读
+	readTimeout time.Duration
+
+	// 请求数据回调
+	requestCallback TCPServerRequestCallback
+}
+
+func NewTCPServerOptions(opts ...TCPServerOption) *TCPServerOptions {
+	options := new(TCPServerOptions)
+
+	for _, opt := range opts {
+		opt(options)
+	}
+
+	if options.receiveBufferSize == 0 {
+		options.receiveBufferSize = tcpServerReceiveBufferSize
+	}
+
+	return options
+}
+
+type TCPServer struct {
+	options  *TCPServerOptions
+	listener *net.TCPListener
+	doneChan chan any
+}
+
+// Connect 建立连接
+func (server *TCPServer) Connect(address string, options *TCPServerOptions) error {
+	addr, err := net.ResolveTCPAddr("tcp", address)
+	if err != nil {
+		return err
+	}
+
+	listener, err := net.ListenTCP("tcp", addr)
+	if err != nil {
+		return err
+	}
+
+	server.options = options
+	server.listener = listener
+	server.doneChan = make(chan any)
+
+	// 启动读取请求协程
+	go server.accept()
+
+	return nil
+}
+
+// Disconnect 断开连接
+func (server *TCPServer) Disconnect() error {
+	err := server.listener.Close()
+	if err != nil {
+		return err
+	}
+	server.listener = nil
+
+	server.doneChan <- nil
+	close(server.doneChan)
+	server.doneChan = nil
+
+	return nil
+}
+
+func (server *TCPServer) accept() {
+	readRequestDoneChannels := syncutils.NewSyncVar(make([]chan any, 0), false)
+
+	for {
+		select {
+		case <-server.doneChan:
+			readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
+				for _, channel := range channels {
+					channel <- nil
+					close(channel)
+				}
+
+				return make([]chan any, 0)
+			})
+
+			return
+		default:
+			conn, err := server.listener.AcceptTCP()
+			if err != nil {
+				fmt.Println(err)
+				continue
+			}
+
+			readRequestDoneChan := make(chan any)
+			readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
+				channels = append(channels, readRequestDoneChan)
+				return channels
+			})
+
+			go server.readRequest(conn, readRequestDoneChan, func(doneChan chan any, err error) {
+				readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
+					for i, toRemoveDoneChan := range channels {
+						if doneChan == toRemoveDoneChan {
+							close(toRemoveDoneChan)
+							channels = append(channels[:i], channels[i+1:]...)
+							break
+						}
+					}
+
+					return channels
+				})
+			})
+		}
+	}
+}
+
+func (server *TCPServer) readRequest(conn *net.TCPConn, doneChan chan any, errCloseCallback func(doneChan chan any, err error)) {
+	for {
+		select {
+		case <-doneChan:
+			if conn != nil {
+				closeConnection(conn)
+			}
+
+			return
+		default:
+			err := readTCP(conn, server.options.receiveBufferSize, func(data []byte) (bool, error) {
+				// 没有提供请求响应函数
+				if server.options.requestCallback == nil {
+					return false, nil
+				}
+
+				// 交给上层回调处理,返回处理结果和响应数据
+				responseDataReader := server.options.requestCallback(NewDataReader(data))
+				if responseDataReader != nil {
+					server.response(conn, responseDataReader.GetBytes())
+					return true, nil
+				}
+
+				return true, nil
+			}, withReadDeadline(server.options.readTimeout))
+			if err != nil {
+				// 对端关闭
+				if err == io.EOF {
+					closeConnection(conn)
+					errCloseCallback(doneChan, err)
+					return
+				}
+
+				fmt.Println(err)
+				continue
+			}
+		}
+	}
+}
+
+func (server *TCPServer) response(conn net.Conn, data []byte) {
+	err := writeTCP(conn, data, withWriteDeadline(server.options.writeTimeout))
+	if err != nil {
+		fmt.Println("Response Error:", err)
+		return
+	}
+}

+ 87 - 0
network/tcp_test.go

@@ -0,0 +1,87 @@
+package network
+
+import (
+	"fmt"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+)
+
+const (
+	testTCPServerAddress           = "127.0.0.1:10060"
+	testTCPServerTimeout           = time.Second
+	testTCPServerReceiveBufferSize = 1024
+
+	testTCPClientTimeout           = time.Second
+	testTCPClientReceiveBufferSize = 1024
+)
+
+func TestTCP(t *testing.T) {
+	server := &TCPServer{}
+
+	err := server.Connect(testTCPServerAddress, NewTCPServerOptions(
+		WithTCPServerReadTimeout(testTCPServerTimeout),
+		WithTCPServerWriteTimeout(testTCPServerTimeout),
+		WithTCPServerReceiveBufferSize(testTCPServerReceiveBufferSize),
+		WithTCPServerRequestCallback(func(dataReader *DataReader) *DataReader {
+			requestBytes, err := dataReader.Bytes(dataReader.Len())
+			if err != nil {
+				t.Fatal(err)
+				return nil
+			}
+
+			responseWriter := NewDataWriter()
+			err = responseWriter.Bytes([]byte(strings.ToUpper(string(requestBytes))))
+			if err != nil {
+				t.Fatal(err)
+				return nil
+			}
+
+			return responseWriter.ToReader()
+		}),
+	))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	defer func(server *TCPServer) {
+		err := server.Disconnect()
+		if err != nil {
+			t.Fatal(err)
+		}
+	}(server)
+
+	wg := &sync.WaitGroup{}
+	wg.Add(2)
+
+	client := &TCPClient{}
+	err = client.Connect(testTCPServerAddress, NewTCPClientOptions(
+		WithTCPClientReadTimeout(testTCPClientTimeout),
+		WithTCPClientWriteTimeout(testTCPClientTimeout),
+		WithTCPClientReceiveBufferSize(testTCPClientReceiveBufferSize),
+		WithTCPClientRequestNonBlockCount(2),
+		WithTCPClientResponseCallback(func(dataReader *DataReader) bool {
+			defer wg.Done()
+
+			requestBytes, err := dataReader.Bytes(dataReader.Len())
+			if err != nil {
+				t.Fatal(err)
+				return true
+			}
+
+			fmt.Println(string(requestBytes))
+			return true
+		}),
+	))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	defer client.Disconnect()
+
+	client.Send([]byte("test1"))
+	client.Send([]byte("test2"))
+
+	wg.Wait()
+}

+ 7 - 9
network/udp_client.go

@@ -13,8 +13,6 @@ const (
 // UDPClientResponseCallback 响应回调
 // 参数:
 // dataReader: 响应数据DataReader
-// 返回值:
-// 是否存在错误
 type UDPClientResponseCallback func(dataReader *DataReader)
 
 type UDPClientOption func(opt *UDPClientOptions)
@@ -119,7 +117,7 @@ func (client *UDPClient) Disconnect() {
 	close(client.requestChan)
 	client.requestChan = nil
 
-	CloseConnection(client.conn)
+	closeConnection(client.conn)
 	client.conn = nil
 }
 
@@ -147,13 +145,13 @@ func (client *UDPClient) sendRequest() {
 				continue
 			}
 
-			responseBytes, _, err := readUDP(client.conn, client.options.receiveBufferSize, withReadDeadline(client.options.readTimeout))
-			if err != nil {
-				fmt.Println(err)
-				continue
-			}
-
 			if client.options.responseCallback != nil {
+				responseBytes, _, err := readUDP(client.conn, client.options.receiveBufferSize, withReadDeadline(client.options.readTimeout))
+				if err != nil {
+					fmt.Println(err)
+					continue
+				}
+
 				go client.options.responseCallback(NewDataReader(responseBytes))
 			}
 		}

+ 6 - 13
network/udp_server.go

@@ -17,9 +17,8 @@ var UDPServerIgnoreResponse = errors.New("忽略响应")
 // 参数:
 // dataReader: 请求数据DataReader
 // 返回值:
-// 响应数据DataReader: 可以使用DataWriter构造,然后使用ToReader转换成DataReader
-// 是否存在错误: 如果是UDPServerIgnoreResponse,则忽略,不进行响应
-type UDPServerRequestCallback func(dataReader *DataReader) (*DataReader, error)
+// 响应数据DataReader: 可以使用DataWriter构造,然后使用ToReader转换成DataReader,返回nil代表还要继续接收数据,不做响应
+type UDPServerRequestCallback func(dataReader *DataReader) *DataReader
 
 type UDPServerOption func(opt *UDPServerOptions)
 
@@ -110,7 +109,7 @@ func (server *UDPServer) Disconnect() {
 	close(server.doneChan)
 	server.doneChan = nil
 
-	CloseConnection(server.conn)
+	closeConnection(server.conn)
 	server.conn = nil
 }
 
@@ -154,18 +153,12 @@ func (server *UDPServer) dealRequest(data []byte, rAddr *net.UDPAddr, doneChan c
 			}
 
 			// 交给上层回调处理,返回处理结果和响应数据
-			responseDataReader, err := server.options.requestCallback(NewDataReader(data))
-			if err != nil {
-				// 忽略响应
-				if errors.Is(err, UDPServerIgnoreResponse) {
-					return
-				}
-
-				server.response(server.conn, rAddr, []byte(err.Error()))
+			responseDataReader := server.options.requestCallback(NewDataReader(data))
+			if responseDataReader != nil {
+				server.response(server.conn, rAddr, responseDataReader.GetBytes())
 				return
 			}
 
-			server.response(server.conn, rAddr, responseDataReader.GetBytes())
 			return
 		}
 	}

+ 7 - 5
network/udp_test.go

@@ -24,19 +24,21 @@ func TestUDP(t *testing.T) {
 		WithUDPServerReadTimeout(testUDPServerTimeout),
 		WithUDPServerWriteTimeout(testUDPServerTimeout),
 		WithUDPServerReceiveBufferSize(testUDPServerReceiveBufferSize),
-		WithUDPServerRequestCallback(func(dataReader *DataReader) (*DataReader, error) {
+		WithUDPServerRequestCallback(func(dataReader *DataReader) *DataReader {
 			requestBytes, err := dataReader.Bytes(dataReader.Len())
 			if err != nil {
-				return nil, err
+				t.Fatal(err)
+				return nil
 			}
 
 			responseWriter := NewDataWriter()
 			err = responseWriter.Bytes([]byte(strings.ToUpper(string(requestBytes))))
 			if err != nil {
-				return nil, err
+				t.Fatal(err)
+				return nil
 			}
 
-			return responseWriter.ToReader(), nil
+			return responseWriter.ToReader()
 		}),
 	))
 	if err != nil {
@@ -57,7 +59,7 @@ func TestUDP(t *testing.T) {
 		WithUDPClientResponseCallback(func(dataReader *DataReader) {
 			requestBytes, err := dataReader.Bytes(dataReader.Len())
 			if err != nil {
-				fmt.Println(err)
+				t.Fatal(err)
 				return
 			}