tcp_client.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package network
  2. import (
  3. "fmt"
  4. "net"
  5. "time"
  6. )
  7. const (
  8. tcpClientReceiveBufferSize = 1024
  9. )
  10. // TCPClientResponseCallback 响应回调
  11. // 参数:
  12. // dataReader: 响应数据DataReader
  13. // 返回值:
  14. // 是否读取完成,false继续读取
  15. type TCPClientResponseCallback func(dataReader *DataReader) bool
  16. type TCPClientOption func(opt *TCPClientOptions)
  17. func WithTCPClientRequestNonBlockCount(count int) TCPClientOption {
  18. return func(opt *TCPClientOptions) {
  19. opt.requestNonBlockCount = count
  20. }
  21. }
  22. func WithTCPClientReceiveBufferSize(receiveBufferSize int) TCPClientOption {
  23. return func(opt *TCPClientOptions) {
  24. opt.receiveBufferSize = receiveBufferSize
  25. }
  26. }
  27. func WithTCPClientWriteTimeout(timeout time.Duration) TCPClientOption {
  28. return func(opt *TCPClientOptions) {
  29. opt.writeTimeout = timeout
  30. }
  31. }
  32. func WithTCPClientReadTimeout(timeout time.Duration) TCPClientOption {
  33. return func(opt *TCPClientOptions) {
  34. opt.readTimeout = timeout
  35. }
  36. }
  37. func WithTCPClientResponseCallback(responseCallback TCPClientResponseCallback) TCPClientOption {
  38. return func(opt *TCPClientOptions) {
  39. opt.responseCallback = responseCallback
  40. }
  41. }
  42. type TCPClientOptions struct {
  43. // 不阻塞的请求数量,默认为阻塞
  44. requestNonBlockCount int
  45. // 默认1024字节,一般保证足够收取一个数据包
  46. receiveBufferSize int
  47. // 写超时,不设置就是阻塞写
  48. writeTimeout time.Duration
  49. // 读超时,不设置就是阻塞读
  50. readTimeout time.Duration
  51. // 响应数据回调
  52. responseCallback TCPClientResponseCallback
  53. }
  54. func NewTCPClientOptions(opts ...TCPClientOption) *TCPClientOptions {
  55. options := new(TCPClientOptions)
  56. for _, opt := range opts {
  57. opt(options)
  58. }
  59. if options.receiveBufferSize == 0 {
  60. options.receiveBufferSize = tcpClientReceiveBufferSize
  61. }
  62. return options
  63. }
  64. type TCPClient struct {
  65. options *TCPClientOptions
  66. conn *net.TCPConn
  67. requestChan chan []byte
  68. doneChan chan any
  69. }
  70. // Connect 建立连接
  71. func (client *TCPClient) Connect(serverAddress string, options *TCPClientOptions) error {
  72. serverAddr, err := net.ResolveTCPAddr("tcp", serverAddress)
  73. if err != nil {
  74. panic(err)
  75. }
  76. conn, err := net.DialTCP("tcp", nil, serverAddr)
  77. if err != nil {
  78. panic(err)
  79. }
  80. client.options = options
  81. client.conn = conn
  82. client.requestChan = make(chan []byte, options.requestNonBlockCount)
  83. client.doneChan = make(chan any)
  84. // 启动发送请求协程
  85. go client.sendRequest()
  86. return nil
  87. }
  88. // Disconnect 断开连接
  89. func (client *TCPClient) Disconnect() {
  90. client.doneChan <- nil
  91. close(client.doneChan)
  92. client.doneChan = nil
  93. close(client.requestChan)
  94. client.requestChan = nil
  95. closeConnection(client.conn)
  96. client.conn = nil
  97. }
  98. // Send 发送数据包,data应该为大端字节序
  99. func (client *TCPClient) Send(data []byte) {
  100. client.requestChan <- data
  101. }
  102. func (client *TCPClient) sendRequest() {
  103. dealRequestDoneChannels := make([]chan any, 0)
  104. for {
  105. select {
  106. case <-client.doneChan:
  107. for _, dealRequestDoneChan := range dealRequestDoneChannels {
  108. dealRequestDoneChan <- nil
  109. close(dealRequestDoneChan)
  110. }
  111. return
  112. case data := <-client.requestChan:
  113. err := writeTCP(client.conn, data, withWriteDeadline(client.options.writeTimeout))
  114. if err != nil {
  115. fmt.Println(err)
  116. continue
  117. }
  118. if client.options.responseCallback != nil {
  119. err = readTCP(client.conn, client.options.receiveBufferSize, func(data []byte) (bool, error) {
  120. readOver := client.options.responseCallback(NewDataReader(data))
  121. return readOver, nil
  122. }, withReadDeadline(client.options.readTimeout))
  123. if err != nil {
  124. fmt.Println(err)
  125. continue
  126. }
  127. }
  128. }
  129. }
  130. }