sql.go 6.5 KB


  1. package sdk
  2. import (
  3. "errors"
  4. "git.sxidc.com/go-tools/utils/strutils"
  5. "git.sxidc.com/service-supports/ds-sdk/sdk/raw_sql_tpl"
  6. "git.sxidc.com/service-supports/ds-sdk/sdk/tag"
  7. "reflect"
  8. "strconv"
  9. "time"
  10. )
  11. type RawSqlExecutor interface {
  12. ExecuteRawSql(sql string, executeParams map[string]any) ([]map[string]any, error)
  13. }
  14. const (
  15. timeWriteFormat = time.DateTime + ".000000 +08:00"
  16. createdTimeFieldName = "CreatedTime"
  17. lastUpdatedTimeFieldName = "LastUpdatedTime"
  18. )
  19. type InsertCallback[T any] func(e T, fieldName string, value any) (retValue any, err error)
  20. func Insert[T any](executor RawSqlExecutor, tableName string, e T, callback InsertCallback[T]) error {
  21. if executor == nil {
  22. return errors.New("没有传递执行器")
  23. }
  24. if strutils.IsStringEmpty(tableName) {
  25. return errors.New("没有传递表名")
  26. }
  27. if reflect.TypeOf(e) == nil {
  28. return errors.New("没有传递实体")
  29. }
  30. sqlMapping, err := tag.ParseSqlMapping(e)
  31. if err != nil {
  32. return err
  33. }
  34. executeParams := raw_sql_tpl.InsertExecuteParams{
  35. TableName: tableName,
  36. }
  37. now := time.Now()
  38. for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap {
  39. fieldType := sqlMappingColumn.ValueFieldType
  40. value := reflect.Zero(fieldType).Interface()
  41. if !sqlMappingColumn.ValueFieldValue.IsZero() {
  42. value = sqlMappingColumn.ValueFieldValue.Interface()
  43. }
  44. if sqlMappingColumn.InsertCallback {
  45. if callback == nil {
  46. return errors.New("需要使用回调函数但是没有传递回调函数")
  47. }
  48. retValue, err := callback(e, fieldName, value)
  49. if err != nil {
  50. return err
  51. }
  52. retValueType := reflect.TypeOf(retValue)
  53. if retValueType == nil || retValueType.Kind() == reflect.Ptr {
  54. return errors.New("返回应当为值类型")
  55. }
  56. value = retValue
  57. }
  58. if (fieldName == createdTimeFieldName || fieldName == lastUpdatedTimeFieldName) &&
  59. fieldType.String() == "time.Time" && value.(time.Time).IsZero() {
  60. value = now
  61. }
  62. tableRowValue, err := parseValue(value)
  63. if err != nil {
  64. return err
  65. }
  66. executeParams.TableRows = append(executeParams.TableRows, raw_sql_tpl.TableRow{
  67. Column: sqlMappingColumn.Name,
  68. Value: tableRowValue,
  69. })
  70. }
  71. executeParamsMap, err := executeParams.Map()
  72. if err != nil {
  73. return err
  74. }
  75. _, err = executor.ExecuteRawSql(raw_sql_tpl.InsertTpl, executeParamsMap)
  76. if err != nil {
  77. return err
  78. }
  79. return nil
  80. }
  81. func Delete[T any](executor RawSqlExecutor, tableName string, e T) error {
  82. if executor == nil {
  83. return errors.New("没有传递执行器")
  84. }
  85. if strutils.IsStringEmpty(tableName) {
  86. return errors.New("没有传递表名")
  87. }
  88. if reflect.TypeOf(e) == nil {
  89. return errors.New("没有传递实体")
  90. }
  91. sqlMapping, err := tag.ParseSqlMapping(e)
  92. if err != nil {
  93. return err
  94. }
  95. executeParams := raw_sql_tpl.DeleteExecuteParams{
  96. TableName: tableName,
  97. }
  98. for _, sqlMappingColumn := range sqlMapping.ColumnMap {
  99. if !sqlMappingColumn.IsKey {
  100. continue
  101. }
  102. fieldType := sqlMappingColumn.ValueFieldType
  103. value := reflect.Zero(fieldType).Interface()
  104. if !sqlMappingColumn.ValueFieldValue.IsZero() {
  105. value = sqlMappingColumn.ValueFieldValue.Interface()
  106. }
  107. tableRowValue, err := parseValue(value)
  108. if err != nil {
  109. return err
  110. }
  111. executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{
  112. Column: sqlMappingColumn.Name,
  113. Operator: "=",
  114. Value: tableRowValue,
  115. })
  116. }
  117. executeParamsMap, err := executeParams.Map()
  118. if err != nil {
  119. return err
  120. }
  121. _, err = executor.ExecuteRawSql(raw_sql_tpl.DeleteTpl, executeParamsMap)
  122. if err != nil {
  123. return err
  124. }
  125. return nil
  126. }
  127. type UpdateCallback[T any] func(e T, fieldName string, value any) (retValue any, err error)
  128. func Update[T any](executor RawSqlExecutor, tableName string, e T, callback UpdateCallback[T]) error {
  129. if executor == nil {
  130. return errors.New("没有传递执行器")
  131. }
  132. if strutils.IsStringEmpty(tableName) {
  133. return errors.New("没有传递表名")
  134. }
  135. if reflect.TypeOf(e) == nil {
  136. return errors.New("没有传递实体")
  137. }
  138. sqlMapping, err := tag.ParseSqlMapping(e)
  139. if err != nil {
  140. return err
  141. }
  142. executeParams := raw_sql_tpl.UpdateExecuteParams{
  143. TableName: tableName,
  144. }
  145. now := time.Now()
  146. for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap {
  147. fieldType := sqlMappingColumn.ValueFieldType
  148. value := reflect.Zero(fieldType).Interface()
  149. if !sqlMappingColumn.ValueFieldValue.IsZero() {
  150. value = sqlMappingColumn.ValueFieldValue.Interface()
  151. }
  152. if sqlMappingColumn.InsertCallback {
  153. if callback == nil {
  154. return errors.New("需要使用回调函数但是没有传递回调函数")
  155. }
  156. retValue, err := callback(e, fieldName, value)
  157. if err != nil {
  158. return err
  159. }
  160. retValueType := reflect.TypeOf(retValue)
  161. if retValueType == nil || retValueType.Kind() == reflect.Ptr {
  162. return errors.New("返回应当为值类型")
  163. }
  164. value = retValue
  165. }
  166. if fieldName == lastUpdatedTimeFieldName &&
  167. fieldType.String() == "time.Time" && value.(time.Time).IsZero() {
  168. value = now
  169. }
  170. // 字段为空不更新
  171. if reflect.ValueOf(value).IsZero() && !sqlMappingColumn.CanUpdateClear {
  172. continue
  173. }
  174. tableRowValue, err := parseValue(value)
  175. if err != nil {
  176. return err
  177. }
  178. executeParams.TableRows = append(executeParams.TableRows, raw_sql_tpl.TableRow{
  179. Column: sqlMappingColumn.Name,
  180. Value: tableRowValue,
  181. })
  182. if sqlMappingColumn.IsKey {
  183. executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{
  184. Column: sqlMappingColumn.Name,
  185. Operator: "=",
  186. Value: tableRowValue,
  187. })
  188. }
  189. }
  190. executeParamsMap, err := executeParams.Map()
  191. if err != nil {
  192. return err
  193. }
  194. _, err = executor.ExecuteRawSql(raw_sql_tpl.UpdateTpl, executeParamsMap)
  195. if err != nil {
  196. return err
  197. }
  198. return nil
  199. }
  200. func parseValue(value any) (string, error) {
  201. switch v := value.(type) {
  202. case string:
  203. return "'" + v + "'", nil
  204. case bool:
  205. return strconv.FormatBool(v), nil
  206. case time.Time:
  207. return "'" + v.Format(timeWriteFormat) + "'", nil
  208. case int:
  209. return strconv.Itoa(v), nil
  210. case int8:
  211. return strconv.FormatInt(int64(v), 10), nil
  212. case int16:
  213. return strconv.FormatInt(int64(v), 10), nil
  214. case int32:
  215. return strconv.FormatInt(int64(v), 10), nil
  216. case int64:
  217. return strconv.FormatInt(v, 10), nil
  218. case uint:
  219. return strconv.FormatUint(uint64(v), 10), nil
  220. case uint8:
  221. return strconv.FormatUint(uint64(v), 10), nil
  222. case uint16:
  223. return strconv.FormatUint(uint64(v), 10), nil
  224. case uint32:
  225. return strconv.FormatUint(uint64(v), 10), nil
  226. case uint64:
  227. return strconv.FormatUint(v, 10), nil
  228. default:
  229. return "", errors.New("不支持的类型")
  230. }
  231. }