Browse Source

完成UDP的服务端开发,未测试

yjp 1 year ago
parent
commit
75d2bd6057
2 changed files with 81 additions and 41 deletions
  1. 2 2
      network/connection.go
  2. 79 39
      network/udp_server.go

+ 2 - 2
network/connection.go

@@ -63,8 +63,8 @@ func readUDP(conn *net.UDPConn, bufferSize int, opts ...ConnectionReadOptions) (
 	return buffer[:n], rAddr, nil
 }
 
-// WriteUDPWithRemoteAddr 指定远端(对端)地址写出UDP包
-func WriteUDPWithRemoteAddr(conn *net.UDPConn, rAddr *net.UDPAddr, data []byte, opts ...ConnectionWriteOptions) error {
+// writeUDPWithRemoteAddr 指定远端(对端)地址写出UDP包
+func writeUDPWithRemoteAddr(conn *net.UDPConn, rAddr *net.UDPAddr, data []byte, opts ...ConnectionWriteOptions) error {
 	for _, opt := range opts {
 		err := opt(conn)
 		if err != nil {

+ 79 - 39
network/udp_server.go

@@ -10,30 +10,53 @@ const (
 	udpServerReceiveBufferSize = 1024
 )
 
+// UDPServerRequestCallback 请求回调
+// 参数:
+// data: 请求数据
+// 返回值:
+// send: 是否发送响应
+// responseBytes: 响应数据包
+// err: 是否存在错误
+type UDPServerRequestCallback func(data []byte) (send bool, responseBytes []byte, err error)
+
 type UDPServerOption func(opt *UDPServerOptions)
 
-func WithReceiveBufferSize(receiveBufferSize int) UDPServerOption {
+func WithUDPServerReceiveBufferSize(receiveBufferSize int) UDPServerOption {
+	return func(opt *UDPServerOptions) {
+		opt.receiveBufferSize = receiveBufferSize
+	}
+}
+
+func WithUDPServerWriteTimeout(timeout time.Duration) UDPServerOption {
 	return func(opt *UDPServerOptions) {
-		opt.ReceiveBufferSize = receiveBufferSize
+		opt.writeTimeout = timeout
 	}
 }
 
-func WithWriteTimeout(timeout time.Duration) UDPServerOption {
+func WithUDPServerReadTimeout(timeout time.Duration) UDPServerOption {
 	return func(opt *UDPServerOptions) {
-		opt.WriteTimeout = timeout
+		opt.readTimeout = timeout
 	}
 }
 
-func WithReadTimeout(timeout time.Duration) UDPServerOption {
+func WithUDPServerRequestCallback(requestCallback UDPServerRequestCallback) UDPServerOption {
 	return func(opt *UDPServerOptions) {
-		opt.ReadTimeout = timeout
+		opt.requestCallback = requestCallback
 	}
 }
 
 type UDPServerOptions struct {
-	ReceiveBufferSize int
-	WriteTimeout      time.Duration
-	ReadTimeout       time.Duration
+	// 默认1024字节,一般保证足够收取一个数据包
+	receiveBufferSize int
+
+	// 写超时,不设置就是阻塞写
+	writeTimeout time.Duration
+
+	// 读超时,不设置就是阻塞读
+	readTimeout time.Duration
+
+	// 请求数据回调
+	requestCallback UDPServerRequestCallback
 }
 
 func NewUDPServerOptions(opts ...UDPServerOption) *UDPServerOptions {
@@ -43,24 +66,17 @@ func NewUDPServerOptions(opts ...UDPServerOption) *UDPServerOptions {
 		opt(options)
 	}
 
-	if options.ReceiveBufferSize == 0 {
-		options.ReceiveBufferSize = udpServerReceiveBufferSize
+	if options.receiveBufferSize == 0 {
+		options.receiveBufferSize = udpServerReceiveBufferSize
 	}
 
 	return options
 }
 
 type UDPServer struct {
-	options             *UDPServerOptions
-	conn                *net.UDPConn
-	doneChan            chan any
-	dealRequestChan     chan *remoteData
-	dealRequestDoneChan chan any
-}
-
-type remoteData struct {
-	data  []byte
-	rAddr *net.UDPAddr
+	options  *UDPServerOptions
+	conn     *net.UDPConn
+	doneChan chan any
 }
 
 // Connect 建立连接
@@ -81,7 +97,7 @@ func (server *UDPServer) Connect(address string, options *UDPServerOptions) erro
 	server.doneChan = make(chan any)
 
 	// 启动读取请求协程
-	go server.readRequest()
+	server.readRequest()
 
 	return nil
 }
@@ -97,38 +113,62 @@ func (server *UDPServer) Disconnect() {
 }
 
 func (server *UDPServer) readRequest() {
-	server.dealRequestChan = make(chan *remoteData)
-	server.dealRequestDoneChan = make(chan any)
-
-	go server.dealRequestAndResponse()
+	dealRequestDoneChannels := make([]chan any, 0)
 
 	for {
 		select {
 		case <-server.doneChan:
-			server.dealRequestDoneChan <- nil
-			close(server.dealRequestDoneChan)
-			server.dealRequestDoneChan = nil
-
-			close(server.dealRequestChan)
-			server.dealRequestChan = nil
+			for _, dealRequestDoneChan := range dealRequestDoneChannels {
+				dealRequestDoneChan <- nil
+				close(dealRequestDoneChan)
+			}
 
 			return
 		default:
-			// 读取任意客户端发来的请求
-			data, rAddr, err := readUDP(server.conn, server.options.ReceiveBufferSize, WithReadDeadline(server.options.ReadTimeout))
+			// 读取任意客户端发来的请求,超时就是没有客户端发出请求
+			data, rAddr, err := readUDP(server.conn, server.options.receiveBufferSize, WithReadDeadline(server.options.readTimeout))
 			if err != nil {
 				fmt.Println(err)
 				continue
 			}
 
-			server.dealRequestChan <- &remoteData{
-				data:  data,
-				rAddr: rAddr,
+			// 接收到请求
+			dealRequestDoneChan := make(chan any)
+			dealRequestDoneChannels = append(dealRequestDoneChannels, dealRequestDoneChan)
+			go server.dealRequest(data, rAddr, dealRequestDoneChan)
+		}
+	}
+}
+
+func (server *UDPServer) dealRequest(data []byte, rAddr *net.UDPAddr, doneChan chan any) {
+	for {
+		select {
+		case <-doneChan:
+			return
+		default:
+			if server.options.requestCallback == nil {
+				server.response(server.conn, rAddr, data)
+				return
+			}
+
+			send, responseBytes, err := server.options.requestCallback(data)
+			if !send {
+				return
+			}
+
+			if err != nil {
+				server.response(server.conn, rAddr, []byte(err.Error()))
 			}
+
+			server.response(server.conn, rAddr, responseBytes)
 		}
 	}
 }
 
-func (server *UDPServer) dealRequestAndResponse() {
-	// 回调上层
+func (server *UDPServer) response(conn *net.UDPConn, rAddr *net.UDPAddr, data []byte) {
+	err := writeUDPWithRemoteAddr(conn, rAddr, data, WithWriteDeadline(server.options.writeTimeout))
+	if err != nil {
+		fmt.Println("Response Error:", err)
+		return
+	}
 }