Explorar el Código

修改插入接口

yjp hace 1 año
padre
commit
fc53283517
Se han modificado 1 ficheros con 21 adiciones y 17 borrados
  1. 21 17
      framework/core/infrastructure/database/database.go

+ 21 - 17
framework/core/infrastructure/database/database.go

@@ -52,7 +52,7 @@ func Transaction(executor Executor, txFunc func(tx Executor) error) error {
 	return nil
 }
 
-func InsertEntity(executor Executor, tableName string, e any) error {
+func InsertEntity(executor Executor, tableName string, es any) error {
 	if executor == nil {
 		return errors.New("没有传递执行器")
 	}
@@ -61,15 +61,23 @@ func InsertEntity(executor Executor, tableName string, e any) error {
 		return errors.New("没有传递表名")
 	}
 
-	if e == nil {
+	if es == nil {
 		return nil
 	}
 
-	entityType := reflect.TypeOf(e)
-	if !reflectutils.IsTypeStructOrStructPointer(entityType) {
-		return errors.New("实体参数不是结构或结构指针")
+	entityType := reflect.TypeOf(es)
+	entityElemType := reflectutils.PointerTypeElem(entityType)
+
+	if entityElemType.Kind() == reflect.Struct {
+		return insertEntitySingle(executor, tableName, es)
+	} else if entityElemType.Kind() == reflect.Slice {
+		return insertEntityBatch(executor, tableName, es)
+	} else {
+		return errors.New("实体可以是结构,结构指针,结构Slice,结构指针的Slice或Slice的指针")
 	}
+}
 
+func insertEntitySingle(executor Executor, tableName string, e any) error {
 	fields, err := sql_mapping.DefaultUsage(e)
 	if err != nil {
 		return err
@@ -95,23 +103,19 @@ func InsertEntity(executor Executor, tableName string, e any) error {
 	return nil
 }
 
-func InsertEntityBatch(executor Executor, tableName string, es []any) error {
-	if executor == nil {
-		return errors.New("没有传递执行器")
-	}
-
-	if strutils.IsStringEmpty(tableName) {
-		return errors.New("没有传递表名")
-	}
-
+func insertEntityBatch(executor Executor, tableName string, es any) error {
 	now := time.Now().Local()
 	tableRowBatch := make([]sql.TableRow, 0)
+	entitiesValue := reflectutils.PointerValueElem(reflect.ValueOf(es))
 
-	for _, e := range es {
-		if e == nil {
-			return nil
+	for i := 0; i < entitiesValue.Len(); i++ {
+		entityValue := entitiesValue.Index(i)
+		if !entityValue.IsValid() || entityValue.IsZero() {
+			continue
 		}
 
+		e := entityValue.Interface()
+
 		entityType := reflect.TypeOf(e)
 		if !reflectutils.IsTypeStructOrStructPointer(entityType) {
 			return errors.New("实体参数不是结构或结构指针")