|
- package network
- import (
- "fmt"
- "git.sxidc.com/go-tools/utils/syncutils"
- "io"
- "net"
- "time"
- )
- const (
- tcpServerReceiveBufferSize = 1024
- )
- // TCPServerRequestCallback 请求回调
- // 参数:
- // dataReader: 请求数据DataReader
- // 返回值:
- // 响应数据DataReader: 可以使用DataWriter构造,然后使用ToReader转换成DataReader,返回nil代表还要继续接收数据,不做响应
- type TCPServerRequestCallback func(dataReader *DataReader) *DataReader
- type TCPServerOption func(opt *TCPServerOptions)
- func WithTCPServerReceiveBufferSize(receiveBufferSize int) TCPServerOption {
- return func(opt *TCPServerOptions) {
- opt.receiveBufferSize = receiveBufferSize
- }
- }
- func WithTCPServerWriteTimeout(timeout time.Duration) TCPServerOption {
- return func(opt *TCPServerOptions) {
- opt.writeTimeout = timeout
- }
- }
- func WithTCPServerReadTimeout(timeout time.Duration) TCPServerOption {
- return func(opt *TCPServerOptions) {
- opt.readTimeout = timeout
- }
- }
- func WithTCPServerRequestCallback(requestCallback TCPServerRequestCallback) TCPServerOption {
- return func(opt *TCPServerOptions) {
- opt.requestCallback = requestCallback
- }
- }
- type TCPServerOptions struct {
- // 默认1024字节,一般保证足够收取一个数据包
- receiveBufferSize int
- // 写超时,不设置就是阻塞写
- writeTimeout time.Duration
- // 读超时,不设置就是阻塞读
- readTimeout time.Duration
- // 请求数据回调
- requestCallback TCPServerRequestCallback
- }
- func NewTCPServerOptions(opts ...TCPServerOption) *TCPServerOptions {
- options := new(TCPServerOptions)
- for _, opt := range opts {
- opt(options)
- }
- if options.receiveBufferSize == 0 {
- options.receiveBufferSize = tcpServerReceiveBufferSize
- }
- return options
- }
- type TCPServer struct {
- options *TCPServerOptions
- listener *net.TCPListener
- doneChan chan any
- }
- // Connect 建立连接
- func (server *TCPServer) Connect(address string, options *TCPServerOptions) error {
- addr, err := net.ResolveTCPAddr("tcp", address)
- if err != nil {
- return err
- }
- listener, err := net.ListenTCP("tcp", addr)
- if err != nil {
- return err
- }
- server.options = options
- server.listener = listener
- server.doneChan = make(chan any)
- // 启动读取请求协程
- go server.accept()
- return nil
- }
- // Disconnect 断开连接
- func (server *TCPServer) Disconnect() error {
- err := server.listener.Close()
- if err != nil {
- return err
- }
- server.listener = nil
- server.doneChan <- nil
- close(server.doneChan)
- server.doneChan = nil
- return nil
- }
- func (server *TCPServer) accept() {
- readRequestDoneChannels := syncutils.NewSyncVar(make([]chan any, 0), false)
- for {
- select {
- case <-server.doneChan:
- readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
- for _, channel := range channels {
- channel <- nil
- close(channel)
- }
- return make([]chan any, 0)
- })
- return
- default:
- conn, err := server.listener.AcceptTCP()
- if err != nil {
- fmt.Println(err)
- continue
- }
- readRequestDoneChan := make(chan any)
- readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
- channels = append(channels, readRequestDoneChan)
- return channels
- })
- go server.readRequest(conn, readRequestDoneChan, func(doneChan chan any, err error) {
- readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
- for i, toRemoveDoneChan := range channels {
- if doneChan == toRemoveDoneChan {
- close(toRemoveDoneChan)
- channels = append(channels[:i], channels[i+1:]...)
- break
- }
- }
- return channels
- })
- })
- }
- }
- }
- func (server *TCPServer) readRequest(conn *net.TCPConn, doneChan chan any, errCloseCallback func(doneChan chan any, err error)) {
- for {
- select {
- case <-doneChan:
- if conn != nil {
- closeConnection(conn)
- }
- return
- default:
- err := readTCP(conn, server.options.receiveBufferSize, func(data []byte) (bool, error) {
- // 没有提供请求响应函数
- if server.options.requestCallback == nil {
- return false, nil
- }
- // 交给上层回调处理,返回处理结果和响应数据
- responseDataReader := server.options.requestCallback(NewDataReader(data))
- if responseDataReader != nil {
- server.response(conn, responseDataReader.GetBytes())
- return true, nil
- }
- return true, nil
- }, withReadDeadline(server.options.readTimeout))
- if err != nil {
- // 对端关闭
- if err == io.EOF {
- closeConnection(conn)
- errCloseCallback(doneChan, err)
- return
- }
- fmt.Println(err)
- continue
- }
- }
- }
- }
- func (server *TCPServer) response(conn net.Conn, data []byte) {
- err := writeTCP(conn, data, withWriteDeadline(server.options.writeTimeout))
- if err != nil {
- fmt.Println("Response Error:", err)
- return
- }
- }
|