yjp před 23 hodinami
rodič
revize
e372c9fe2d

+ 17 - 17
convenient/relation/many2many/service.go

@@ -58,6 +58,7 @@ func Update(middleTableName string,
 			return nil, err
 		}
 
+		tableRows := make([]sql.TableRow, len(toIDs))
 		if toIDs != nil && len(toIDs) != 0 {
 			for _, toID := range toIDs {
 				err := entity.CheckIDTypeValue(fromDomainCNName, fromRelationFieldName, toID)
@@ -67,39 +68,38 @@ func Update(middleTableName string,
 			}
 
 			toIDs = slice.RemoveRepeatElement(toIDs)
-		}
-
-		err = database.Transaction(dbExecutor, func(tx database.Executor) error {
-			err := database.Delete(tx, &sql.DeleteExecuteParams{
-				TableName:  middleTableName,
-				Conditions: sql.NewConditions().Equal(fromRelationColumnName, fromEntity.GetID()),
-			})
-			if err != nil {
-				return err
-			}
-
-			if toIDs == nil || len(toIDs) == 0 {
-				return nil
-			}
 
 			toCount, err := database.Count(dbExecutor, &sql.CountExecuteParams{
 				TableName:  toTableName,
 				Conditions: sql.NewConditions().In(entity.ColumnID, toIDs),
 			})
 			if err != nil {
-				return err
+				return nil, err
 			}
 
 			if int(toCount) != len(toIDs) {
-				return errors.New("部分{{ $toCNName }}不存在")
+				return nil, errors.New("部分{{ $toCNName }}不存在")
 			}
 
-			tableRows := make([]sql.TableRow, len(toIDs))
 			for index, toID := range toIDs {
 				tableRows[index] = *(sql.NewTableRow().
 					Add(fromRelationColumnName, fromEntity.GetID()).
 					Add(toRelationColumnName, toID))
 			}
+		}
+
+		err = database.Transaction(dbExecutor, func(tx database.Executor) error {
+			err := database.Delete(tx, &sql.DeleteExecuteParams{
+				TableName:  middleTableName,
+				Conditions: sql.NewConditions().Equal(fromRelationColumnName, fromEntity.GetID()),
+			})
+			if err != nil {
+				return err
+			}
+
+			if tableRows == nil || len(tableRows) == 0 {
+				return nil
+			}
 
 			err = database.InsertBatch(tx, &sql.InsertBatchExecuteParams{
 				TableName:     middleTableName,

+ 14 - 13
convenient/relation/one2many/service.go

@@ -1,6 +1,8 @@
 package one2many
 
 import (
+	"reflect"
+
 	"git.sxidc.com/go-framework/baize/framework/binding"
 	"git.sxidc.com/go-framework/baize/framework/core/api"
 	"git.sxidc.com/go-framework/baize/framework/core/api/request"
@@ -15,7 +17,6 @@ import (
 	"git.sxidc.com/go-tools/utils/slice"
 	"git.sxidc.com/go-tools/utils/strutils"
 	"github.com/pkg/errors"
-	"reflect"
 )
 
 func UpdateLeft(leftTableName string, leftDomainCNName string, leftRelationFieldName string, leftRelationColumnName string,
@@ -70,32 +71,32 @@ func UpdateLeft(leftTableName string, leftDomainCNName string, leftRelationField
 			}
 
 			rightIDs = slice.RemoveRepeatElement(rightIDs)
-		}
 
-		err = database.Transaction(dbExecutor, func(tx database.Executor) error {
-			err := database.Update(tx, &sql.UpdateExecuteParams{
+			rightCount, err := database.Count(dbExecutor, &sql.CountExecuteParams{
 				TableName:  rightTableName,
-				TableRow:   sql.NewTableRow().Add(leftRelationColumnName, ""),
-				Conditions: sql.NewConditions().Equal(leftRelationColumnName, leftEntity.GetID()),
+				Conditions: sql.NewConditions().In(entity.ColumnID, rightIDs),
 			})
 			if err != nil {
-				return err
+				return nil, err
 			}
 
-			if rightIDs == nil || len(rightIDs) == 0 {
-				return nil
+			if int(rightCount) != len(rightIDs) {
+				return nil, errors.New("部分" + rightDomainCNName + "不存在")
 			}
+		}
 
-			rightCount, err := database.Count(dbExecutor, &sql.CountExecuteParams{
+		err = database.Transaction(dbExecutor, func(tx database.Executor) error {
+			err := database.Update(tx, &sql.UpdateExecuteParams{
 				TableName:  rightTableName,
-				Conditions: sql.NewConditions().In(entity.ColumnID, rightIDs),
+				TableRow:   sql.NewTableRow().Add(leftRelationColumnName, ""),
+				Conditions: sql.NewConditions().Equal(leftRelationColumnName, leftEntity.GetID()),
 			})
 			if err != nil {
 				return err
 			}
 
-			if int(rightCount) != len(rightIDs) {
-				return errors.New("部分" + rightDomainCNName + "不存在")
+			if rightIDs == nil || len(rightIDs) == 0 {
+				return nil
 			}
 
 			err = database.Update(tx, &sql.UpdateExecuteParams{

+ 17 - 17
convenient/relation/remote/service.go

@@ -60,6 +60,7 @@ func Update(middleTableName string,
 			return nil, err
 		}
 
+		tableRows := make([]sql.TableRow, len(toIDs))
 		if toIDs != nil && len(toIDs) != 0 {
 			for _, toID := range toIDs {
 				err := entity.CheckIDTypeValue(fromDomainCNName, fromRelationFieldName, toID)
@@ -69,20 +70,6 @@ func Update(middleTableName string,
 			}
 
 			toIDs = slice.RemoveRepeatElement(toIDs)
-		}
-
-		err = database.Transaction(dbExecutor, func(tx database.Executor) error {
-			err := database.Delete(tx, &sql.DeleteExecuteParams{
-				TableName:  middleTableName,
-				Conditions: sql.NewConditions().Equal(fromRelationColumnName, fromEntity.GetID()),
-			})
-			if err != nil {
-				return err
-			}
-
-			if toIDs == nil || len(toIDs) == 0 {
-				return nil
-			}
 
 			if !toRemote {
 				toCount, err := database.Count(dbExecutor, &sql.CountExecuteParams{
@@ -90,20 +77,33 @@ func Update(middleTableName string,
 					Conditions: sql.NewConditions().In(entity.ColumnID, toIDs),
 				})
 				if err != nil {
-					return err
+					return nil, err
 				}
 
 				if int(toCount) != len(toIDs) {
-					return errors.New("部分{{ $toCNName }}不存在")
+					return nil, errors.New("部分{{ $toCNName }}不存在")
 				}
 			}
 
-			tableRows := make([]sql.TableRow, len(toIDs))
 			for index, toID := range toIDs {
 				tableRows[index] = *(sql.NewTableRow().
 					Add(fromRelationColumnName, fromEntity.GetID()).
 					Add(toRelationColumnName, toID))
 			}
+		}
+
+		err = database.Transaction(dbExecutor, func(tx database.Executor) error {
+			err := database.Delete(tx, &sql.DeleteExecuteParams{
+				TableName:  middleTableName,
+				Conditions: sql.NewConditions().Equal(fromRelationColumnName, fromEntity.GetID()),
+			})
+			if err != nil {
+				return err
+			}
+
+			if tableRows == nil || len(tableRows) == 0 {
+				return nil
+			}
 
 			err = database.InsertBatch(tx, &sql.InsertBatchExecuteParams{
 				TableName:     middleTableName,