package reflectutils

import (
	"fmt"
	"reflect"
)

type StructFieldDefinition struct {
	Name string
	Type reflect.Type
	Tag  string
}

type StructMethodDefinition struct {
	Name             string
	ArgTypes         []reflect.Type
	ReturnValueTypes []reflect.Type
	Body             func(this *Struct, args ...any) []any
}

type Struct struct {
	structValueElem reflect.Value
	fieldMap        map[string]reflect.Value
	methodMap       map[string]reflect.Value
}

func NewStruct(fields ...StructFieldDefinition) *Struct {
	if len(fields) == 0 {
		return &Struct{
			structValueElem: reflect.ValueOf(struct{}{}),
		}
	}

	reflectStructFields := make([]reflect.StructField, len(fields))
	for i, field := range fields {
		reflectStructFields[i] = reflect.StructField{
			Name: field.Name,
			Type: field.Type,
			Tag:  reflect.StructTag(field.Tag),
		}
	}

	return &Struct{
		structValueElem: reflect.New(reflect.StructOf(reflectStructFields)).Elem(),
		fieldMap:        map[string]reflect.Value{},
		methodMap:       map[string]reflect.Value{},
	}
}

func (s *Struct) Any() any {
	return s.structValueElem.Interface()
}

func (s *Struct) Pointer() any {
	return s.structValueElem.Addr().Interface()
}

func (s *Struct) SetFieldValue(fieldName string, value any) {
	s.loadFieldValue(fieldName).Set(reflect.ValueOf(value))
}

func (s *Struct) FieldValue(fieldName string) any {
	return s.loadFieldValue(fieldName).Interface()
}

func (s *Struct) SetFieldValues(fieldAndValues map[string]any) {
	for fieldName, value := range fieldAndValues {
		s.SetFieldValue(fieldName, value)
	}
}

func (s *Struct) FieldValues(fieldNames ...string) map[string]any {
	if len(fieldNames) == 0 {
		return map[string]any{}
	}

	fieldAndValues := make(map[string]any)
	for _, fieldName := range fieldNames {
		fieldAndValues[fieldName] = s.FieldValue(fieldName)
	}

	return fieldAndValues
}

func (s *Struct) MakeMethod(methods ...StructMethodDefinition) {
	for _, method := range methods {
		argTypes := make([]reflect.Type, 0)
		argTypes = append(argTypes, s.structValueElem.Type())
		argTypes = append(argTypes, method.ArgTypes...)

		funcValue := reflect.MakeFunc(reflect.FuncOf(argTypes, method.ReturnValueTypes, false),
			func(argValues []reflect.Value) []reflect.Value {
				args := make([]any, len(argValues)-1)
				for i, arg := range argValues {
					if i == 0 {
						continue
					}

					args[i-1] = arg.Interface()
				}

				returns := method.Body(s, args)

				returnValues := make([]reflect.Value, len(returns))
				for j, returnValue := range returns {
					reflectValue := reflect.ValueOf(returnValue)
					if !reflectValue.IsValid() {
						returnValues[j] = reflect.Zero(method.ReturnValueTypes[j])
					} else {
						returnValues[j] = reflect.ValueOf(returnValue)
					}
				}

				return returnValues
			})

		s.methodMap[method.Name] = funcValue
	}
}

func (s *Struct) CallMethod(methodName string, args ...any) []any {
	method, ok := s.methodMap[methodName]
	if !ok {
		panic(fmt.Sprintf("%s方法不存在", methodName))
	}

	argValues := make([]reflect.Value, len(args)+1)
	argValues[0] = s.structValueElem
	for i, arg := range args {
		argValues[i+1] = reflect.ValueOf(arg)
	}

	returnValues := method.Call(argValues)

	returns := make([]any, len(returnValues))
	for j, returnValue := range returnValues {
		if returnValue.IsNil() {
			returns[j] = nil
		} else {
			returns[j] = returnValue.Interface()
		}
	}

	return returns
}

func (s *Struct) loadFieldValue(fieldName string) reflect.Value {
	fieldValue, ok := s.fieldMap[fieldName]
	if ok {
		return fieldValue
	}

	fieldValue = s.structValueElem.FieldByName(fieldName)
	if !fieldValue.IsValid() {
		panic(fmt.Sprintf("%s字段不存在", fieldName))
	}

	return fieldValue
}

type FunctionDefinition struct {
	ArgTypes         []reflect.Type
	ReturnValueTypes []reflect.Type
	Body             func(args ...any) []any
}

type Function struct {
	functionValue reflect.Value
}

func (f *Function) Call(args ...any) []any {
	argValues := make([]reflect.Value, len(args))
	for i, arg := range args {
		argValues[i] = reflect.ValueOf(arg)
	}

	returnValues := f.functionValue.Call(argValues)

	returns := make([]any, len(returnValues))
	for j, returnValue := range returnValues {
		if returnValue.IsNil() {
			returns[j] = nil
		} else {
			returns[j] = returnValue.Interface()
		}
	}

	return returns
}

func MakeFunction(function FunctionDefinition) *Function {
	funcValue := reflect.MakeFunc(reflect.FuncOf(function.ArgTypes, function.ReturnValueTypes, false),
		func(argValues []reflect.Value) []reflect.Value {
			args := make([]any, len(argValues)-1)
			for i, arg := range argValues {
				if i == 0 {
					continue
				}

				args[i-1] = arg.Interface()
			}

			returns := function.Body(args)

			returnValues := make([]reflect.Value, len(returns))
			for j, returnValue := range returns {
				reflectValue := reflect.ValueOf(returnValue)
				if !reflectValue.IsValid() {
					returnValues[j] = reflect.Zero(function.ReturnValueTypes[j])
				} else {
					returnValues[j] = reflect.ValueOf(returnValue)
				}
			}

			return returnValues
		})

	return &Function{
		functionValue: funcValue,
	}
}