Bladeren bron

完成所有查询函数添加

yjp 1 jaar geleden
bovenliggende
commit
6e7bce61f2
4 gewijzigde bestanden met toevoegingen van 472 en 54 verwijderingen
  1. 19 0
      sdk/error.go
  2. 395 23
      sdk/sql.go
  3. 45 30
      sdk/tag/sql_mapping.go
  4. 13 1
      test/sdk_test.go

+ 19 - 0
sdk/error.go

@@ -0,0 +1,19 @@
+package sdk
+
+import (
+	"errors"
+	"strings"
+)
+
+var (
+	ErrDBRecordHasExist = errors.New("记录已存在")
+	ErrDBRecordNotExist = errors.New("记录不存在")
+)
+
+func IsErrorDBRecordHasExist(err error) bool {
+	return strings.Contains(err.Error(), "记录已存在")
+}
+
+func IsErrorDBRecordNotExist(err error) bool {
+	return strings.Contains(err.Error(), "记录不存在")
+}

+ 395 - 23
sdk/sql.go

@@ -7,6 +7,7 @@ import (
 	"git.sxidc.com/service-supports/ds-sdk/sdk/tag"
 	"reflect"
 	"strconv"
+	"strings"
 	"time"
 )
 
@@ -21,9 +22,10 @@ const (
 	lastUpdatedTimeFieldName = "LastUpdatedTime"
 )
 
-type InsertCallback[T any] func(e T, fieldName string, value any) (retValue any, err error)
+type ValueCallback[T any] func(e T, fieldName string, value any) (retValue any, err error)
+type ConditionCallback[T any] func(e T, fieldName string, columnName string, value any) (retConditionOp string, retConditionValue any, err error)
 
