Browse Source

完成udp客户端编写

yjp 1 year ago
parent
commit
0957c8057b
4 changed files with 179 additions and 14 deletions
  1. 2 2
      network/connection.go
  2. 173 0
      network/udp_client.go
  3. 2 3
      network/udp_server.go
  4. 2 9
      network/udp_test.go

+ 2 - 2
network/connection.go

@@ -80,8 +80,8 @@ func writeUDPWithRemoteAddr(conn *net.UDPConn, rAddr *net.UDPAddr, data []byte,
 	return nil
 }
 
-// WriteUDP 写出UDP包
-func WriteUDP(conn *net.UDPConn, data []byte, opts ...connectionWriteOptions) error {
+// writeUDP 写出UDP包
+func writeUDP(conn *net.UDPConn, data []byte, opts ...connectionWriteOptions) error {
 	for _, opt := range opts {
 		err := opt(conn)
 		if err != nil {

+ 173 - 0
network/udp_client.go

@@ -1 +1,174 @@
 package network
+
+import (
+	"bytes"
+	"fmt"
+	"net"
+	"time"
+)
+
+const (
+	udpClientReceiveBufferSize = 1024
+)
+
+// UDPClientResponseCallback 响应回调
+// 参数:
+// data: 响应数据,大端模式,可以用DataReader处理
+// 返回值:
+// 是否存在错误
+type UDPClientResponseCallback func(data []byte)
+
+type UDPClientOption func(opt *UDPClientOptions)
+
+func WithUDPClientRequestNonBlockCount(count int) UDPClientOption {
+	return func(opt *UDPClientOptions) {
+		opt.requestNonBlockCount = count
+	}
+}
+
+func WithUDPClientReceiveBufferSize(receiveBufferSize int) UDPClientOption {
+	return func(opt *UDPClientOptions) {
+		opt.receiveBufferSize = receiveBufferSize
+	}
+}
+
+func WithUDPClientWriteTimeout(timeout time.Duration) UDPClientOption {
+	return func(opt *UDPClientOptions) {
+		opt.writeTimeout = timeout
+	}
+}
+
+func WithUDPClientReadTimeout(timeout time.Duration) UDPClientOption {
+	return func(opt *UDPClientOptions) {
+		opt.readTimeout = timeout
+	}
+}
+
+func WithUDPClientResponseCallback(responseCallback UDPClientResponseCallback) UDPClientOption {
+	return func(opt *UDPClientOptions) {
+		opt.responseCallback = responseCallback
+	}
+}
+
+type UDPClientOptions struct {
+	// 不阻塞的请求数量,默认为阻塞
+	requestNonBlockCount int
+
+	// 默认1024字节,一般保证足够收取一个数据包
+	receiveBufferSize int
+
+	// 写超时,不设置就是阻塞写
+	writeTimeout time.Duration
+
+	// 读超时,不设置就是阻塞读
+	readTimeout time.Duration
+
+	// 响应数据回调
+	responseCallback UDPClientResponseCallback
+}
+
+func NewUDPClientOptions(opts ...UDPClientOption) *UDPClientOptions {
+	options := new(UDPClientOptions)
+
+	for _, opt := range opts {
+		opt(options)
+	}
+
+	if options.receiveBufferSize == 0 {
+		options.receiveBufferSize = udpClientReceiveBufferSize
+	}
+
+	return options
+}
+
+type UDPClient struct {
+	options     *UDPClientOptions
+	conn        *net.UDPConn
+	requestChan chan []byte
+	doneChan    chan any
+}
+
+// Connect 建立连接
+func (client *UDPClient) Connect(serverAddress string, options *UDPClientOptions) error {
+	serverAddr, err := net.ResolveUDPAddr("udp", serverAddress)
+	if err != nil {
+		panic(err)
+	}
+
+	conn, err := net.DialUDP("udp", nil, serverAddr)
+	if err != nil {
+		panic(err)
+	}
+
+	client.options = options
+	client.conn = conn
+	client.requestChan = make(chan []byte, options.requestNonBlockCount)
+	client.doneChan = make(chan any)
+
+	// 启动发送请求协程
+	client.sendRequest()
+
+	return nil
+}
+
+// Disconnect 断开连接
+func (client *UDPClient) Disconnect() {
+	client.doneChan <- nil
+	close(client.doneChan)
+	client.doneChan = nil
+
+	close(client.requestChan)
+	client.requestChan = nil
+
+	CloseConnection(client.conn)
+	client.conn = nil
+}
+
+// Send 发送数据包,data应该为大端字节序
+func (client *UDPClient) Send(data []byte) {
+	client.requestChan <- data
+}
+
+func (client *UDPClient) sendRequest() {
+	dealRequestDoneChannels := make([]chan any, 0)
+
+	for {
+		select {
+		case <-client.doneChan:
+			for _, dealRequestDoneChan := range dealRequestDoneChannels {
+				dealRequestDoneChan <- nil
+				close(dealRequestDoneChan)
+			}
+
+			return
+		case data := <-client.requestChan:
+			err := writeUDP(client.conn, data, withWriteDeadline(client.options.writeTimeout))
+			if err != nil {
+				fmt.Println(err)
+				continue
+			}
+
+			responseBytes, _, err := readUDP(client.conn, client.options.receiveBufferSize, withReadDeadline(client.options.readTimeout))
+			if err != nil {
+				fmt.Println(err)
+				continue
+			}
+
+			if client.options.responseCallback != nil {
+				go client.dealResponse(responseBytes)
+			}
+		}
+	}
+}
+
+func (client *UDPClient) dealResponse(responseBytes []byte) {
+	dataBuffer := bytes.NewReader(responseBytes)
+	reader := NewDataReader(dataBuffer)
+	data, err := reader.Bytes(len(responseBytes))
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+
+	client.options.responseCallback(data)
+}

+ 2 - 3
network/udp.go → network/udp_server.go

@@ -15,9 +15,9 @@ var UDPServerIgnoreResponse = errors.New("忽略响应")
 
 // UDPServerRequestCallback 请求回调
 // 参数:
-// data: 请求数据,大端,可以用DataReader解析
+// data: 请求数据,大端模式,可以用DataReader处理
 // 返回值:
-// 响应数据: 大端,可以用DataWriter写入buffer再返回
+// 响应数据: 大端模式,可以用DataWriter构造
 // 是否存在错误: 如果是UDPServerIgnoreResponse,则忽略,不进行响应
 type UDPServerRequestCallback func(data []byte) ([]byte, error)
 
@@ -150,7 +150,6 @@ func (server *UDPServer) dealRequest(data []byte, rAddr *net.UDPAddr, doneChan c
 		default:
 			// 没有提供请求响应函数
 			if server.options.requestCallback == nil {
-				server.response(server.conn, rAddr, data)
 				return
 			}
 

+ 2 - 9
network/udp_test.go

@@ -21,17 +21,10 @@ func TestUDP(t *testing.T) {
 		WithUDPServerWriteTimeout(testUDPServerTimeout),
 		WithUDPServerReceiveBufferSize(testUDPServerReceiveBufferSize),
 		WithUDPServerRequestCallback(func(data []byte) ([]byte, error) {
-			requestBuffer := bytes.NewReader(data)
-			requestDataReader := NewDataReader(requestBuffer)
-			requestBytes, err := requestDataReader.Bytes(len(data))
-			if err != nil {
-				return nil, err
-			}
-
-			responseBytes := []byte(strings.ToUpper(string(requestBytes)))
+			responseBytes := []byte(strings.ToUpper(string(data)))
 			responseBuffer := &bytes.Buffer{}
 			responseReader := NewDataWriter(responseBuffer)
-			err = responseReader.Bytes(responseBytes)
+			err := responseReader.Bytes(responseBytes)
 			if err != nil {
 				return nil, err
 			}