Browse Source

完成writer

yjp 1 year ago
parent
commit
e7de951301
4 changed files with 254 additions and 39 deletions
  1. 18 15
      network/data_reader.go
  2. 24 24
      network/data_reader_test.go
  3. 76 0
      network/data_writer.go
  4. 136 0
      network/data_writer_test.go

+ 18 - 15
network/data_reader.go

@@ -25,54 +25,57 @@ func (reader *DataReader) Byte() (byte, error) {
 }
 
 func (reader *DataReader) Bytes(bytesLen int) ([]byte, error) {
-	retBytes := make([]byte, bytesLen)
-	err := binary.Read(reader.r, binary.BigEndian, retBytes)
+	bs := make([]byte, bytesLen)
+	err := binary.Read(reader.r, binary.BigEndian, bs)
 	if err != nil {
 		return nil, err
 	}
 
-	return retBytes, nil
+	return bs, nil
 }
 
 func (reader *DataReader) Uint8() (uint8, error) {
-	var b uint8
+	var u uint8
 
-	err := binary.Read(reader.r, binary.BigEndian, &b)
+	err := binary.Read(reader.r, binary.BigEndian, &u)
 	if err != nil {
 		return 0, err
 	}
 
-	return b, nil
+	return u, nil
 }
 
 func (reader *DataReader) Uint16() (uint16, error) {
-	retBytes := make([]byte, 2)
-	_, err := reader.r.Read(retBytes)
+	var u uint16
+
+	err := binary.Read(reader.r, binary.BigEndian, &u)
 	if err != nil {
 		return 0, err
 	}
 
-	return binary.BigEndian.Uint16(retBytes), nil
+	return u, nil
 }
 
 func (reader *DataReader) Uint32() (uint32, error) {
-	retBytes := make([]byte, 4)
-	_, err := reader.r.Read(retBytes)
+	var u uint32
+
+	err := binary.Read(reader.r, binary.BigEndian, &u)
 	if err != nil {
 		return 0, err
 	}
 
-	return binary.BigEndian.Uint32(retBytes), nil
+	return u, nil
 }
 
 func (reader *DataReader) Uint64() (uint64, error) {
-	retBytes := make([]byte, 8)
-	_, err := reader.r.Read(retBytes)
+	var u uint64
+
+	err := binary.Read(reader.r, binary.BigEndian, &u)
 	if err != nil {
 		return 0, err
 	}
 
-	return binary.BigEndian.Uint64(retBytes), nil
+	return u, nil
 }
 
 func (reader *DataReader) String(bytesLen int) (string, error) {

+ 24 - 24
network/data_reader_test.go

@@ -7,52 +7,52 @@ import (
 )
 
 const (
-	byteValue   byte   = 1
-	uint8Value  uint8  = 5
-	uint16Value uint16 = 6
-	uint32Value uint32 = 7
-	uint64Value uint64 = 8
-	strValue           = "91011"
+	readerByteValue   byte   = 1
+	readerUint8Value  uint8  = 5
+	readerUint16Value uint16 = 6
+	readerUint32Value uint32 = 7
+	readerUint64Value uint64 = 8
+	readerStringValue        = "91011"
 )
 
 var (
-	bytesValue = []byte{2, 3, 4}
+	readerBytesValue = []byte{2, 3, 4}
 )
 
 func TestDataReader(t *testing.T) {
 	buffer := &bytes.Buffer{}
 
-	err := binary.Write(buffer, binary.BigEndian, byteValue)
+	err := binary.Write(buffer, binary.BigEndian, readerByteValue)
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	err = binary.Write(buffer, binary.BigEndian, bytesValue)
+	err = binary.Write(buffer, binary.BigEndian, readerBytesValue)
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	err = binary.Write(buffer, binary.BigEndian, uint8Value)
+	err = binary.Write(buffer, binary.BigEndian, readerUint8Value)
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	err = binary.Write(buffer, binary.BigEndian, uint16Value)
+	err = binary.Write(buffer, binary.BigEndian, readerUint16Value)
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	err = binary.Write(buffer, binary.BigEndian, uint32Value)
+	err = binary.Write(buffer, binary.BigEndian, readerUint32Value)
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	err = binary.Write(buffer, binary.BigEndian, uint64Value)
+	err = binary.Write(buffer, binary.BigEndian, readerUint64Value)
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	err = binary.Write(buffer, binary.BigEndian, []byte(strValue))
+	err = binary.Write(buffer, binary.BigEndian, []byte(readerStringValue))
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -64,21 +64,21 @@ func TestDataReader(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	if retByteValue != byteValue {
+	if retByteValue != readerByteValue {
 		t.Fatal("byte不正确", err.Error())
 	}
 
-	retBytesValue, err := reader.Bytes(len(bytesValue))
+	retBytesValue, err := reader.Bytes(len(readerBytesValue))
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	if len(retBytesValue) != len(bytesValue) {
+	if len(retBytesValue) != len(readerBytesValue) {
 		t.Fatal("bytes长度不正确")
 	}
 
 	for i := 0; i < len(retBytesValue); i++ {
-		if retBytesValue[i] != bytesValue[i] {
+		if retBytesValue[i] != readerBytesValue[i] {
 			t.Fatal("bytes不正确")
 		}
 	}
@@ -88,7 +88,7 @@ func TestDataReader(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	if retUint8Value != uint8Value {
+	if retUint8Value != readerUint8Value {
 		t.Fatal("uint8不正确", err.Error())
 	}
 
@@ -97,7 +97,7 @@ func TestDataReader(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	if retUint16Value != uint16Value {
+	if retUint16Value != readerUint16Value {
 		t.Fatal("uint16不正确", err.Error())
 	}
 
@@ -106,7 +106,7 @@ func TestDataReader(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	if retUint32Value != uint32Value {
+	if retUint32Value != readerUint32Value {
 		t.Fatal("uint32不正确", err.Error())
 	}
 
@@ -115,16 +115,16 @@ func TestDataReader(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	if retUint64Value != uint64Value {
+	if retUint64Value != readerUint64Value {
 		t.Fatal("uint64不正确", err.Error())
 	}
 
-	retStringValue, err := reader.String(len(strValue))
+	retStringValue, err := reader.String(len(readerStringValue))
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	if retStringValue != strValue {
+	if retStringValue != readerStringValue {
 		t.Fatal("uint64不正确", err.Error())
 	}
 }

+ 76 - 0
network/data_writer.go

@@ -1 +1,77 @@
 package network
+
+import (
+	"encoding/binary"
+	"io"
+)
+
+type DataWriter struct {
+	w io.Writer
+}
+
+func NewDataWriter(w io.Writer) *DataWriter {
+	return &DataWriter{w: w}
+}
+
+func (reader *DataWriter) Byte(b byte) error {
+	err := binary.Write(reader.w, binary.BigEndian, b)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (reader *DataWriter) Bytes(bs []byte) error {
+	err := binary.Write(reader.w, binary.BigEndian, bs)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (reader *DataWriter) Uint8(u uint8) error {
+	err := binary.Write(reader.w, binary.BigEndian, u)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (reader *DataWriter) Uint16(u uint16) error {
+	err := binary.Write(reader.w, binary.BigEndian, u)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (reader *DataWriter) Uint32(u uint32) error {
+	err := binary.Write(reader.w, binary.BigEndian, u)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (reader *DataWriter) Uint64(u uint64) error {
+	err := binary.Write(reader.w, binary.BigEndian, u)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (reader *DataWriter) String(s string) error {
+	err := binary.Write(reader.w, binary.BigEndian, []byte(s))
+	if err != nil {
+		return err
+	}
+
+	return nil
+}

+ 136 - 0
network/data_writer_test.go

@@ -0,0 +1,136 @@
+package network
+
+import (
+	"bytes"
+	"encoding/binary"
+	"testing"
+)
+
+const (
+	writerByteValue   byte   = 1
+	writerUint8Value  uint8  = 5
+	writerUint16Value uint16 = 6
+	writerUint32Value uint32 = 7
+	writerUint64Value uint64 = 8
+	writerStringValue        = "91011"
+)
+
+var (
+	writerBytesValue = []byte{2, 3, 4}
+)
+
+func TestDataWriter(t *testing.T) {
+	buffer := &bytes.Buffer{}
+	writer := NewDataWriter(buffer)
+
+	err := writer.Byte(writerByteValue)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = writer.Bytes(writerBytesValue)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = writer.Uint8(writerUint8Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = writer.Uint16(writerUint16Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = writer.Uint32(writerUint32Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = writer.Uint64(writerUint64Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = writer.String(writerStringValue)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	var retByteValue byte
+	err = binary.Read(buffer, binary.BigEndian, &retByteValue)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retByteValue != writerByteValue {
+		t.Fatal("byte不一致")
+	}
+
+	retBytesValue := make([]byte, len(writerBytesValue))
+	err = binary.Read(buffer, binary.BigEndian, &retBytesValue)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if len(retBytesValue) != len(readerBytesValue) {
+		t.Fatal("bytes长度不正确")
+	}
+
+	for i := 0; i < len(retBytesValue); i++ {
+		if retBytesValue[i] != readerBytesValue[i] {
+			t.Fatal("bytes不正确")
+		}
+	}
+
+	var retUint8Value uint8
+	err = binary.Read(buffer, binary.BigEndian, &retUint8Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retUint8Value != writerUint8Value {
+		t.Fatal("uint8不一致")
+	}
+
+	var retUint16Value uint16
+	err = binary.Read(buffer, binary.BigEndian, &retUint16Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retUint16Value != writerUint16Value {
+		t.Fatal("uint16不一致")
+	}
+
+	var retUint32Value uint32
+	err = binary.Read(buffer, binary.BigEndian, &retUint32Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retUint32Value != writerUint32Value {
+		t.Fatal("uint32不一致")
+	}
+
+	var retUint64Value uint64
+	err = binary.Read(buffer, binary.BigEndian, &retUint64Value)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retUint64Value != writerUint64Value {
+		t.Fatal("uint64不一致")
+	}
+
+	retStringByteValue := make([]byte, len(writerStringValue))
+	err = binary.Read(buffer, binary.BigEndian, &retStringByteValue)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if string(retStringByteValue) != writerStringValue {
+		t.Fatal("string不一致")
+	}
+}