crud.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. package sdk
  2. import (
  3. "errors"
  4. "git.sxidc.com/go-tools/utils/strutils"
  5. "git.sxidc.com/service-supports/ds-sdk/sdk/raw_sql_tpl"
  6. "git.sxidc.com/service-supports/ds-sdk/sdk/tag"
  7. "reflect"
  8. "strconv"
  9. "time"
  10. )
  11. const (
  12. timeWriteFormat = time.DateTime + ".000000 +08:00"
  13. createdTimeFieldName = "CreatedTime"
  14. lastUpdatedTimeFieldName = "LastUpdatedTime"
  15. )
  16. type InsertCallback[T any] func(e T, fieldName string, value any) (retValue any, err error)
  17. func Insert[T any](sdk *SDK, tableName string, e T, callback InsertCallback[T]) error {
  18. if sdk == nil {
  19. return errors.New("没有传递sdk")
  20. }
  21. if strutils.IsStringEmpty(tableName) {
  22. return errors.New("没有传递表名")
  23. }
  24. if reflect.TypeOf(e) == nil {
  25. return errors.New("没有传递实体")
  26. }
  27. sqlMapping, err := tag.ParseSqlMapping(e)
  28. if err != nil {
  29. return err
  30. }
  31. executeParams := raw_sql_tpl.InsertExecuteParams{
  32. TableName: tableName,
  33. }
  34. now := time.Now()
  35. for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap {
  36. fieldType := sqlMappingColumn.FieldType
  37. value := reflect.Zero(fieldType).Interface()
  38. if !sqlMappingColumn.FieldValue.IsZero() {
  39. value = sqlMappingColumn.FieldValue.Interface()
  40. }
  41. if sqlMappingColumn.InsertCallback {
  42. if callback == nil {
  43. return errors.New("需要使用回调函数但是没有传递回调函数")
  44. }
  45. retValue, err := callback(e, fieldName, value)
  46. if err != nil {
  47. return err
  48. }
  49. retValueType := reflect.TypeOf(retValue)
  50. if retValueType == nil || retValueType.Kind() == reflect.Ptr {
  51. return errors.New("返回应当为值类型")
  52. }
  53. value = retValue
  54. }
  55. if (fieldName == createdTimeFieldName || fieldName == lastUpdatedTimeFieldName) &&
  56. fieldType.String() == "time.Time" && value.(time.Time).IsZero() {
  57. value = now
  58. }
  59. tableRowValue, err := parseValue(value)
  60. if err != nil {
  61. return err
  62. }
  63. executeParams.TableRows = append(executeParams.TableRows, raw_sql_tpl.TableRow{
  64. Column: sqlMappingColumn.Name,
  65. Value: tableRowValue,
  66. })
  67. }
  68. executeParamsMap, err := executeParams.Map()
  69. if err != nil {
  70. return err
  71. }
  72. _, err = sdk.ExecuteRawSql(raw_sql_tpl.InsertTpl, executeParamsMap)
  73. if err != nil {
  74. return err
  75. }
  76. return nil
  77. }
  78. func parseValue(value any) (string, error) {
  79. switch v := value.(type) {
  80. case string:
  81. return "'" + v + "'", nil
  82. case bool:
  83. return strconv.FormatBool(v), nil
  84. case time.Time:
  85. return "'" + v.Format(timeWriteFormat) + "'", nil
  86. case int:
  87. return strconv.Itoa(v), nil
  88. case int8:
  89. return strconv.FormatInt(int64(v), 10), nil
  90. case int16:
  91. return strconv.FormatInt(int64(v), 10), nil
  92. case int32:
  93. return strconv.FormatInt(int64(v), 10), nil
  94. case int64:
  95. return strconv.FormatInt(v, 10), nil
  96. case uint:
  97. return strconv.FormatUint(uint64(v), 10), nil
  98. case uint8:
  99. return strconv.FormatUint(uint64(v), 10), nil
  100. case uint16:
  101. return strconv.FormatUint(uint64(v), 10), nil
  102. case uint32:
  103. return strconv.FormatUint(uint64(v), 10), nil
  104. case uint64:
  105. return strconv.FormatUint(v, 10), nil
  106. default:
  107. return "", errors.New("不支持的类型")
  108. }
  109. }