Browse Source

完成data_reader

yjp 1 năm trước cách đây
mục cha
commit
ca0aeae573
2 tập tin đã thay đổi với 202 bổ sung20 xóa
  1. 72 20
      network/data_reader.go
  2. 130 0
      network/data_reader_test.go

+ 72 - 20
network/data_reader.go

@@ -1,34 +1,86 @@
 package network
 
 import (
-	"bytes"
+	"encoding/binary"
 	"io"
 )
 
-// ReadBigEndianString 大端读取字符串 TODO 考虑改为一个网络字节序得写入和读取器
-func ReadBigEndianString(bigEndianBytes []byte) (string, error) {
-	// 使用bytes.Buffer来模拟一个io.Reader
-	reader := bytes.NewReader(bigEndianBytes)
+type DataReader struct {
+	r io.Reader
+}
+
+func NewDataReader(r io.Reader) *DataReader {
+	return &DataReader{r: r}
+}
 
-	// 初始化一个buffer,用于存放读取的字符串
-	var buffer []byte
+func (reader *DataReader) Byte() (byte, error) {
+	var b byte
 
-	// 使用binary.Read以大端模式读取字符串
-	// 因为字符串可能包含\x00,所以需要读取到\x00为止
-	for {
-		b := make([]byte, 1)
+	err := binary.Read(reader.r, binary.BigEndian, &b)
+	if err != nil {
+		return 0, err
+	}
 
-		_, err := reader.Read(b)
-		if err != nil && err != io.EOF {
-			return "", err
-		}
+	return b, nil
+}
+
+func (reader *DataReader) Bytes(bytesLen int) ([]byte, error) {
+	retBytes := make([]byte, bytesLen)
+	err := binary.Read(reader.r, binary.BigEndian, retBytes)
+	if err != nil {
+		return nil, err
+	}
 
-		if b[0] == 0 {
-			break
-		}
+	return retBytes, nil
+}
+
+func (reader *DataReader) Uint8() (uint8, error) {
+	var b uint8
+
+	err := binary.Read(reader.r, binary.BigEndian, &b)
+	if err != nil {
+		return 0, err
+	}
+
+	return b, nil
+}
+
+func (reader *DataReader) Uint16() (uint16, error) {
+	retBytes := make([]byte, 2)
+	_, err := reader.r.Read(retBytes)
+	if err != nil {
+		return 0, err
+	}
+
+	return binary.BigEndian.Uint16(retBytes), nil
+}
+
+func (reader *DataReader) Uint32() (uint32, error) {
+	retBytes := make([]byte, 4)
+	_, err := reader.r.Read(retBytes)
+	if err != nil {
+		return 0, err
+	}
+
+	return binary.BigEndian.Uint32(retBytes), nil
+}
+
+func (reader *DataReader) Uint64() (uint64, error) {
+	retBytes := make([]byte, 8)
+	_, err := reader.r.Read(retBytes)
+	if err != nil {
+		return 0, err
+	}
+
+	return binary.BigEndian.Uint64(retBytes), nil
+}
 
-		buffer = append(buffer, b...)
+func (reader *DataReader) String(bytesLen int) (string, error) {
+	retBytes := make([]byte, bytesLen)
+	err := binary.Read(reader.r, binary.BigEndian, retBytes)
+	if err != nil {
+		return "", err
 	}
 
-	return string(buffer), nil
+	return string(retBytes), nil
 }

+ 130 - 0
network/data_reader_test.go

@@ -0,0 +1,130 @@
+package network
+
+import (
+	"bytes"
+	"encoding/binary"
+	"testing"
+)
+
+const (
+	byteValue   byte   = 1
+	uint8Value  uint8  = 5
+	uint16Value uint16 = 6
+	uint32Value uint32 = 7
+	uint64Value uint64 = 8
+	strValue           = "91011"
+)
+
+var (
+	bytesValue = []byte{2, 3, 4}
+)
+
+func TestDataReader(t *testing.T) {
+	buffer := &bytes.Buffer{}
+
+	err := binary.Write(buffer, binary.BigEndian, byteValue)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = binary.Write(buffer, binary.BigEndian, bytesValue)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = binary.Write(buffer, binary.BigEndian, uint8Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = binary.Write(buffer, binary.BigEndian, uint16Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = binary.Write(buffer, binary.BigEndian, uint32Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = binary.Write(buffer, binary.BigEndian, uint64Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = binary.Write(buffer, binary.BigEndian, []byte(strValue))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	reader := NewDataReader(buffer)
+
+	retByteValue, err := reader.Byte()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retByteValue != byteValue {
+		t.Fatal("byte不正确", err.Error())
+	}
+
+	retBytesValue, err := reader.Bytes(len(bytesValue))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if len(retBytesValue) != len(bytesValue) {
+		t.Fatal("bytes长度不正确")
+	}
+
+	for i := 0; i < len(retBytesValue); i++ {
+		if retBytesValue[i] != bytesValue[i] {
+			t.Fatal("bytes不正确")
+		}
+	}
+
+	retUint8Value, err := reader.Uint8()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retUint8Value != uint8Value {
+		t.Fatal("uint8不正确", err.Error())
+	}
+
+	retUint16Value, err := reader.Uint16()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retUint16Value != uint16Value {
+		t.Fatal("uint16不正确", err.Error())
+	}
+
+	retUint32Value, err := reader.Uint32()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retUint32Value != uint32Value {
+		t.Fatal("uint32不正确", err.Error())
+	}
+
+	retUint64Value, err := reader.Uint64()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retUint64Value != uint64Value {
+		t.Fatal("uint64不正确", err.Error())
+	}
+
+	retStringValue, err := reader.String(len(strValue))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retStringValue != strValue {
+		t.Fatal("uint64不正确", err.Error())
+	}
+}