瀏覽代碼

添加使用反射创建结构和函数的工具函数

yjp 4 月之前
父節點
當前提交
34a318e773
共有 2 個文件被更改,包括 321 次插入0 次删除
  1. 221 0
      reflectutils/maker.go
  2. 100 0
      reflectutils/maker_test.go

+ 221 - 0
reflectutils/maker.go

@@ -0,0 +1,221 @@
+package reflectutils
+
+import (
+	"fmt"
+	"reflect"
+)
+
+type StructFieldDefinition struct {
+	Name string
+	Type reflect.Type
+	Tag  string
+}
+
+type StructMethodDefinition struct {
+	Name             string
+	ArgTypes         []reflect.Type
+	ReturnValueTypes []reflect.Type
+	Body             func(this *Struct, args ...any) []any
+}
+
+type Struct struct {
+	structValueElem reflect.Value
+	fieldMap        map[string]reflect.Value
+	methodMap       map[string]reflect.Value
+}
+
+func NewStruct(fields ...StructFieldDefinition) *Struct {
+	if len(fields) == 0 {
+		return &Struct{
+			structValueElem: reflect.ValueOf(struct{}{}),
+		}
+	}
+
+	reflectStructFields := make([]reflect.StructField, len(fields))
+	for i, field := range fields {
+		reflectStructFields[i] = reflect.StructField{
+			Name: field.Name,
+			Type: field.Type,
+			Tag:  reflect.StructTag(field.Tag),
+		}
+	}
+
+	return &Struct{
+		structValueElem: reflect.New(reflect.StructOf(reflectStructFields)).Elem(),
+		fieldMap:        map[string]reflect.Value{},
+		methodMap:       map[string]reflect.Value{},
+	}
+}
+
+func (s *Struct) Any() any {
+	return s.structValueElem.Interface()
+}
+
+func (s *Struct) Pointer() any {
+	return s.structValueElem.Addr().Interface()
+}
+
+func (s *Struct) SetFieldValue(fieldName string, value any) {
+	s.loadFieldValue(fieldName).Set(reflect.ValueOf(value))
+}
+
+func (s *Struct) FieldValue(fieldName string) any {
+	return s.loadFieldValue(fieldName).Interface()
+}
+
+func (s *Struct) SetFieldValues(fieldAndValues map[string]any) {
+	for fieldName, value := range fieldAndValues {
+		s.SetFieldValue(fieldName, value)
+	}
+}
+
+func (s *Struct) FieldValues(fieldNames ...string) map[string]any {
+	if len(fieldNames) == 0 {
+		return map[string]any{}
+	}
+
+	fieldAndValues := make(map[string]any)
+	for _, fieldName := range fieldNames {
+		fieldAndValues[fieldName] = s.FieldValue(fieldName)
+	}
+
+	return fieldAndValues
+}
+
+func (s *Struct) MakeMethod(methods ...StructMethodDefinition) {
+	for _, method := range methods {
+		argTypes := make([]reflect.Type, 0)
+		argTypes = append(argTypes, s.structValueElem.Type())
+		argTypes = append(argTypes, method.ArgTypes...)
+
+		funcValue := reflect.MakeFunc(reflect.FuncOf(argTypes, method.ReturnValueTypes, false),
+			func(argValues []reflect.Value) []reflect.Value {
+				args := make([]any, len(argValues)-1)
+				for i, arg := range argValues {
+					if i == 0 {
+						continue
+					}
+
+					args[i-1] = arg.Interface()
+				}
+
+				returns := method.Body(s, args)
+
+				returnValues := make([]reflect.Value, len(returns))
+				for j, returnValue := range returns {
+					reflectValue := reflect.ValueOf(returnValue)
+					if !reflectValue.IsValid() {
+						returnValues[j] = reflect.Zero(method.ReturnValueTypes[j])
+					} else {
+						returnValues[j] = reflect.ValueOf(returnValue)
+					}
+				}
+
+				return returnValues
+			})
+
+		s.methodMap[method.Name] = funcValue
+	}
+}
+
+func (s *Struct) CallMethod(methodName string, args ...any) []any {
+	method, ok := s.methodMap[methodName]
+	if !ok {
+		panic(fmt.Sprintf("%s方法不存在", methodName))
+	}
+
+	argValues := make([]reflect.Value, len(args)+1)
+	argValues[0] = s.structValueElem
+	for i, arg := range args {
+		argValues[i+1] = reflect.ValueOf(arg)
+	}
+
+	returnValues := method.Call(argValues)
+
+	returns := make([]any, len(returnValues))
+	for j, returnValue := range returnValues {
+		if returnValue.IsNil() {
+			returns[j] = nil
+		} else {
+			returns[j] = returnValue.Interface()
+		}
+	}
+
+	return returns
+}
+
+func (s *Struct) loadFieldValue(fieldName string) reflect.Value {
+	fieldValue, ok := s.fieldMap[fieldName]
+	if ok {
+		return fieldValue
+	}
+
+	fieldValue = s.structValueElem.FieldByName(fieldName)
+	if !fieldValue.IsValid() {
+		panic(fmt.Sprintf("%s字段不存在", fieldName))
+	}
+
+	return fieldValue
+}
+
+type FunctionDefinition struct {
+	ArgTypes         []reflect.Type
+	ReturnValueTypes []reflect.Type
+	Body             func(args ...any) []any
+}
+
+type Function struct {
+	functionValue reflect.Value
+}
+
+func (f *Function) Call(args ...any) []any {
+	argValues := make([]reflect.Value, len(args))
+	for i, arg := range args {
+		argValues[i] = reflect.ValueOf(arg)
+	}
+
+	returnValues := f.functionValue.Call(argValues)
+
+	returns := make([]any, len(returnValues))
+	for j, returnValue := range returnValues {
+		if returnValue.IsNil() {
+			returns[j] = nil
+		} else {
+			returns[j] = returnValue.Interface()
+		}
+	}
+
+	return returns
+}
+
+func MakeFunction(function FunctionDefinition) *Function {
+	funcValue := reflect.MakeFunc(reflect.FuncOf(function.ArgTypes, function.ReturnValueTypes, false),
+		func(argValues []reflect.Value) []reflect.Value {
+			args := make([]any, len(argValues)-1)
+			for i, arg := range argValues {
+				if i == 0 {
+					continue
+				}
+
+				args[i-1] = arg.Interface()
+			}
+
+			returns := function.Body(args)
+
+			returnValues := make([]reflect.Value, len(returns))
+			for j, returnValue := range returns {
+				reflectValue := reflect.ValueOf(returnValue)
+				if !reflectValue.IsValid() {
+					returnValues[j] = reflect.Zero(function.ReturnValueTypes[j])
+				} else {
+					returnValues[j] = reflect.ValueOf(returnValue)
+				}
+			}
+
+			return returnValues
+		})
+
+	return &Function{
+		functionValue: funcValue,
+	}
+}

