sql_result_tag_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. package test
  2. import (
  3. "git.sxidc.com/go-framework/baize/framework/core/tag/sql/sql_result"
  4. "git.sxidc.com/go-tools/utils/encoding"
  5. "git.sxidc.com/go-tools/utils/strutils"
  6. "github.com/pkg/errors"
  7. "math/rand"
  8. "reflect"
  9. "strings"
  10. "testing"
  11. "time"
  12. )
  13. type SqlResultTagStruct struct {
  14. Ignore string `sqlresult:"-"`
  15. Column string `sqlresult:"column:test_column"`
  16. TimeLayout time.Time `sqlresult:"timeLayout:2006-01-02"`
  17. Aes string `sqlresult:"aes:@MKU^*HF%p%G43Fd)UAHCVD#$XZSWQ@L"`
  18. SplitWith []string `sqlresult:"splitWith:##"`
  19. Trim string `sqlresult:"trim:||"`
  20. TrimPrefix string `sqlresult:"trimPrefix:{{"`
  21. TrimSuffix string `sqlresult:"trimSuffix:}}"`
  22. WrongTimeLayout int `sqlresult:"timeLayout:2006-01-02"`
  23. WrongAes int `sqlresult:"aes:@MKU^*HF%p%G43Fd)UAHCVD#$XZSWQ@L"`
  24. WrongSplitWith int `sqlresult:"splitWith:##"`
  25. WrongTrim int `sqlresult:"trim:||"`
  26. WrongTrimPrefix int `sqlresult:"trimPrefix:{{"`
  27. WrongTrimSuffix int `sqlresult:"trimSuffix:}}"`
  28. }
  29. func (s SqlResultTagStruct) checkFields(t *testing.T, result map[string]any) {
  30. if len(result) != reflect.TypeOf(s).NumField()-1 {
  31. t.Fatalf("%+v\n", errors.Errorf("有字段没有被解析"))
  32. }
  33. for columnName, value := range result {
  34. if columnName == "ignore" {
  35. t.Fatalf("%+v\n", errors.Errorf("忽略字段没有被忽略"))
  36. }
  37. switch columnName {
  38. case "test_column":
  39. if s.Column != value {
  40. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  41. columnName, s.Column, value))
  42. }
  43. case "time_layout":
  44. resultValue := value.(string)
  45. if s.TimeLayout.Format(time.DateOnly)+"T00:00:00" != resultValue {
  46. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  47. columnName, s.TimeLayout.Format(time.DateOnly), resultValue))
  48. }
  49. case "aes":
  50. resultValue := value.(string)
  51. decrypted, err := encoding.AESDecrypt(resultValue, "@MKU^*HF%p%G43Fd)UAHCVD#$XZSWQ@L")
  52. if err != nil {
  53. t.Fatalf("%+v\n", errors.Errorf(err.Error()))
  54. }
  55. if s.Aes != decrypted {
  56. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  57. columnName, s.Aes, decrypted))
  58. }
  59. case "split_with":
  60. resultValue := value.(string)
  61. if strings.Join(s.SplitWith, "##") != resultValue {
  62. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  63. columnName, strings.Join(s.SplitWith, "##"), resultValue))
  64. }
  65. case "trim":
  66. resultValue := value.(string)
  67. if s.Trim != strings.Trim(resultValue, "||") {
  68. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  69. columnName, s.Trim, strings.Trim(resultValue, "||")))
  70. }
  71. case "trim_prefix":
  72. resultValue := value.(string)
  73. if s.TrimPrefix != strings.TrimPrefix(resultValue, "{{") {
  74. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  75. columnName, s.TrimPrefix, strings.TrimPrefix(resultValue, "{{")))
  76. }
  77. case "trim_suffix":
  78. resultValue := value.(string)
  79. if s.TrimSuffix != strings.TrimSuffix(resultValue, "}}") {
  80. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  81. columnName, s.TrimSuffix, strings.TrimSuffix(resultValue, "}}")))
  82. }
  83. case "wrong_time_layout":
  84. resultValue := value.(int)
  85. if s.WrongTimeLayout != resultValue {
  86. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  87. columnName, s.WrongTimeLayout, resultValue))
  88. }
  89. case "wrong_aes":
  90. resultValue := value.(int)
  91. if s.WrongAes != resultValue {
  92. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  93. columnName, s.WrongAes, resultValue))
  94. }
  95. case "wrong_split_with":
  96. resultValue := value.(int)
  97. if s.WrongSplitWith != resultValue {
  98. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  99. columnName, s.WrongSplitWith, resultValue))
  100. }
  101. case "wrong_trim":
  102. resultValue := value.(int)
  103. if s.WrongTrim != resultValue {
  104. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  105. columnName, s.WrongTrim, resultValue))
  106. }
  107. case "wrong_trim_prefix":
  108. resultValue := value.(int)
  109. if s.WrongTrimPrefix != resultValue {
  110. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  111. columnName, s.WrongTrimPrefix, resultValue))
  112. }
  113. case "wrong_trim_suffix":
  114. resultValue := value.(int)
  115. if s.WrongTrimSuffix != resultValue {
  116. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  117. columnName, s.WrongTrimSuffix, resultValue))
  118. }
  119. default:
  120. t.Fatalf("%+v\n", errors.Errorf("未知的列: tag columnName: %v", columnName))
  121. }
  122. }
  123. }
  124. type SqlResultTagPointFieldsStruct struct {
  125. Ignore *string `sqlresult:"-"`
  126. Column *string `sqlresult:"column:test_column"`
  127. TimeLayout *time.Time `sqlresult:"timeLayout:2006-01-02"`
  128. Aes *string `sqlresult:"aes:@MKU^*HF%p%G43Fd)UAHCVD#$XZSWQ@L"`
  129. SplitWith *[]string `sqlresult:"splitWith:##"`
  130. Trim *string `sqlresult:"trim:||"`
  131. TrimPrefix *string `sqlresult:"trimPrefix:{{"`
  132. TrimSuffix *string `sqlresult:"trimSuffix:}}"`
  133. WrongTimeLayout *int `sqlresult:"timeLayout:2006-01-02"`
  134. WrongAes *int `sqlresult:"aes:@MKU^*HF%p%G43Fd)UAHCVD#$XZSWQ@L"`
  135. WrongSplitWith *int `sqlresult:"splitWith:##"`
  136. WrongTrim *int `sqlresult:"trim:||"`
  137. WrongTrimPrefix *int `sqlresult:"trimPrefix:{{"`
  138. WrongTrimSuffix *int `sqlresult:"trimSuffix:}}"`
  139. }
  140. func (s SqlResultTagPointFieldsStruct) checkFields(t *testing.T, result map[string]any) {
  141. if len(result) != reflect.TypeOf(s).NumField()-1 {
  142. t.Fatalf("%+v\n", errors.Errorf("有字段没有被解析"))
  143. }
  144. for columnName, value := range result {
  145. if columnName == "ignore" {
  146. t.Fatalf("%+v\n", errors.Errorf("忽略字段没有被忽略"))
  147. }
  148. switch columnName {
  149. case "test_column":
  150. if *s.Column != value {
  151. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  152. columnName, *s.Column, value))
  153. }
  154. case "time_layout":
  155. resultValue := value.(string)
  156. if (*s.TimeLayout).Format(time.DateOnly)+"T00:00:00" != resultValue {
  157. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  158. columnName, (*s.TimeLayout).Format(time.DateOnly), resultValue))
  159. }
  160. case "aes":
  161. resultValue := value.(string)
  162. decrypted, err := encoding.AESDecrypt(resultValue, "@MKU^*HF%p%G43Fd)UAHCVD#$XZSWQ@L")
  163. if err != nil {
  164. t.Fatalf("%+v\n", errors.Errorf(err.Error()))
  165. }
  166. if *s.Aes != decrypted {
  167. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  168. columnName, *s.Aes, decrypted))
  169. }
  170. case "split_with":
  171. resultValue := value.(string)
  172. if strings.Join(*s.SplitWith, "##") != resultValue {
  173. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  174. columnName, strings.Join(*s.SplitWith, "##"), resultValue))
  175. }
  176. case "trim":
  177. resultValue := value.(string)
  178. if *s.Trim != strings.Trim(resultValue, "||") {
  179. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  180. columnName, s.Trim, strings.Trim(resultValue, "||")))
  181. }
  182. case "trim_prefix":
  183. resultValue := value.(string)
  184. if *s.TrimPrefix != strings.TrimPrefix(resultValue, "{{") {
  185. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  186. columnName, s.TrimPrefix, strings.TrimPrefix(resultValue, "{{")))
  187. }
  188. case "trim_suffix":
  189. resultValue := value.(string)
  190. if *s.TrimSuffix != strings.TrimSuffix(resultValue, "}}") {
  191. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  192. columnName, s.TrimSuffix, strings.TrimSuffix(resultValue, "}}")))
  193. }
  194. case "wrong_time_layout":
  195. resultValue := value.(int)
  196. if *s.WrongTimeLayout != resultValue {
  197. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  198. columnName, s.WrongTimeLayout, resultValue))
  199. }
  200. case "wrong_aes":
  201. resultValue := value.(int)
  202. if *s.WrongAes != resultValue {
  203. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  204. columnName, s.WrongAes, resultValue))
  205. }
  206. case "wrong_split_with":
  207. resultValue := value.(int)
  208. if *s.WrongSplitWith != resultValue {
  209. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  210. columnName, s.WrongSplitWith, resultValue))
  211. }
  212. case "wrong_trim":
  213. resultValue := value.(int)
  214. if *s.WrongTrim != resultValue {
  215. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  216. columnName, s.WrongTrim, resultValue))
  217. }
  218. case "wrong_trim_prefix":
  219. resultValue := value.(int)
  220. if *s.WrongTrimPrefix != resultValue {
  221. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  222. columnName, s.WrongTrimPrefix, resultValue))
  223. }
  224. case "wrong_trim_suffix":
  225. resultValue := value.(int)
  226. if *s.WrongTrimSuffix != resultValue {
  227. t.Fatalf("%+v\n", errors.Errorf("值不正确: columnName: %v, fieldValue %v, resultValue %v",
  228. columnName, s.WrongTrimSuffix, resultValue))
  229. }
  230. default:
  231. t.Fatalf("%+v\n", errors.Errorf("未知的列: tag columnName: %v", columnName))
  232. }
  233. }
  234. }
  235. func TestSqlResultTagDefaultUsage(t *testing.T) {
  236. aesEncrypted, err := encoding.AESEncrypt(strutils.SimpleUUID(), "@MKU^*HF%p%G43Fd)UAHCVD#$XZSWQ@L")
  237. if err != nil {
  238. t.Fatalf("%+v\n", errors.Errorf(err.Error()))
  239. }
  240. result := map[string]any{
  241. "test_column": strutils.SimpleUUID(),
  242. "time_layout": "2024-07-03T00:00:00",
  243. "aes": aesEncrypted,
  244. "split_with": strings.Join([]string{strutils.SimpleUUID(), strutils.SimpleUUID()}, "##"),
  245. "trim": "||" + strutils.SimpleUUID() + "||",
  246. "trim_prefix": "{{" + strutils.SimpleUUID(),
  247. "trim_suffix": strutils.SimpleUUID() + "}}",
  248. "wrong_time_layout": rand.Int(),
  249. "wrong_aes": rand.Int(),
  250. "wrong_split_with": rand.Int(),
  251. "wrong_trim": rand.Int(),
  252. "wrong_trim_prefix": rand.Int(),
  253. "wrong_trim_suffix": rand.Int(),
  254. }
  255. s := SqlResultTagStruct{}
  256. sPointerField := SqlResultTagPointFieldsStruct{}
  257. err = sql_result.DefaultUsage(result, s, "")
  258. if err == nil || err.Error() != "参数不是结构指针" {
  259. t.Fatalf("%+v\n", errors.Errorf("没有检测出参数必须是指针类型"))
  260. }
  261. err = sql_result.DefaultUsage(result, sPointerField, "")
  262. if err == nil || err.Error() != "参数不是结构指针" {
  263. t.Fatalf("%+v\n", errors.Errorf("没有检测出参数必须是指针类型"))
  264. }
  265. err = sql_result.DefaultUsage(result, &s, "")
  266. if err != nil {
  267. t.Fatalf("%+v\n", err)
  268. }
  269. s.checkFields(t, result)
  270. err = sql_result.DefaultUsage(result, &sPointerField, "")
  271. if err != nil {
  272. t.Fatalf("%+v\n", err)
  273. }
  274. sPointerField.checkFields(t, result)
  275. }