package network import ( "fmt" "net" "time" ) const ( udpServerReceiveBufferSize = 1024 ) // UDPServerRequestCallback 请求回调 // 参数: // dataReader: 请求数据DataReader // 返回值: // 响应数据DataReader: 可以使用DataWriter构造,然后使用ToReader转换成DataReader,返回nil代表还要继续接收数据,不做响应 type UDPServerRequestCallback func(dataReader *DataReader) *DataReader type UDPServerOption func(opt *UDPServerOptions) func WithUDPServerReceiveBufferSize(receiveBufferSize int) UDPServerOption { return func(opt *UDPServerOptions) { opt.receiveBufferSize = receiveBufferSize } } func WithUDPServerWriteTimeout(timeout time.Duration) UDPServerOption { return func(opt *UDPServerOptions) { opt.writeTimeout = timeout } } func WithUDPServerReadTimeout(timeout time.Duration) UDPServerOption { return func(opt *UDPServerOptions) { opt.readTimeout = timeout } } func WithUDPServerRequestCallback(requestCallback UDPServerRequestCallback) UDPServerOption { return func(opt *UDPServerOptions) { opt.requestCallback = requestCallback } } type UDPServerOptions struct { // 默认1024字节,一般保证足够收取一个数据包 receiveBufferSize int // 写超时,不设置就是阻塞写 writeTimeout time.Duration // 读超时,不设置就是阻塞读 readTimeout time.Duration // 请求数据回调 requestCallback UDPServerRequestCallback } func NewUDPServerOptions(opts ...UDPServerOption) *UDPServerOptions { options := new(UDPServerOptions) for _, opt := range opts { opt(options) } if options.receiveBufferSize == 0 { options.receiveBufferSize = udpServerReceiveBufferSize } return options } type UDPServer struct { options *UDPServerOptions conn *net.UDPConn doneChan chan any } // Connect 建立连接 func (server *UDPServer) Connect(address string, options *UDPServerOptions) error { addr, err := net.ResolveUDPAddr("udp", address) if err != nil { return err } // 监听端口 conn, err := net.ListenUDP("udp", addr) if err != nil { return err } server.options = options server.conn = conn server.doneChan = make(chan any) // 启动读取请求协程 go server.readRequest() return nil } // Disconnect 断开连接 func (server *UDPServer) Disconnect() { server.doneChan <- nil close(server.doneChan) server.doneChan = nil closeConnection(server.conn) server.conn = nil } func (server *UDPServer) readRequest() { dealRequestDoneChannels := make([]chan any, 0) for { select { case <-server.doneChan: for _, dealRequestDoneChan := range dealRequestDoneChannels { dealRequestDoneChan <- nil close(dealRequestDoneChan) } return default: // 读取任意客户端发来的请求,超时就是没有客户端发出请求 data, rAddr, err := readUDP(server.conn, server.options.receiveBufferSize, withReadDeadline(server.options.readTimeout)) if err != nil { fmt.Println(err) continue } // 接收到请求 dealRequestDoneChan := make(chan any) dealRequestDoneChannels = append(dealRequestDoneChannels, dealRequestDoneChan) go server.dealRequest(data, rAddr, dealRequestDoneChan) } } } func (server *UDPServer) dealRequest(data []byte, rAddr *net.UDPAddr, doneChan chan any) { for { select { case <-doneChan: return default: // 没有提供请求响应函数 if server.options.requestCallback == nil { return } // 交给上层回调处理,返回处理结果和响应数据 responseDataReader := server.options.requestCallback(NewDataReader(data)) if responseDataReader != nil { server.response(server.conn, rAddr, responseDataReader.GetBytes()) return } return } } } func (server *UDPServer) response(conn *net.UDPConn, rAddr *net.UDPAddr, data []byte) { err := writeUDPWithRemoteAddr(conn, rAddr, data, withWriteDeadline(server.options.writeTimeout)) if err != nil { fmt.Println("Response Error:", err) return } }