+ 100 - 0
reflectutils/maker_test.go

@@ -0,0 +1,100 @@
+package reflectutils
+
+import (
+	"fmt"
+	"github.com/pkg/errors"
+	"reflect"
+	"testing"
+	"time"
+)
+
+func TestStruct(t *testing.T) {
+	name := "test"
+	age := 18
+	enterTime := time.Now().Local()
+
+	studentStruct := NewStruct(
+		StructFieldDefinition{
+			Name: "Name",
+			Type: reflect.TypeOf(""),
+			Tag:  "json:name",
+		},
+		StructFieldDefinition{
+			Name: "Age",
+			Type: reflect.TypeOf(0),
+			Tag:  "json:age",
+		},
+		StructFieldDefinition{
+			Name: "EnterTime",
+			Type: reflect.TypeOf(time.Time{}),
+			Tag:  "json:enterTime",
+		},
+	)
+
+	studentStruct.SetFieldValues(map[string]any{
+		"Name":      name,
+		"Age":       age,
+		"EnterTime": enterTime,
+	})
+
+	values := studentStruct.FieldValues("Name", "Age", "EnterTime")
+
+	for fieldName, value := range values {
+		switch fieldName {
+		case "Name":
+			if value != name {
+				t.Fatalf("%+v\n", errors.Errorf("名字不一致: except: %v, actual: %v", name, value))
+			}
+		case "Age":
+			if value != age {
+				t.Fatalf("%+v\n", errors.Errorf("年龄不一致: except: %v, actual: %v", age, value))
+			}
+		case "EnterTime":
+			if value != enterTime {
+				t.Fatalf("%+v\n", errors.Errorf("入学时间不一致: except: %v, actual: %v", enterTime, value))
+			}
+		default:
+			t.Fatalf("%+v\n", errors.New("不存在的字段"))
+		}
+	}
+
+	studentStruct.MakeMethod(
+		StructMethodDefinition{
+			Name:             "Print",
+			ArgTypes:         []reflect.Type{reflect.TypeOf("")},
+			ReturnValueTypes: []reflect.Type{reflect.TypeOf(errors.New(""))},
+			Body: func(s *Struct, args ...any) []any {
+				values := studentStruct.FieldValues("Name", "Age", "EnterTime")
+
+				fmt.Println("Student Info:")
+				fmt.Println("Name:", values["Name"])
+				fmt.Println("Age:", values["Age"])
+				fmt.Println("EnterTime:", values["EnterTime"].(time.Time).Format(time.DateTime))
+				fmt.Println("Arg:", args[0])
+
+				return []any{nil}
+			},
+		},
+	)
+
+	returns := studentStruct.CallMethod("Print", "Hello Args")
+	if returns[0] != nil {
+		t.Fatalf("%+v\n", errors.Errorf("%v", returns[0]))
+	}
+}
+
+func TestFunction(t *testing.T) {
+	printHelloFunc := MakeFunction(FunctionDefinition{
+		ArgTypes:         []reflect.Type{reflect.TypeOf("")},
+		ReturnValueTypes: []reflect.Type{reflect.TypeOf(errors.New(""))},
+		Body: func(args ...any) []any {
+			fmt.Printf("Hello %v!\n", args[0])
+			return []any{nil}
+		},
+	})
+
+	returns := printHelloFunc.Call("World")
+	if returns[0] != nil {
+		t.Fatalf("%+v\n", errors.Errorf("%v", returns[0]))
+	}
+}