Bladeren bron

添加yaml加载函数

yjp 1 jaar geleden
bovenliggende
commit
3388288de1
4 gewijzigde bestanden met toevoegingen van 169 en 32 verwijderingen
  1. 49 32
      fileutils/fileutils.go
  2. 7 0
      yaml/test.yaml
  3. 54 0
      yaml/yaml.go
  4. 59 0
      yaml/yaml_test.go

+ 49 - 32
fileutils/fileutils.go

@@ -97,16 +97,60 @@ func ZipDir(dirPath string, savePath string) error {
 }
 
 func UnzipFile(srcFilePath string, destDir string) error {
-	fileBytes, err := os.ReadFile(srcFilePath)
+	srcFile, err := os.Open(srcFilePath)
 	if err != nil {
 		return err
 	}
 
-	return UnzipBytes(fileBytes, destDir)
+	defer func() {
+		err := srcFile.Close()
+		if err != nil {
+			log.Println(err)
+			return
+		}
+	}()
+
+	srcFileInfo, err := srcFile.Stat()
+	if err != nil {
+		return err
+	}
+
+	return unzip(srcFile, srcFileInfo.Size(), destDir)
 }
 
 func UnzipBytes(zipFileBytes []byte, destDir string) error {
-	zipReader, err := zip.NewReader(bytes.NewReader(zipFileBytes), int64(len(zipFileBytes)))
+	return unzip(bytes.NewReader(zipFileBytes), int64(len(zipFileBytes)), destDir)
+}
+
+func writeZipFile(filePath string, zipPath string, zipWriter *zip.Writer) error {
+	file, err := os.Open(filePath)
+	if err != nil {
+		return err
+	}
+
+	defer func() {
+		err := file.Close()
+		if err != nil {
+			log.Println(err)
+			return
+		}
+	}()
+
+	writer, err := zipWriter.Create(zipPath)
+	if err != nil {
+		return err
+	}
+
+	_, err = io.Copy(writer, file)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func unzip(r io.ReaderAt, size int64, destDir string) error {
+	zipReader, err := zip.NewReader(r, size)
 	if err != nil {
 		return err
 	}
@@ -114,11 +158,11 @@ func UnzipBytes(zipFileBytes []byte, destDir string) error {
 	for _, f := range zipReader.File {
 		fPath := filepath.Join(destDir, f.Name)
 		if f.FileInfo().IsDir() {
-			if err = os.MkdirAll(fPath, os.ModePerm); err != nil {
+			if err := os.MkdirAll(fPath, os.ModePerm); err != nil {
 				return err
 			}
 		} else {
-			if err = os.MkdirAll(filepath.Dir(fPath), os.ModePerm); err != nil {
+			if err := os.MkdirAll(filepath.Dir(fPath), os.ModePerm); err != nil {
 				return err
 			}
 
@@ -148,30 +192,3 @@ func UnzipBytes(zipFileBytes []byte, destDir string) error {
 
 	return nil
 }
-
-func writeZipFile(filePath string, zipPath string, zipWriter *zip.Writer) error {
-	file, err := os.Open(filePath)
-	if err != nil {
-		return err
-	}
-
-	defer func() {
-		err := file.Close()
-		if err != nil {
-			log.Println(err)
-			return
-		}
-	}()
-
-	writer, err := zipWriter.Create(zipPath)
-	if err != nil {
-		return err
-	}
-
-	_, err = io.Copy(writer, file)
-	if err != nil {
-		return err
-	}
-
-	return nil
-}

+ 7 - 0
yaml/test.yaml

@@ -0,0 +1,7 @@
+test1:
+  name: "foo1"
+
+---
+
+test2:
+  name: "foo2"

+ 54 - 0
yaml/yaml.go

@@ -0,0 +1,54 @@
+package yaml
+
+import (
+	"bytes"
+	"errors"
+	"fmt"
+	"gopkg.in/yaml.v3"
+	"io"
+	"os"
+	"reflect"
+)
+
+func LoadYamlFile(yamlFilePath string, retObjects *[]any) error {
+	yamlFile, err := os.Open(yamlFilePath)
+	if err != nil {
+		return err
+	}
+
+	defer func(yamlFile *os.File) {
+		err := yamlFile.Close()
+		if err != nil {
+			fmt.Println(err)
+		}
+	}(yamlFile)
+
+	return loadYaml(yamlFile, retObjects)
+}
+
+func LoadYamlBytes(yamlBytes []byte, retObjects *[]any) error {
+	return loadYaml(bytes.NewReader(yamlBytes), retObjects)
+}
+
+func loadYaml(r io.Reader, retObjects *[]any) error {
+	decoder := yaml.NewDecoder(r)
+
+	for i := 0; i < len(*retObjects); i++ {
+		retObject := (*retObjects)[i]
+
+		if reflect.TypeOf(retObject).Kind() != reflect.Ptr {
+			return errors.New("返回对象slice元素需要指针")
+		}
+
+		err := decoder.Decode(retObject)
+		if err != nil {
+			if err == io.EOF {
+				break
+			}
+
+			return err
+		}
+	}
+
+	return nil
+}

+ 59 - 0
yaml/yaml_test.go

@@ -0,0 +1,59 @@
+package yaml
+
+import (
+	"os"
+	"testing"
+)
+
+type TestYamlModel1 struct {
+	Test `yaml:"test1"`
+}
+
+type TestYamlModel2 struct {
+	Test `yaml:"test2"`
+}
+
+type Test struct {
+	Name string `yaml:"name"`
+}
+
+func TestLoadYAMLFile(t *testing.T) {
+	testModel1 := new(TestYamlModel1)
+	testModel2 := new(TestYamlModel2)
+	retObjects := []any{testModel1, testModel2}
+
+	err := LoadYamlFile("test.yaml", &retObjects)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retObjects[0].(*TestYamlModel1).Name != "foo1" {
+		t.Fatal("model1名称错误")
+	}
+
+	if retObjects[1].(*TestYamlModel2).Name != "foo2" {
+		t.Fatal("model2名称错误")
+	}
+
+	testModel1 = new(TestYamlModel1)
+	testModel2 = new(TestYamlModel2)
+	retObjects = []any{testModel1, testModel2}
+
+	yamlFileBytes, err := os.ReadFile("test.yaml")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = LoadYamlBytes(yamlFileBytes, &retObjects)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if retObjects[0].(*TestYamlModel1).Name != "foo1" {
+		t.Fatal("model1名称错误")
+	}
+
+	if retObjects[1].(*TestYamlModel2).Name != "foo2" {
+		t.Fatal("model2名称错误")
+	}
+}