yjp před 2 měsíci
rodič
revize
919cf45cf6

+ 33 - 11
network/data_reader.go

@@ -1,22 +1,22 @@
 package network
 
 import (
+	"bytes"
 	"encoding/binary"
-	"io"
 )
 
 type DataReader struct {
-	r io.Reader
+	buffer *bytes.Buffer
 }
 
-func NewDataReader(r io.Reader) *DataReader {
-	return &DataReader{r: r}
+func NewDataReader(data []byte) *DataReader {
+	return &DataReader{buffer: bytes.NewBuffer(data)}
 }
 
 func (reader *DataReader) Byte() (byte, error) {
 	var b byte
 
-	err := binary.Read(reader.r, binary.BigEndian, &b)
+	err := binary.Read(reader.buffer, binary.BigEndian, &b)
 	if err != nil {
 		return 0, err
 	}
@@ -26,7 +26,7 @@ func (reader *DataReader) Byte() (byte, error) {
 
 func (reader *DataReader) Bytes(bytesLen int) ([]byte, error) {
 	bs := make([]byte, bytesLen)
-	err := binary.Read(reader.r, binary.BigEndian, bs)
+	err := binary.Read(reader.buffer, binary.BigEndian, bs)
 	if err != nil {
 		return nil, err
 	}
@@ -37,7 +37,7 @@ func (reader *DataReader) Bytes(bytesLen int) ([]byte, error) {
 func (reader *DataReader) Uint8() (uint8, error) {
 	var u uint8
 
-	err := binary.Read(reader.r, binary.BigEndian, &u)
+	err := binary.Read(reader.buffer, binary.BigEndian, &u)
 	if err != nil {
 		return 0, err
 	}
@@ -48,7 +48,7 @@ func (reader *DataReader) Uint8() (uint8, error) {
 func (reader *DataReader) Uint16() (uint16, error) {
 	var u uint16
 
-	err := binary.Read(reader.r, binary.BigEndian, &u)
+	err := binary.Read(reader.buffer, binary.BigEndian, &u)
 	if err != nil {
 		return 0, err
 	}
@@ -59,7 +59,7 @@ func (reader *DataReader) Uint16() (uint16, error) {
 func (reader *DataReader) Uint32() (uint32, error) {
 	var u uint32
 
-	err := binary.Read(reader.r, binary.BigEndian, &u)
+	err := binary.Read(reader.buffer, binary.BigEndian, &u)
 	if err != nil {
 		return 0, err
 	}
@@ -70,7 +70,7 @@ func (reader *DataReader) Uint32() (uint32, error) {
 func (reader *DataReader) Uint64() (uint64, error) {
 	var u uint64
 
-	err := binary.Read(reader.r, binary.BigEndian, &u)
+	err := binary.Read(reader.buffer, binary.BigEndian, &u)
 	if err != nil {
 		return 0, err
 	}
@@ -80,10 +80,32 @@ func (reader *DataReader) Uint64() (uint64, error) {
 
 func (reader *DataReader) String(bytesLen int) (string, error) {
 	retBytes := make([]byte, bytesLen)
-	err := binary.Read(reader.r, binary.BigEndian, retBytes)
+	err := binary.Read(reader.buffer, binary.BigEndian, retBytes)
 	if err != nil {
 		return "", err
 	}
 
 	return string(retBytes), nil
 }
+
+func (reader *DataReader) Len() int {
+	return reader.buffer.Len()
+}
+
+func (reader *DataReader) Cap() int {
+	return reader.buffer.Cap()
+}
+
+func (reader *DataReader) GetBytes() []byte {
+	return reader.buffer.Bytes()
+}
+
+func (reader *DataReader) ToWriter() (*DataWriter, error) {
+	writer := NewDataWriter()
+	err := writer.Bytes(writer.buffer.Bytes())
+	if err != nil {
+		return nil, err
+	}
+
+	return writer, nil
+}

+ 1 - 1
network/data_reader_test.go

@@ -57,7 +57,7 @@ func TestDataReader(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	reader := NewDataReader(buffer)
+	reader := NewDataReader(buffer.Bytes())
 
 	retByteValue, err := reader.Byte()
 	if err != nil {

+ 34 - 18
network/data_writer.go

@@ -1,20 +1,20 @@
 package network
 
 import (
+	"bytes"
 	"encoding/binary"
-	"io"
 )
 
 type DataWriter struct {
-	w io.Writer
+	buffer *bytes.Buffer
 }
 
-func NewDataWriter(w io.Writer) *DataWriter {
-	return &DataWriter{w: w}
+func NewDataWriter() *DataWriter {
+	return &DataWriter{buffer: &bytes.Buffer{}}
 }
 
-func (reader *DataWriter) Byte(b byte) error {
-	err := binary.Write(reader.w, binary.BigEndian, b)
+func (writer *DataWriter) Byte(b byte) error {
+	err := binary.Write(writer.buffer, binary.BigEndian, b)
 	if err != nil {
 		return err
 	}
@@ -22,8 +22,8 @@ func (reader *DataWriter) Byte(b byte) error {
 	return nil
 }
 
-func (reader *DataWriter) Bytes(bs []byte) error {
-	err := binary.Write(reader.w, binary.BigEndian, bs)
+func (writer *DataWriter) Bytes(bs []byte) error {
+	err := binary.Write(writer.buffer, binary.BigEndian, bs)
 	if err != nil {
 		return err
 	}
@@ -31,8 +31,8 @@ func (reader *DataWriter) Bytes(bs []byte) error {
 	return nil
 }
 
-func (reader *DataWriter) Uint8(u uint8) error {
-	err := binary.Write(reader.w, binary.BigEndian, u)
+func (writer *DataWriter) Uint8(u uint8) error {
+	err := binary.Write(writer.buffer, binary.BigEndian, u)
 	if err != nil {
 		return err
 	}
@@ -40,8 +40,8 @@ func (reader *DataWriter) Uint8(u uint8) error {
 	return nil
 }
 
-func (reader *DataWriter) Uint16(u uint16) error {
-	err := binary.Write(reader.w, binary.BigEndian, u)
+func (writer *DataWriter) Uint16(u uint16) error {
+	err := binary.Write(writer.buffer, binary.BigEndian, u)
 	if err != nil {
 		return err
 	}
@@ -49,8 +49,8 @@ func (reader *DataWriter) Uint16(u uint16) error {
 	return nil
 }
 
-func (reader *DataWriter) Uint32(u uint32) error {
-	err := binary.Write(reader.w, binary.BigEndian, u)
+func (writer *DataWriter) Uint32(u uint32) error {
+	err := binary.Write(writer.buffer, binary.BigEndian, u)
 	if err != nil {
 		return err
 	}
@@ -58,8 +58,8 @@ func (reader *DataWriter) Uint32(u uint32) error {
 	return nil
 }
 
-func (reader *DataWriter) Uint64(u uint64) error {
-	err := binary.Write(reader.w, binary.BigEndian, u)
+func (writer *DataWriter) Uint64(u uint64) error {
+	err := binary.Write(writer.buffer, binary.BigEndian, u)
 	if err != nil {
 		return err
 	}
@@ -67,11 +67,27 @@ func (reader *DataWriter) Uint64(u uint64) error {
 	return nil
 }
 
-func (reader *DataWriter) String(s string) error {
-	err := binary.Write(reader.w, binary.BigEndian, []byte(s))
+func (writer *DataWriter) String(s string) error {
+	err := binary.Write(writer.buffer, binary.BigEndian, []byte(s))
 	if err != nil {
 		return err
 	}
 
 	return nil
 }
+
+func (writer *DataWriter) Len() int {
+	return writer.buffer.Len()
+}
+
+func (writer *DataWriter) Cap() int {
+	return writer.buffer.Cap()
+}
+
+func (writer *DataWriter) GetBytes() []byte {
+	return writer.buffer.Bytes()
+}
+
+func (writer *DataWriter) ToReader() *DataReader {
+	return NewDataReader(writer.buffer.Bytes())
+}

+ 10 - 9
network/data_writer_test.go

@@ -20,8 +20,7 @@ var (
 )
 
 func TestDataWriter(t *testing.T) {
-	buffer := &bytes.Buffer{}
-	writer := NewDataWriter(buffer)
+	writer := NewDataWriter()
 
 	err := writer.Byte(writerByteValue)
 	if err != nil {
@@ -58,8 +57,10 @@ func TestDataWriter(t *testing.T) {
 		t.Fatal(err)
 	}
 
+	readBuffer := bytes.NewReader(writer.GetBytes())
+
 	var retByteValue byte
-	err = binary.Read(buffer, binary.BigEndian, &retByteValue)
+	err = binary.Read(readBuffer, binary.BigEndian, &retByteValue)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -69,7 +70,7 @@ func TestDataWriter(t *testing.T) {
 	}
 
 	retBytesValue := make([]byte, len(writerBytesValue))
-	err = binary.Read(buffer, binary.BigEndian, &retBytesValue)
+	err = binary.Read(readBuffer, binary.BigEndian, &retBytesValue)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -85,7 +86,7 @@ func TestDataWriter(t *testing.T) {
 	}
 
 	var retUint8Value uint8
-	err = binary.Read(buffer, binary.BigEndian, &retUint8Value)
+	err = binary.Read(readBuffer, binary.BigEndian, &retUint8Value)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -95,7 +96,7 @@ func TestDataWriter(t *testing.T) {
 	}
 
 	var retUint16Value uint16
-	err = binary.Read(buffer, binary.BigEndian, &retUint16Value)
+	err = binary.Read(readBuffer, binary.BigEndian, &retUint16Value)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -105,7 +106,7 @@ func TestDataWriter(t *testing.T) {
 	}
 
 	var retUint32Value uint32
-	err = binary.Read(buffer, binary.BigEndian, &retUint32Value)
+	err = binary.Read(readBuffer, binary.BigEndian, &retUint32Value)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -115,7 +116,7 @@ func TestDataWriter(t *testing.T) {
 	}
 
 	var retUint64Value uint64
-	err = binary.Read(buffer, binary.BigEndian, &retUint64Value)
+	err = binary.Read(readBuffer, binary.BigEndian, &retUint64Value)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -125,7 +126,7 @@ func TestDataWriter(t *testing.T) {
 	}
 
 	retStringByteValue := make([]byte, len(writerStringValue))
-	err = binary.Read(buffer, binary.BigEndian, &retStringByteValue)
+	err = binary.Read(readBuffer, binary.BigEndian, &retStringByteValue)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 4 - 17
network/udp_client.go

@@ -1,7 +1,6 @@
 package network
 
 import (
-	"bytes"
 	"fmt"
 	"net"
 	"time"
@@ -13,10 +12,10 @@ const (
 
 // UDPClientResponseCallback 响应回调
 // 参数:
-// data: 响应数据,大端模式,可以用DataReader处理
+// dataReader: 响应数据DataReader
 // 返回值:
 // 是否存在错误
-type UDPClientResponseCallback func(data []byte)
+type UDPClientResponseCallback func(dataReader *DataReader)
 
 type UDPClientOption func(opt *UDPClientOptions)
 
@@ -106,7 +105,7 @@ func (client *UDPClient) Connect(serverAddress string, options *UDPClientOptions
 	client.doneChan = make(chan any)
 
 	// 启动发送请求协程
-	client.sendRequest()
+	go client.sendRequest()
 
 	return nil
 }
@@ -155,20 +154,8 @@ func (client *UDPClient) sendRequest() {
 			}
 
 			if client.options.responseCallback != nil {
-				go client.dealResponse(responseBytes)
+				go client.options.responseCallback(NewDataReader(responseBytes))
 			}
 		}
 	}
 }
-
-func (client *UDPClient) dealResponse(responseBytes []byte) {
-	dataBuffer := bytes.NewReader(responseBytes)
-	reader := NewDataReader(dataBuffer)
-	data, err := reader.Bytes(len(responseBytes))
-	if err != nil {
-		fmt.Println(err)
-		return
-	}
-
-	client.options.responseCallback(data)
-}

+ 6 - 6
network/udp_server.go

@@ -15,11 +15,11 @@ var UDPServerIgnoreResponse = errors.New("忽略响应")
 
 // UDPServerRequestCallback 请求回调
 // 参数:
-// data: 请求数据,大端模式,可以用DataReader处理
+// dataReader: 请求数据DataReader
 // 返回值:
-// 响应数据: 大端模式,可以用DataWriter构造
+// 响应数据DataReader: 可以使用DataWriter构造,然后使用ToReader转换成DataReader
 // 是否存在错误: 如果是UDPServerIgnoreResponse,则忽略,不进行响应
-type UDPServerRequestCallback func(data []byte) ([]byte, error)
+type UDPServerRequestCallback func(dataReader *DataReader) (*DataReader, error)
 
 type UDPServerOption func(opt *UDPServerOptions)
 
@@ -99,7 +99,7 @@ func (server *UDPServer) Connect(address string, options *UDPServerOptions) erro
 	server.doneChan = make(chan any)
 
 	// 启动读取请求协程
-	server.readRequest()
+	go server.readRequest()
 
 	return nil
 }
@@ -154,7 +154,7 @@ func (server *UDPServer) dealRequest(data []byte, rAddr *net.UDPAddr, doneChan c
 			}
 
 			// 交给上层回调处理,返回处理结果和响应数据
-			responseBytes, err := server.options.requestCallback(data)
+			responseDataReader, err := server.options.requestCallback(NewDataReader(data))
 			if err != nil {
 				// 忽略响应
 				if errors.Is(err, UDPServerIgnoreResponse) {
@@ -165,7 +165,7 @@ func (server *UDPServer) dealRequest(data []byte, rAddr *net.UDPAddr, doneChan c
 				return
 			}
 
-			server.response(server.conn, rAddr, responseBytes)
+			server.response(server.conn, rAddr, responseDataReader.GetBytes())
 			return
 		}
 	}

+ 44 - 8
network/udp_test.go

@@ -1,8 +1,9 @@
 package network
 
 import (
-	"bytes"
+	"fmt"
 	"strings"
+	"sync"
 	"testing"
 	"time"
 )
@@ -11,6 +12,9 @@ const (
 	testUDPServerAddress           = "127.0.0.1:10060"
 	testUDPServerTimeout           = time.Second
 	testUDPServerReceiveBufferSize = 1024
+
+	testUDPClientTimeout           = time.Second
+	testUDPClientReceiveBufferSize = 1024
 )
 
 func TestUDP(t *testing.T) {
@@ -20,16 +24,19 @@ func TestUDP(t *testing.T) {
 		WithUDPServerReadTimeout(testUDPServerTimeout),
 		WithUDPServerWriteTimeout(testUDPServerTimeout),
 		WithUDPServerReceiveBufferSize(testUDPServerReceiveBufferSize),
-		WithUDPServerRequestCallback(func(data []byte) ([]byte, error) {
-			responseBytes := []byte(strings.ToUpper(string(data)))
-			responseBuffer := &bytes.Buffer{}
-			responseReader := NewDataWriter(responseBuffer)
-			err := responseReader.Bytes(responseBytes)
+		WithUDPServerRequestCallback(func(dataReader *DataReader) (*DataReader, error) {
+			requestBytes, err := dataReader.Bytes(dataReader.Len())
+			if err != nil {
+				return nil, err
+			}
+
+			responseWriter := NewDataWriter()
+			err = responseWriter.Bytes([]byte(strings.ToUpper(string(requestBytes))))
 			if err != nil {
 				return nil, err
 			}
 
-			return responseBuffer.Bytes(), nil
+			return responseWriter.ToReader(), nil
 		}),
 	))
 	if err != nil {
@@ -38,5 +45,34 @@ func TestUDP(t *testing.T) {
 
 	defer server.Disconnect()
 
-	// TODO 完成客户端后补充
+	wg := &sync.WaitGroup{}
+	wg.Add(2)
+
+	client := &UDPClient{}
+	err = client.Connect(testUDPServerAddress, NewUDPClientOptions(
+		WithUDPClientReadTimeout(testUDPClientTimeout),
+		WithUDPClientWriteTimeout(testUDPClientTimeout),
+		WithUDPClientReceiveBufferSize(testUDPClientReceiveBufferSize),
+		WithUDPClientRequestNonBlockCount(2),
+		WithUDPClientResponseCallback(func(dataReader *DataReader) {
+			requestBytes, err := dataReader.Bytes(dataReader.Len())
+			if err != nil {
+				fmt.Println(err)
+				return
+			}
+
+			fmt.Println(string(requestBytes))
+			wg.Done()
+		}),
+	))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	defer client.Disconnect()
+
+	client.Send([]byte("test1"))
+	client.Send([]byte("test2"))
+
+	wg.Wait()
 }