yjp 11 місяців тому
батько
коміт
42d1960ebb
2 змінених файлів з 39 додано та 12 видалено
  1. 2 2
      sql/sql_tpl/condition.go
  2. 37 10
      sql/sql_tpl/value.go

+ 2 - 2
sql/sql_tpl/condition.go

@@ -53,7 +53,7 @@ func (conditions *Conditions) In(columnName string, value any, opts ...AfterPars
 		return conditions
 	}
 
-	parsedValue, err := parseValue(value, opts...)
+	parsedValue, err := parseSliceValues(value, opts...)
 	if err != nil {
 		conditions.err = err
 		return conditions
@@ -69,7 +69,7 @@ func (conditions *Conditions) NotIn(columnName string, value any, opts ...AfterP
 		return conditions
 	}
 
-	parsedValue, err := parseValue(value, opts...)
+	parsedValue, err := parseSliceValues(value, opts...)
 	if err != nil {
 		conditions.err = err
 		return conditions

+ 37 - 10
sql/sql_tpl/value.go

@@ -127,16 +127,6 @@ func parseValue(value any, opts ...AfterParsedStrValueOption) (string, error) {
 		if opts == nil || len(opts) == 0 {
 			return parsedValue, nil
 		}
-	case []string:
-		if v == nil || len(v) == 0 {
-			parsedValue = ""
-		} else {
-			parsedValue = strings.Join(v, ",")
-		}
-
-		if opts == nil || len(opts) == 0 {
-			return "'" + parsedValue + "'", nil
-		}
 	default:
 		return "", errors.New("不支持的类型")
 	}
@@ -152,3 +142,40 @@ func parseValue(value any, opts ...AfterParsedStrValueOption) (string, error) {
 
 	return parsedValue, nil
 }
+
+func parseSliceValues(values any, opts ...AfterParsedStrValueOption) (string, error) {
+	sliceValue := reflect.ValueOf(values)
+
+	if !sliceValue.IsValid() {
+		return "", errors.New("无效值")
+	}
+
+	if sliceValue.Kind() == reflect.Ptr && sliceValue.IsNil() {
+		return "()", nil
+	}
+
+	if sliceValue.Kind() == reflect.Ptr {
+		sliceValue = sliceValue.Elem()
+	}
+
+	if sliceValue.Kind() != reflect.Slice {
+		return "", errors.New("传递的不是slice")
+	}
+
+	parsedValues := make([]string, 0)
+	for i := 0; i < sliceValue.Len(); i++ {
+		value := sliceValue.Index(i).Interface()
+		parsedValue, err := parseValue(value)
+		if err != nil {
+			return "", err
+		}
+
+		parsedValues = append(parsedValues, parsedValue)
+	}
+
+	if len(parsedValues) == 0 {
+		return "()", nil
+	}
+
+	return "(" + strings.Join(parsedValues, ",") + ")", nil
+}