package network import ( "fmt" "git.sxidc.com/go-tools/utils/syncutils" "io" "net" "time" ) const ( tcpServerReceiveBufferSize = 1024 ) // TCPServerRequestCallback 请求回调 // 参数: // dataReader: 请求数据DataReader // 返回值: // 响应数据DataReader: 可以使用DataWriter构造,然后使用ToReader转换成DataReader,返回nil代表还要继续接收数据,不做响应 type TCPServerRequestCallback func(dataReader *DataReader) *DataReader type TCPServerOption func(opt *TCPServerOptions) func WithTCPServerReceiveBufferSize(receiveBufferSize int) TCPServerOption { return func(opt *TCPServerOptions) { opt.receiveBufferSize = receiveBufferSize } } func WithTCPServerWriteTimeout(timeout time.Duration) TCPServerOption { return func(opt *TCPServerOptions) { opt.writeTimeout = timeout } } func WithTCPServerReadTimeout(timeout time.Duration) TCPServerOption { return func(opt *TCPServerOptions) { opt.readTimeout = timeout } } func WithTCPServerRequestCallback(requestCallback TCPServerRequestCallback) TCPServerOption { return func(opt *TCPServerOptions) { opt.requestCallback = requestCallback } } type TCPServerOptions struct { // 默认1024字节,一般保证足够收取一个数据包 receiveBufferSize int // 写超时,不设置就是阻塞写 writeTimeout time.Duration // 读超时,不设置就是阻塞读 readTimeout time.Duration // 请求数据回调 requestCallback TCPServerRequestCallback } func NewTCPServerOptions(opts ...TCPServerOption) *TCPServerOptions { options := new(TCPServerOptions) for _, opt := range opts { opt(options) } if options.receiveBufferSize == 0 { options.receiveBufferSize = tcpServerReceiveBufferSize } return options } type TCPServer struct { options *TCPServerOptions listener *net.TCPListener doneChan chan any } // Connect 建立连接 func (server *TCPServer) Connect(address string, options *TCPServerOptions) error { addr, err := net.ResolveTCPAddr("tcp", address) if err != nil { return err } listener, err := net.ListenTCP("tcp", addr) if err != nil { return err } server.options = options server.listener = listener server.doneChan = make(chan any) // 启动读取请求协程 go server.accept() return nil } // Disconnect 断开连接 func (server *TCPServer) Disconnect() error { err := server.listener.Close() if err != nil { return err } server.listener = nil server.doneChan <- nil close(server.doneChan) server.doneChan = nil return nil } func (server *TCPServer) accept() { readRequestDoneChannels := syncutils.NewSyncVar(make([]chan any, 0), false) for { select { case <-server.doneChan: readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any { for _, channel := range channels { channel <- nil close(channel) } return make([]chan any, 0) }) return default: conn, err := server.listener.AcceptTCP() if err != nil { fmt.Println(err) continue } readRequestDoneChan := make(chan any) readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any { channels = append(channels, readRequestDoneChan) return channels }) go server.readRequest(conn, readRequestDoneChan, func(doneChan chan any, err error) { readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any { for i, toRemoveDoneChan := range channels { if doneChan == toRemoveDoneChan { close(toRemoveDoneChan) channels = append(channels[:i], channels[i+1:]...) break } } return channels }) }) } } } func (server *TCPServer) readRequest(conn *net.TCPConn, doneChan chan any, errCloseCallback func(doneChan chan any, err error)) { for { select { case <-doneChan: if conn != nil { closeConnection(conn) } return default: err := readTCP(conn, server.options.receiveBufferSize, func(data []byte) (bool, error) { // 没有提供请求响应函数 if server.options.requestCallback == nil { return false, nil } // 交给上层回调处理,返回处理结果和响应数据 responseDataReader := server.options.requestCallback(NewDataReader(data)) if responseDataReader != nil { server.response(conn, responseDataReader.GetBytes()) return true, nil } return true, nil }, withReadDeadline(server.options.readTimeout)) if err != nil { // 对端关闭 if err == io.EOF { closeConnection(conn) errCloseCallback(doneChan, err) return } fmt.Println(err) continue } } } } func (server *TCPServer) response(conn net.Conn, data []byte) { err := writeTCP(conn, data, withWriteDeadline(server.options.writeTimeout)) if err != nil { fmt.Println("Response Error:", err) return } }