package sql_tpl import ( "errors" "git.sxidc.com/go-tools/utils/encoding" "git.sxidc.com/go-tools/utils/strutils" "reflect" "strconv" "strings" "time" ) const ( timeWriteFormat = time.DateTime + ".000000 +08:00" ) type AfterParsedStrValueOption func(strValue string) (string, error) func WithAESKey(aesKey string) AfterParsedStrValueOption { return func(strValue string) (string, error) { if strutils.IsStringEmpty(strValue) { return "''", nil } encrypted, err := encoding.AESEncrypt(strValue, aesKey) if err != nil { return "", err } return "'" + encrypted + "'", nil } } func parseValue(value any, opts ...AfterParsedStrValueOption) (string, error) { valueValue := reflect.ValueOf(value) if !valueValue.IsValid() { return "", errors.New("无效值") } if valueValue.Kind() == reflect.Ptr && valueValue.IsNil() { return "", errors.New("空值") } if valueValue.Kind() == reflect.Ptr { valueValue = valueValue.Elem() } var parsedValue string switch v := valueValue.Interface().(type) { case string: parsedValue = v if opts == nil || len(opts) == 0 { return "'" + parsedValue + "'", nil } case bool: parsedValue = strconv.FormatBool(v) if opts == nil || len(opts) == 0 { return parsedValue, nil } case time.Time: parsedValue = v.Format(timeWriteFormat) if opts == nil || len(opts) == 0 { return "'" + parsedValue + "'", nil } case int: parsedValue = strconv.Itoa(v) if opts == nil || len(opts) == 0 { return parsedValue, nil } case int8: parsedValue = strconv.FormatInt(int64(v), 10) if opts == nil || len(opts) == 0 { return parsedValue, nil } case int16: parsedValue = strconv.FormatInt(int64(v), 10) if opts == nil || len(opts) == 0 { return parsedValue, nil } case int32: parsedValue = strconv.FormatInt(int64(v), 10) if opts == nil || len(opts) == 0 { return parsedValue, nil } case int64: parsedValue = strconv.FormatInt(v, 10) if opts == nil || len(opts) == 0 { return parsedValue, nil } case uint: parsedValue = strconv.FormatUint(uint64(v), 10) if opts == nil || len(opts) == 0 { return parsedValue, nil } case uint8: parsedValue = strconv.FormatUint(uint64(v), 10) if opts == nil || len(opts) == 0 { return parsedValue, nil } case uint16: parsedValue = strconv.FormatUint(uint64(v), 10) if opts == nil || len(opts) == 0 { return parsedValue, nil } case uint32: parsedValue = strconv.FormatUint(uint64(v), 10) if opts == nil || len(opts) == 0 { return parsedValue, nil } case uint64: parsedValue = strconv.FormatUint(v, 10) if opts == nil || len(opts) == 0 { return parsedValue, nil } default: return "", errors.New("不支持的类型") } for _, opt := range opts { innerParsedValue, err := opt(parsedValue) if err != nil { return "", err } parsedValue = innerParsedValue } return parsedValue, nil } func parseSliceValues(values any, opts ...AfterParsedStrValueOption) (string, error) { sliceValue := reflect.ValueOf(values) if !sliceValue.IsValid() { return "", errors.New("无效值") } if sliceValue.Kind() == reflect.Ptr && sliceValue.IsNil() { return "()", nil } if sliceValue.Kind() == reflect.Ptr { sliceValue = sliceValue.Elem() } if sliceValue.Kind() != reflect.Slice { return "", errors.New("传递的不是slice") } parsedValues := make([]string, 0) for i := 0; i < sliceValue.Len(); i++ { value := sliceValue.Index(i).Interface() parsedValue, err := parseValue(value) if err != nil { return "", err } for _, opt := range opts { innerParsedValue, err := opt(parsedValue) if err != nil { return "", err } parsedValue = innerParsedValue } parsedValues = append(parsedValues, parsedValue) } if len(parsedValues) == 0 { return "()", nil } return "(" + strings.Join(parsedValues, ",") + ")", nil }