|
@@ -1 +1,214 @@
|
|
|
package network
|
|
|
+
|
|
|
+import (
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
+ "git.sxidc.com/go-tools/utils/syncutils"
|
|
|
+ "io"
|
|
|
+ "net"
|
|
|
+ "time"
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ tcpServerReceiveBufferSize = 1024
|
|
|
+)
|
|
|
+
|
|
|
+var TCPServerIgnoreResponse = errors.New("忽略响应")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+type TCPServerRequestCallback func(dataReader *DataReader) *DataReader
|
|
|
+
|
|
|
+type TCPServerOption func(opt *TCPServerOptions)
|
|
|
+
|
|
|
+func WithTCPServerReceiveBufferSize(receiveBufferSize int) TCPServerOption {
|
|
|
+ return func(opt *TCPServerOptions) {
|
|
|
+ opt.receiveBufferSize = receiveBufferSize
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func WithTCPServerWriteTimeout(timeout time.Duration) TCPServerOption {
|
|
|
+ return func(opt *TCPServerOptions) {
|
|
|
+ opt.writeTimeout = timeout
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func WithTCPServerReadTimeout(timeout time.Duration) TCPServerOption {
|
|
|
+ return func(opt *TCPServerOptions) {
|
|
|
+ opt.readTimeout = timeout
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func WithTCPServerRequestCallback(requestCallback TCPServerRequestCallback) TCPServerOption {
|
|
|
+ return func(opt *TCPServerOptions) {
|
|
|
+ opt.requestCallback = requestCallback
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type TCPServerOptions struct {
|
|
|
+
|
|
|
+ receiveBufferSize int
|
|
|
+
|
|
|
+
|
|
|
+ writeTimeout time.Duration
|
|
|
+
|
|
|
+
|
|
|
+ readTimeout time.Duration
|
|
|
+
|
|
|
+
|
|
|
+ requestCallback TCPServerRequestCallback
|
|
|
+}
|
|
|
+
|
|
|
+func NewTCPServerOptions(opts ...TCPServerOption) *TCPServerOptions {
|
|
|
+ options := new(TCPServerOptions)
|
|
|
+
|
|
|
+ for _, opt := range opts {
|
|
|
+ opt(options)
|
|
|
+ }
|
|
|
+
|
|
|
+ if options.receiveBufferSize == 0 {
|
|
|
+ options.receiveBufferSize = tcpServerReceiveBufferSize
|
|
|
+ }
|
|
|
+
|
|
|
+ return options
|
|
|
+}
|
|
|
+
|
|
|
+type TCPServer struct {
|
|
|
+ options *TCPServerOptions
|
|
|
+ listener *net.TCPListener
|
|
|
+ doneChan chan any
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (server *TCPServer) Connect(address string, options *TCPServerOptions) error {
|
|
|
+ addr, err := net.ResolveTCPAddr("tcp", address)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ listener, err := net.ListenTCP("tcp", addr)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ server.options = options
|
|
|
+ server.listener = listener
|
|
|
+ server.doneChan = make(chan any)
|
|
|
+
|
|
|
+
|
|
|
+ go server.accept()
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (server *TCPServer) Disconnect() error {
|
|
|
+ err := server.listener.Close()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ server.listener = nil
|
|
|
+
|
|
|
+ server.doneChan <- nil
|
|
|
+ close(server.doneChan)
|
|
|
+ server.doneChan = nil
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (server *TCPServer) accept() {
|
|
|
+ readRequestDoneChannels := syncutils.NewSyncVar(make([]chan any, 0), false)
|
|
|
+
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-server.doneChan:
|
|
|
+ readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
|
|
|
+ for _, channel := range channels {
|
|
|
+ channel <- nil
|
|
|
+ close(channel)
|
|
|
+ }
|
|
|
+
|
|
|
+ return make([]chan any, 0)
|
|
|
+ })
|
|
|
+
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ conn, err := server.listener.AcceptTCP()
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println(err)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ readRequestDoneChan := make(chan any)
|
|
|
+ readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
|
|
|
+ channels = append(channels, readRequestDoneChan)
|
|
|
+ return channels
|
|
|
+ })
|
|
|
+
|
|
|
+ go server.readRequest(conn, readRequestDoneChan, func(doneChan chan any, err error) {
|
|
|
+ readRequestDoneChannels.ForWrite(func(channels []chan any) []chan any {
|
|
|
+ for i, toRemoveDoneChan := range channels {
|
|
|
+ if doneChan == toRemoveDoneChan {
|
|
|
+ close(toRemoveDoneChan)
|
|
|
+ channels = append(channels[:i], channels[i+1:]...)
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return channels
|
|
|
+ })
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (server *TCPServer) readRequest(conn *net.TCPConn, doneChan chan any, errCloseCallback func(doneChan chan any, err error)) {
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-doneChan:
|
|
|
+ if conn != nil {
|
|
|
+ closeConnection(conn)
|
|
|
+ }
|
|
|
+
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ err := readTCP(conn, server.options.receiveBufferSize, func(data []byte) (bool, error) {
|
|
|
+
|
|
|
+ if server.options.requestCallback == nil {
|
|
|
+ return false, nil
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ responseDataReader := server.options.requestCallback(NewDataReader(data))
|
|
|
+ if responseDataReader != nil {
|
|
|
+ server.response(conn, responseDataReader.GetBytes())
|
|
|
+ return true, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return true, nil
|
|
|
+ }, withReadDeadline(server.options.readTimeout))
|
|
|
+ if err != nil {
|
|
|
+
|
|
|
+ if err == io.EOF {
|
|
|
+ closeConnection(conn)
|
|
|
+ errCloseCallback(doneChan, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ fmt.Println(err)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (server *TCPServer) response(conn net.Conn, data []byte) {
|
|
|
+ err := writeTCP(conn, data, withWriteDeadline(server.options.writeTimeout))
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println("Response Error:", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+}
|