tcp_server.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. package network
  2. import (
  3. "errors"
  4. "fmt"
  5. "git.sxidc.com/go-tools/utils/syncutils"
  6. "io"
  7. "net"
  8. "time"
  9. )
  10. const (
  11. tcpServerReceiveBufferSize = 1024
  12. )
  13. var TCPServerIgnoreResponse = errors.New("忽略响应")
  14. // TCPServerRequestCallback 请求回调
  15. // 参数:
  16. // dataReader: 请求数据DataReader
  17. // 返回值:
  18. // 响应数据DataReader: 可以使用DataWriter构造,然后使用ToReader转换成DataReader,返回nil代表还要继续接收数据,不做响应
  19. type TCPServerRequestCallback func(dataReader *DataReader) *DataReader
  20. type TCPServerOption func(opt *TCPServerOptions)
  21. func WithTCPServerReceiveBufferSize(receiveBufferSize int) TCPServerOption {
  22. return func(opt *TCPServerOptions) {
  23. opt.receiveBufferSize = receiveBufferSize
  24. }
  25. }
  26. func WithTCPServerWriteTimeout(timeout time.Duration) TCPServerOption {
  27. return func(opt *TCPServerOptions) {
  28. opt.writeTimeout = timeout
  29. }
  30. }
  31. func WithTCPServerReadTimeout(timeout time.Duration) TCPServerOption {
  32. return func(opt *TCPServerOptions) {
  33. opt.readTimeout = timeout
  34. }
  35. }
  36. func WithTCPServerRequestCallback(requestCallback TCPServerRequestCallback) TCPServerOption {
  37. return func(opt *TCPServerOptions) {
  38. opt.requestCallback = requestCallback
  39. }
  40. }
  41. type TCPServerOptions struct {
  42. // 默认1024字节,一般保证足够收取一个数据包
  43. receiveBufferSize int
  44. // 写超时,不设置就是阻塞写
  45. writeTimeout time.Duration
  46. // 读超时,不设置就是阻塞读
  47. readTimeout time.Duration
  48. // 请求数据回调
  49. requestCallback TCPServerRequestCallback
  50. }
  51. func NewTCPServerOptions(opts ...TCPServerOption) *TCPServerOptions {
  52. options := new(TCPServerOptions)
  53. for _, opt := range opts {
  54. opt(options)
  55. }
  56. if options.receiveBufferSize == 0 {
  57. options.receiveBufferSize = tcpServerReceiveBufferSize
  58. }
  59. return options
  60. }
  61. type TCPServer struct {
  62. options *TCPServerOptions
  63. listener *net.TCPListener
  64. doneChan chan any
  65. }
  66. // Connect 建立连接
  67. func (server *TCPServer) Connect(address string, options *TCPServerOptions) error {
  68. addr, err := net.ResolveTCPAddr("tcp", address)
  69. if err != nil {
  70. return err
  71. }
  72. listener, err := net.ListenTCP("tcp", addr)
  73. if err != nil {
  74. return err
  75. }
  76. server.options = options
  77. server.listener = listener
  78. server.doneChan = make(chan any)
  79. // 启动读取请求协程
  80. go server.accept()
  81. return nil
  82. }
  83. // Disconnect 断开连接
  84. func (server *TCPServer) Disconnect() error {
  85. err := server.listener.Close()
  86. if err != nil {
  87. return err
  88. }
  89. server.listener = nil
  90. server.doneChan <- nil
  91. close(server.doneChan)
  92. server.doneChan = nil
  93. return nil
  94. }
  95. func (server *TCPServer) accept() {
  96. readRequestDoneChannels := syncutils.NewSyncVar(make([]chan any, 0), false)
  97. for {
  98. select {
  99. case <-server.doneChan:
  100. readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
  101. for _, channel := range channels {
  102. channel <- nil
  103. close(channel)
  104. }
  105. return make([]chan any, 0)
  106. })
  107. return
  108. default:
  109. conn, err := server.listener.AcceptTCP()
  110. if err != nil {
  111. fmt.Println(err)
  112. continue
  113. }
  114. readRequestDoneChan := make(chan any)
  115. readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
  116. channels = append(channels, readRequestDoneChan)
  117. return channels
  118. })
  119. go server.readRequest(conn, readRequestDoneChan, func(doneChan chan any, err error) {
  120. readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
  121. for i, toRemoveDoneChan := range channels {
  122. if doneChan == toRemoveDoneChan {
  123. close(toRemoveDoneChan)
  124. channels = append(channels[:i], channels[i+1:]...)
  125. break
  126. }
  127. }
  128. return channels
  129. })
  130. })
  131. }
  132. }
  133. }
  134. func (server *TCPServer) readRequest(conn *net.TCPConn, doneChan chan any, errCloseCallback func(doneChan chan any, err error)) {
  135. for {
  136. select {
  137. case <-doneChan:
  138. if conn != nil {
  139. closeConnection(conn)
  140. }
  141. return
  142. default:
  143. err := readTCP(conn, server.options.receiveBufferSize, func(data []byte) (bool, error) {
  144. // 没有提供请求响应函数
  145. if server.options.requestCallback == nil {
  146. return false, nil
  147. }
  148. // 交给上层回调处理,返回处理结果和响应数据
  149. responseDataReader := server.options.requestCallback(NewDataReader(data))
  150. if responseDataReader != nil {
  151. server.response(conn, responseDataReader.GetBytes())
  152. return true, nil
  153. }
  154. return true, nil
  155. }, withReadDeadline(server.options.readTimeout))
  156. if err != nil {
  157. // 对端关闭
  158. if err == io.EOF {
  159. closeConnection(conn)
  160. errCloseCallback(doneChan, err)
  161. return
  162. }
  163. fmt.Println(err)
  164. continue
  165. }
  166. }
  167. }
  168. }
  169. func (server *TCPServer) response(conn net.Conn, data []byte) {
  170. err := writeTCP(conn, data, withWriteDeadline(server.options.writeTimeout))
  171. if err != nil {
  172. fmt.Println("Response Error:", err)
  173. return
  174. }
  175. }