tcp_server.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. package network
  2. import (
  3. "fmt"
  4. "git.sxidc.com/go-tools/utils/syncutils"
  5. "io"
  6. "net"
  7. "time"
  8. )
  9. const (
  10. tcpServerReceiveBufferSize = 1024
  11. )
  12. // TCPServerRequestCallback 请求回调
  13. // 参数:
  14. // dataReader: 请求数据DataReader
  15. // 返回值:
  16. // 响应数据DataReader: 可以使用DataWriter构造,然后使用ToReader转换成DataReader,返回nil代表还要继续接收数据,不做响应
  17. type TCPServerRequestCallback func(dataReader *DataReader) *DataReader
  18. type TCPServerOption func(opt *TCPServerOptions)
  19. func WithTCPServerReceiveBufferSize(receiveBufferSize int) TCPServerOption {
  20. return func(opt *TCPServerOptions) {
  21. opt.receiveBufferSize = receiveBufferSize
  22. }
  23. }
  24. func WithTCPServerWriteTimeout(timeout time.Duration) TCPServerOption {
  25. return func(opt *TCPServerOptions) {
  26. opt.writeTimeout = timeout
  27. }
  28. }
  29. func WithTCPServerReadTimeout(timeout time.Duration) TCPServerOption {
  30. return func(opt *TCPServerOptions) {
  31. opt.readTimeout = timeout
  32. }
  33. }
  34. func WithTCPServerRequestCallback(requestCallback TCPServerRequestCallback) TCPServerOption {
  35. return func(opt *TCPServerOptions) {
  36. opt.requestCallback = requestCallback
  37. }
  38. }
  39. type TCPServerOptions struct {
  40. // 默认1024字节,一般保证足够收取一个数据包
  41. receiveBufferSize int
  42. // 写超时,不设置就是阻塞写
  43. writeTimeout time.Duration
  44. // 读超时,不设置就是阻塞读
  45. readTimeout time.Duration
  46. // 请求数据回调
  47. requestCallback TCPServerRequestCallback
  48. }
  49. func NewTCPServerOptions(opts ...TCPServerOption) *TCPServerOptions {
  50. options := new(TCPServerOptions)
  51. for _, opt := range opts {
  52. opt(options)
  53. }
  54. if options.receiveBufferSize == 0 {
  55. options.receiveBufferSize = tcpServerReceiveBufferSize
  56. }
  57. return options
  58. }
  59. type TCPServer struct {
  60. options *TCPServerOptions
  61. listener *net.TCPListener
  62. doneChan chan any
  63. }
  64. // Connect 建立连接
  65. func (server *TCPServer) Connect(address string, options *TCPServerOptions) error {
  66. addr, err := net.ResolveTCPAddr("tcp", address)
  67. if err != nil {
  68. return err
  69. }
  70. listener, err := net.ListenTCP("tcp", addr)
  71. if err != nil {
  72. return err
  73. }
  74. server.options = options
  75. server.listener = listener
  76. server.doneChan = make(chan any)
  77. // 启动读取请求协程
  78. go server.accept()
  79. return nil
  80. }
  81. // Disconnect 断开连接
  82. func (server *TCPServer) Disconnect() error {
  83. err := server.listener.Close()
  84. if err != nil {
  85. return err
  86. }
  87. server.listener = nil
  88. server.doneChan <- nil
  89. close(server.doneChan)
  90. server.doneChan = nil
  91. return nil
  92. }
  93. func (server *TCPServer) accept() {
  94. readRequestDoneChannels := syncutils.NewSyncVar(make([]chan any, 0), false)
  95. for {
  96. select {
  97. case <-server.doneChan:
  98. readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
  99. for _, channel := range channels {
  100. channel <- nil
  101. close(channel)
  102. }
  103. return make([]chan any, 0)
  104. })
  105. return
  106. default:
  107. conn, err := server.listener.AcceptTCP()
  108. if err != nil {
  109. fmt.Println(err)
  110. continue
  111. }
  112. readRequestDoneChan := make(chan any)
  113. readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
  114. channels = append(channels, readRequestDoneChan)
  115. return channels
  116. })
  117. go server.readRequest(conn, readRequestDoneChan, func(doneChan chan any, err error) {
  118. readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
  119. for i, toRemoveDoneChan := range channels {
  120. if doneChan == toRemoveDoneChan {
  121. close(toRemoveDoneChan)
  122. channels = append(channels[:i], channels[i+1:]...)
  123. break
  124. }
  125. }
  126. return channels
  127. })
  128. })
  129. }
  130. }
  131. }
  132. func (server *TCPServer) readRequest(conn *net.TCPConn, doneChan chan any, errCloseCallback func(doneChan chan any, err error)) {
  133. for {
  134. select {
  135. case <-doneChan:
  136. if conn != nil {
  137. closeConnection(conn)
  138. }
  139. return
  140. default:
  141. err := readTCP(conn, server.options.receiveBufferSize, func(data []byte) (bool, error) {
  142. // 没有提供请求响应函数
  143. if server.options.requestCallback == nil {
  144. return false, nil
  145. }
  146. // 交给上层回调处理,返回处理结果和响应数据
  147. responseDataReader := server.options.requestCallback(NewDataReader(data))
  148. if responseDataReader != nil {
  149. server.response(conn, responseDataReader.GetBytes())
  150. return true, nil
  151. }
  152. return true, nil
  153. }, withReadDeadline(server.options.readTimeout))
  154. if err != nil {
  155. // 对端关闭
  156. if err == io.EOF {
  157. closeConnection(conn)
  158. errCloseCallback(doneChan, err)
  159. return
  160. }
  161. fmt.Println(err)
  162. continue
  163. }
  164. }
  165. }
  166. }
  167. func (server *TCPServer) response(conn net.Conn, data []byte) {
  168. err := writeTCP(conn, data, withWriteDeadline(server.options.writeTimeout))
  169. if err != nil {
  170. fmt.Println("Response Error:", err)
  171. return
  172. }
  173. }