-func Insert[T any](executor SqlExecutor, tableName string, e T, callback InsertCallback[T]) error {
+func Insert[T any](executor SqlExecutor, tableName string, e T, callback ValueCallback[T]) error {
 	if executor == nil {
 		return errors.New("没有传递执行器")
 	}
@@ -96,6 +98,10 @@ func Insert[T any](executor SqlExecutor, tableName string, e T, callback InsertC
 
 	_, err = executor.ExecuteRawSql(raw_sql_tpl.InsertTpl, executeParamsMap)
 	if err != nil {
+		if strings.Contains(err.Error(), "SQLSTATE 23505") {
+			return ErrDBRecordHasExist
+		}
+
 		return err
 	}
 
@@ -161,9 +167,7 @@ func Delete[T any](executor SqlExecutor, tableName string, e T) error {
 	return nil
 }
 
-type UpdateCallback[T any] func(e T, fieldName string, value any) (retValue any, err error)
-
-func Update[T any](executor SqlExecutor, tableName string, e T, callback UpdateCallback[T]) error {
+func Update[T any](executor SqlExecutor, tableName string, e T, callback ValueCallback[T]) error {
 	if executor == nil {
 		return errors.New("没有传递执行器")
 	}
@@ -259,24 +263,22 @@ func Update[T any](executor SqlExecutor, tableName string, e T, callback UpdateC
 	return nil
 }
 
-type QueryCallback[T any] func(e T, fieldName string, columnName string, value any) (retConditionOp string, retConditionValue any, err error)
-
-func Query[T any](executor SqlExecutor, tableName string, e T, pageNo int, pageSize int, callback QueryCallback[T]) ([]map[string]any, error) {
+func Query[T any](executor SqlExecutor, tableName string, e T, pageNo int, pageSize int, callback ConditionCallback[T]) ([]map[string]any, int64, error) {
 	if executor == nil {
-		return nil, errors.New("没有传递执行器")
+		return nil, 0, errors.New("没有传递执行器")
 	}
 
 	if strutils.IsStringEmpty(tableName) {
-		return nil, errors.New("没有传递表名")
+		return nil, 0, errors.New("没有传递表名")
 	}
 
 	if reflect.TypeOf(e) == nil {
-		return nil, errors.New("没有传递实体")
+		return nil, 0, errors.New("没有传递实体")
 	}
 
 	sqlMapping, err := tag.ParseSqlMapping(e)
 	if err != nil {
-		return nil, err
+		return nil, 0, err
 	}
 
 	var offset int
@@ -292,6 +294,10 @@ func Query[T any](executor SqlExecutor, tableName string, e T, pageNo int, pageS
 		Offset:    offset,
 	}
 
+	countParams := raw_sql_tpl.CountExecuteParams{
+		TableName: tableName,
+	}
+
 	for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap {
 		if !sqlMappingColumn.CanQuery {
 			continue
@@ -308,17 +314,17 @@ func Query[T any](executor SqlExecutor, tableName string, e T, pageNo int, pageS
 
 		if sqlMappingColumn.QueryCallback {
 			if callback == nil {
-				return nil, errors.New("需要使用回调函数但是没有传递回调函数")
+				return nil, 0, errors.New("需要使用回调函数但是没有传递回调函数")
 			}
 
 			retConditionOp, retConditionValue, err := callback(e, fieldName, sqlMappingColumn.Name, conditionValue)
 			if err != nil {
-				return nil, err
+				return nil, 0, err
 			}
 
 			retValueType := reflect.TypeOf(retConditionValue)
 			if retValueType == nil || retValueType.Kind() == reflect.Ptr {
-				return nil, errors.New("返回应当为值类型")
+				return nil, 0, errors.New("返回应当为值类型")
 			}
 
 			conditionValue = retConditionValue
@@ -327,16 +333,91 @@ func Query[T any](executor SqlExecutor, tableName string, e T, pageNo int, pageS
 
 		tableRowValue, err := parseValue(conditionValue)
 		if err != nil {
-			return nil, err
+			return nil, 0, err
 		}
 
-		if sqlMappingColumn.IsKey {
-			executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{
-				Column:   sqlMappingColumn.Name,
-				Operator: conditionOp,
-				Value:    tableRowValue,
-			})
+		executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{
+			Column:   sqlMappingColumn.Name,
+			Operator: conditionOp,
+			Value:    tableRowValue,
+		})
+
+		countParams.Conditions = append(countParams.Conditions, raw_sql_tpl.Condition{
+			Column:   sqlMappingColumn.Name,
+			Operator: conditionOp,
+			Value:    tableRowValue,
+		})
+	}
+
+	executeParamsMap, err := executeParams.Map()
+	if err != nil {
+		return nil, 0, err
+	}
+
+	countParamsMap, err := countParams.Map()
+	if err != nil {
+		return nil, 0, err
+	}
+
+	tableRows, err := executor.ExecuteRawSql(raw_sql_tpl.QueryTpl, executeParamsMap)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	countTableRow, err := executor.ExecuteRawSql(raw_sql_tpl.CountTpl, countParamsMap)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	return tableRows, int64(countTableRow[0]["count"].(float64)), nil
+}
+
+func QueryByKeys[T any](executor SqlExecutor, tableName string, e T) (map[string]any, error) {
+	if executor == nil {
+		return nil, errors.New("没有传递执行器")
+	}
+
+	if strutils.IsStringEmpty(tableName) {
+		return nil, errors.New("没有传递表名")
+	}
+
+	if reflect.TypeOf(e) == nil {
+		return nil, errors.New("没有传递实体")
+	}
+
+	sqlMapping, err := tag.ParseSqlMapping(e)
+	if err != nil {
+		return nil, err
+	}
+
+	executeParams := raw_sql_tpl.QueryExecuteParams{
+		TableName: tableName,
+		Limit:     0,
+		Offset:    0,
+	}
+
+	for _, sqlMappingColumn := range sqlMapping.ColumnMap {
+		if !sqlMappingColumn.IsKey {
+			continue
+		}
+
+		fieldType := sqlMappingColumn.ValueFieldType
+
+		conditionValue := reflect.Zero(fieldType).Interface()
+		if !sqlMappingColumn.ValueFieldValue.IsZero() {
+			conditionValue = sqlMappingColumn.ValueFieldValue.Interface()
+		}
+
+		tableRowValue, err := parseValue(conditionValue)
+		if err != nil {
+			return nil, err
 		}
+
+		executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{
+			Column:   sqlMappingColumn.Name,
+			Operator: "=",
+			Value:    tableRowValue,
+		})
 	}
 
 	executeParamsMap, err := executeParams.Map()
@@ -349,7 +430,298 @@ func Query[T any](executor SqlExecutor, tableName string, e T, pageNo int, pageS
 		return nil, err
 	}
 
-	return tableRows, nil
+	if tableRows == nil || len(tableRows) == 0 {
+		return nil, ErrDBRecordNotExist
+	}
+
+	return tableRows[0], nil
+}
+
+func Count[T any](executor SqlExecutor, tableName string, e T, callback ConditionCallback[T]) (int64, error) {
+	if executor == nil {
+		return 0, errors.New("没有传递执行器")
+	}
+
+	if strutils.IsStringEmpty(tableName) {
+		return 0, errors.New("没有传递表名")
+	}
+
+	if reflect.TypeOf(e) == nil {
+		return 0, errors.New("没有传递实体")
+	}
+
+	sqlMapping, err := tag.ParseSqlMapping(e)
+	if err != nil {
+		return 0, err
+	}
+
+	executeParams := raw_sql_tpl.CountExecuteParams{
+		TableName: tableName,
+	}
+
+	for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap {
+		fieldType := sqlMappingColumn.ValueFieldType
+
+		conditionValue := reflect.Zero(fieldType).Interface()
+		if !sqlMappingColumn.ValueFieldValue.IsZero() {
+			conditionValue = sqlMappingColumn.ValueFieldValue.Interface()
+		}
+
+		conditionOp := "="
+
+		if sqlMappingColumn.CountCallback {
+			if callback == nil {
+				return 0, errors.New("需要使用回调函数但是没有传递回调函数")
+			}
+
+			retConditionOp, retConditionValue, err := callback(e, fieldName, sqlMappingColumn.Name, conditionValue)
+			if err != nil {
+				return 0, err
+			}
+
+			retValueType := reflect.TypeOf(retConditionValue)
+			if retValueType == nil || retValueType.Kind() == reflect.Ptr {
+				return 0, errors.New("返回应当为值类型")
+			}
+
+			conditionValue = retConditionValue
+			conditionOp = retConditionOp
+		}
+
+		tableRowValue, err := parseValue(conditionValue)
+		if err != nil {
+			return 0, err
+		}
+
+		executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{
+			Column:   sqlMappingColumn.Name,
+			Operator: conditionOp,
+			Value:    tableRowValue,
+		})
+	}
+
+	executeParamsMap, err := executeParams.Map()
+	if err != nil {
+		return 0, err
+	}
+
+	tableRows, err := executor.ExecuteRawSql(raw_sql_tpl.CountTpl, executeParamsMap)
+	if err != nil {
+		return 0, err
+	}
+
+	return int64(tableRows[0]["count"].(float64)), nil
+}
+
+func CheckExist[T any](executor SqlExecutor, tableName string, e T, callback ConditionCallback[T]) (bool, error) {
+	if executor == nil {
+		return false, errors.New("没有传递执行器")
+	}
+
+	if strutils.IsStringEmpty(tableName) {
+		return false, errors.New("没有传递表名")
+	}
+
+	if reflect.TypeOf(e) == nil {
+		return false, errors.New("没有传递实体")
+	}
+
+	sqlMapping, err := tag.ParseSqlMapping(e)
+	if err != nil {
+		return false, err
+	}
+
+	executeParams := raw_sql_tpl.CountExecuteParams{
+		TableName: tableName,
+	}
+
+	for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap {
+		fieldType := sqlMappingColumn.ValueFieldType
+
+		conditionValue := reflect.Zero(fieldType).Interface()
+		if !sqlMappingColumn.ValueFieldValue.IsZero() {
+			conditionValue = sqlMappingColumn.ValueFieldValue.Interface()
+		}
+
+		conditionOp := "="
+
+		if sqlMappingColumn.CheckExistCallback {
+			if callback == nil {
+				return false, errors.New("需要使用回调函数但是没有传递回调函数")
+			}
+
+			retConditionOp, retConditionValue, err := callback(e, fieldName, sqlMappingColumn.Name, conditionValue)
+			if err != nil {
+				return false, err
+			}
+
+			retValueType := reflect.TypeOf(retConditionValue)
+			if retValueType == nil || retValueType.Kind() == reflect.Ptr {
+				return false, errors.New("返回应当为值类型")
+			}
+
+			conditionValue = retConditionValue
+			conditionOp = retConditionOp
+		}
+
+		tableRowValue, err := parseValue(conditionValue)
+		if err != nil {
+			return false, err
+		}
+
+		executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{
+			Column:   sqlMappingColumn.Name,
+			Operator: conditionOp,
+			Value:    tableRowValue,
+		})
+	}
+
+	executeParamsMap, err := executeParams.Map()
+	if err != nil {
+		return false, err
+	}
+
+	tableRows, err := executor.ExecuteRawSql(raw_sql_tpl.CountTpl, executeParamsMap)
+	if err != nil {
+		return false, err
+	}
+
+	return int64(tableRows[0]["count"].(float64)) > 0, nil
+}
+
+func CheckExistByKey[T any](executor SqlExecutor, tableName string, e T) (bool, error) {
+	if executor == nil {
+		return false, errors.New("没有传递执行器")
+	}
+
+	if strutils.IsStringEmpty(tableName) {
+		return false, errors.New("没有传递表名")
+	}
+
+	if reflect.TypeOf(e) == nil {
+		return false, errors.New("没有传递实体")
+	}
+
+	sqlMapping, err := tag.ParseSqlMapping(e)
+	if err != nil {
+		return false, err
+	}
+
+	executeParams := raw_sql_tpl.CountExecuteParams{
+		TableName: tableName,
+	}
+
+	for _, sqlMappingColumn := range sqlMapping.ColumnMap {
+		if !sqlMappingColumn.IsKey {
+			continue
+		}
+
+		fieldType := sqlMappingColumn.ValueFieldType
+
+		conditionValue := reflect.Zero(fieldType).Interface()
+		if !sqlMappingColumn.ValueFieldValue.IsZero() {
+			conditionValue = sqlMappingColumn.ValueFieldValue.Interface()
+		}
+
+		tableRowValue, err := parseValue(conditionValue)
+		if err != nil {
+			return false, err
+		}
+
+		executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{
+			Column:   sqlMappingColumn.Name,
+			Operator: "=",
+			Value:    tableRowValue,
+		})
+	}
+
+	executeParamsMap, err := executeParams.Map()
+	if err != nil {
+		return false, err
+	}
+
+	tableRows, err := executor.ExecuteRawSql(raw_sql_tpl.CountTpl, executeParamsMap)
+	if err != nil {
+		return false, err
+	}
+
+	return int64(tableRows[0]["count"].(float64)) > 0, nil
+}
+
+func CheckHasOnlyOne[T any](executor SqlExecutor, tableName string, e T, callback ConditionCallback[T]) (bool, error) {
+	if executor == nil {
+		return false, errors.New("没有传递执行器")
+	}
+
+	if strutils.IsStringEmpty(tableName) {
+		return false, errors.New("没有传递表名")
+	}
+
+	if reflect.TypeOf(e) == nil {
+		return false, errors.New("没有传递实体")
+	}
+
+	sqlMapping, err := tag.ParseSqlMapping(e)
+	if err != nil {
+		return false, err
+	}
+
+	executeParams := raw_sql_tpl.CountExecuteParams{
+		TableName: tableName,
+	}
+
+	for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap {
+		fieldType := sqlMappingColumn.ValueFieldType
+
+		conditionValue := reflect.Zero(fieldType).Interface()
+		if !sqlMappingColumn.ValueFieldValue.IsZero() {
+			conditionValue = sqlMappingColumn.ValueFieldValue.Interface()
+		}
+
+		conditionOp := "="
+
+		if sqlMappingColumn.QueryCallback {
+			if callback == nil {
+				return false, errors.New("需要使用回调函数但是没有传递回调函数")
+			}
+
+			retConditionOp, retConditionValue, err := callback(e, fieldName, sqlMappingColumn.Name, conditionValue)
+			if err != nil {
+				return false, err
+			}
+
+			retValueType := reflect.TypeOf(retConditionValue)
+			if retValueType == nil || retValueType.Kind() == reflect.Ptr {
+				return false, errors.New("返回应当为值类型")
+			}
+
+			conditionValue = retConditionValue
+			conditionOp = retConditionOp
+		}
+
+		tableRowValue, err := parseValue(conditionValue)
+		if err != nil {
+			return false, err
+		}
+
+		executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{
+			Column:   sqlMappingColumn.Name,
+			Operator: conditionOp,
+			Value:    tableRowValue,
+		})
+	}
+
+	executeParamsMap, err := executeParams.Map()
+	if err != nil {
+		return false, err
+	}
+
+	tableRows, err := executor.ExecuteRawSql(raw_sql_tpl.CountTpl, executeParamsMap)
+	if err != nil {
+		return false, err
+	}
+
+	return int64(tableRows[0]["count"].(float64)) == 1, nil
 }
 
 func ExecuteRawSql(executor SqlExecutor, sql string, executeParams map[string]any) ([]map[string]any, error) {

+ 45 - 30
sdk/tag/sql_mapping.go

@@ -13,16 +13,19 @@ const (
 )
 
 const (
-	sqlMappingTagKey         = "sqlmapping"
-	sqlMappingIgnore         = "-"
-	sqlMappingColumn         = "column"
-	sqlMappingKey            = "key"
-	sqlMappingNotUpdate      = "notUpdate"
-	sqlMappingUpdateClear    = "updateClear"
-	sqlMappingNotQuery       = "notQuery"
-	sqlMappingInsertCallback = "insertCallback"
-	sqlMappingUpdateCallback = "updateCallback"
-	sqlMappingQueryCallback  = "queryCallback"
+	sqlMappingTagKey                  = "sqlmapping"
+	sqlMappingIgnore                  = "-"
+	sqlMappingColumn                  = "column"
+	sqlMappingKey                     = "key"
+	sqlMappingNotUpdate               = "notUpdate"
+	sqlMappingUpdateClear             = "updateClear"
+	sqlMappingNotQuery                = "notQuery"
+	sqlMappingInsertCallback          = "insertCallback"
+	sqlMappingUpdateCallback          = "updateCallback"
+	sqlMappingQueryCallback           = "queryCallback"
+	sqlMappingCountCallback           = "countCallback"
+	sqlMappingCheckExistCallback      = "checkExistCallback"
+	sqlMappingCheckHasOnlyOneCallback = "checkHasOnlyCallback"
 )
 
 type SqlMapping struct {
@@ -72,14 +75,17 @@ func ParseSqlMapping(e any) (*SqlMapping, error) {
 }
 
 type SqlMappingColumn struct {
-	Name           string
-	IsKey          bool
-	CanUpdate      bool
-	CanUpdateClear bool
-	CanQuery       bool
-	InsertCallback bool
-	UpdateCallback bool
-	QueryCallback  bool
+	Name                    string
+	IsKey                   bool
+	CanUpdate               bool
+	CanUpdateClear          bool
+	CanQuery                bool
+	InsertCallback          bool
+	UpdateCallback          bool
+	QueryCallback           bool
+	CountCallback           bool
+	CheckExistCallback      bool
+	CheckHasOnlyOneCallback bool
 
 	// 原字段的反射结构
 	OriginFieldType  reflect.Type
@@ -105,18 +111,21 @@ func parseSqlMappingColumn(field reflect.StructField, fieldValue reflect.Value)
 	}
 
 	sqlColumn := &SqlMappingColumn{
-		Name:             strcase.ToSnake(field.Name),
-		IsKey:            false,
-		CanUpdate:        true,
-		CanUpdateClear:   false,
-		CanQuery:         true,
-		InsertCallback:   false,
-		UpdateCallback:   false,
-		QueryCallback:    false,
-		OriginFieldType:  field.Type,
-		OriginFieldValue: fieldValue,
-		ValueFieldType:   valueFieldType,
-		ValueFieldValue:  valueFieldValue,
+		Name:                    strcase.ToSnake(field.Name),
+		IsKey:                   false,
+		CanUpdate:               true,
+		CanUpdateClear:          false,
+		CanQuery:                true,
+		InsertCallback:          false,
+		UpdateCallback:          false,
+		QueryCallback:           false,
+		CountCallback:           false,
+		CheckExistCallback:      false,
+		CheckHasOnlyOneCallback: false,
+		OriginFieldType:         field.Type,
+		OriginFieldValue:        fieldValue,
+		ValueFieldType:          valueFieldType,
+		ValueFieldValue:         valueFieldValue,
 	}
 
 	sqlMappingTag, ok := field.Tag.Lookup(sqlMappingTagKey)
@@ -150,6 +159,12 @@ func parseSqlMappingColumn(field reflect.StructField, fieldValue reflect.Value)
 				sqlColumn.UpdateCallback = true
 			case sqlMappingQueryCallback:
 				sqlColumn.QueryCallback = true
+			case sqlMappingCountCallback:
+				sqlColumn.CountCallback = true
+			case sqlMappingCheckExistCallback:
+				sqlColumn.CheckExistCallback = true
+			case sqlMappingCheckHasOnlyOneCallback:
+				sqlColumn.CheckHasOnlyOneCallback = true
 			default:
 				continue
 			}

+ 13 - 1
test/sdk_test.go

@@ -17,7 +17,7 @@ import (
 type Class struct {
 	ID              string    `sqlmapping:"key;"`
 	Name            string    `sqlmapping:"updateClear;notQuery;insertCallback;updateCallback;"`
-	StudentNum      int       `sqlmapping:"column:student_num;notUpdate;queryCallback;" sqlresult:"column:student_num_alias;"`
+	StudentNum      int       `sqlmapping:"column:student_num;notUpdate;queryCallback;countCallback;checkExistCallback;checkHasOnlyCallback;" sqlresult:"column:student_num_alias;"`
 	GraduatedTime   time.Time `sqlresult:"callback"`
 	CreatedTime     *time.Time
 	LastUpdatedTime time.Time
@@ -450,6 +450,18 @@ func TestSqlMapping(t *testing.T) {
 		if sqlColumn.QueryCallback && sqlColumn.Name != "student_num" {
 			t.Fatal("查询回调不正确")
 		}
+
+		if sqlColumn.CountCallback && sqlColumn.Name != "student_num" {
+			t.Fatal("计数回调不正确")
+		}
+
+		if sqlColumn.CheckExistCallback && sqlColumn.Name != "student_num" {
+			t.Fatal("检查存在性回调不正确")
+		}
+
+		if sqlColumn.CheckHasOnlyOneCallback && sqlColumn.Name != "student_num" {
+			t.Fatal("检查唯一性回调不正确")
+		}
 	}
 }