udp_client.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package network
  2. import (
  3. "fmt"
  4. "net"
  5. "time"
  6. )
  7. const (
  8. udpClientReceiveBufferSize = 1024
  9. )
  10. // UDPClientResponseCallback 响应回调
  11. // 参数:
  12. // dataReader: 响应数据DataReader
  13. // 返回值:
  14. // 是否存在错误
  15. type UDPClientResponseCallback func(dataReader *DataReader)
  16. type UDPClientOption func(opt *UDPClientOptions)
  17. func WithUDPClientRequestNonBlockCount(count int) UDPClientOption {
  18. return func(opt *UDPClientOptions) {
  19. opt.requestNonBlockCount = count
  20. }
  21. }
  22. func WithUDPClientReceiveBufferSize(receiveBufferSize int) UDPClientOption {
  23. return func(opt *UDPClientOptions) {
  24. opt.receiveBufferSize = receiveBufferSize
  25. }
  26. }
  27. func WithUDPClientWriteTimeout(timeout time.Duration) UDPClientOption {
  28. return func(opt *UDPClientOptions) {
  29. opt.writeTimeout = timeout
  30. }
  31. }
  32. func WithUDPClientReadTimeout(timeout time.Duration) UDPClientOption {
  33. return func(opt *UDPClientOptions) {
  34. opt.readTimeout = timeout
  35. }
  36. }
  37. func WithUDPClientResponseCallback(responseCallback UDPClientResponseCallback) UDPClientOption {
  38. return func(opt *UDPClientOptions) {
  39. opt.responseCallback = responseCallback
  40. }
  41. }
  42. type UDPClientOptions struct {
  43. // 不阻塞的请求数量,默认为阻塞
  44. requestNonBlockCount int
  45. // 默认1024字节,一般保证足够收取一个数据包
  46. receiveBufferSize int
  47. // 写超时,不设置就是阻塞写
  48. writeTimeout time.Duration
  49. // 读超时,不设置就是阻塞读
  50. readTimeout time.Duration
  51. // 响应数据回调
  52. responseCallback UDPClientResponseCallback
  53. }
  54. func NewUDPClientOptions(opts ...UDPClientOption) *UDPClientOptions {
  55. options := new(UDPClientOptions)
  56. for _, opt := range opts {
  57. opt(options)
  58. }
  59. if options.receiveBufferSize == 0 {
  60. options.receiveBufferSize = udpClientReceiveBufferSize
  61. }
  62. return options
  63. }
  64. type UDPClient struct {
  65. options *UDPClientOptions
  66. conn *net.UDPConn
  67. requestChan chan []byte
  68. doneChan chan any
  69. }
  70. // Connect 建立连接
  71. func (client *UDPClient) Connect(serverAddress string, options *UDPClientOptions) error {
  72. serverAddr, err := net.ResolveUDPAddr("udp", serverAddress)
  73. if err != nil {
  74. panic(err)
  75. }
  76. conn, err := net.DialUDP("udp", nil, serverAddr)
  77. if err != nil {
  78. panic(err)
  79. }
  80. client.options = options
  81. client.conn = conn
  82. client.requestChan = make(chan []byte, options.requestNonBlockCount)
  83. client.doneChan = make(chan any)
  84. // 启动发送请求协程
  85. go client.sendRequest()
  86. return nil
  87. }
  88. // Disconnect 断开连接
  89. func (client *UDPClient) Disconnect() {
  90. client.doneChan <- nil
  91. close(client.doneChan)
  92. client.doneChan = nil
  93. close(client.requestChan)
  94. client.requestChan = nil
  95. CloseConnection(client.conn)
  96. client.conn = nil
  97. }
  98. // Send 发送数据包,data应该为大端字节序
  99. func (client *UDPClient) Send(data []byte) {
  100. client.requestChan <- data
  101. }
  102. func (client *UDPClient) sendRequest() {
  103. dealRequestDoneChannels := make([]chan any, 0)
  104. for {
  105. select {
  106. case <-client.doneChan:
  107. for _, dealRequestDoneChan := range dealRequestDoneChannels {
  108. dealRequestDoneChan <- nil
  109. close(dealRequestDoneChan)
  110. }
  111. return
  112. case data := <-client.requestChan:
  113. err := writeUDP(client.conn, data, withWriteDeadline(client.options.writeTimeout))
  114. if err != nil {
  115. fmt.Println(err)
  116. continue
  117. }
  118. responseBytes, _, err := readUDP(client.conn, client.options.receiveBufferSize, withReadDeadline(client.options.readTimeout))
  119. if err != nil {
  120. fmt.Println(err)
  121. continue
  122. }
  123. if client.options.responseCallback != nil {
  124. go client.options.responseCallback(NewDataReader(responseBytes))
  125. }
  126. }
  127. }
  128. }