context.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680
  1. package api
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
  6. "github.com/gin-gonic/gin"
  7. "github.com/pkg/errors"
  8. "io"
  9. "mime/multipart"
  10. "net/textproto"
  11. "strings"
  12. )
  13. const (
  14. bodyKey = "body-context"
  15. queryParamsKey = "query-params-context"
  16. pathParamsKey = "path-params-context"
  17. headerKey = "header-context"
  18. tenantInfoKey = "context-tenant-info"
  19. userInfoKey = "context-user-info"
  20. )
  21. type Context struct {
  22. *gin.Context
  23. }
  24. // GetFileHeaderBytes 获取Multipart中传递的文件名和文件内容
  25. // 参数:
  26. // - fileHeader: Multipart的文件头
  27. // 返回值:
  28. // - 文件名
  29. // - 文件字节
  30. // - 错误
  31. func (c *Context) GetFileHeaderBytes(fileHeader *multipart.FileHeader) (string, []byte, error) {
  32. file, err := fileHeader.Open()
  33. if err != nil {
  34. return "", nil, errors.New(err.Error())
  35. }
  36. defer func(file multipart.File) {
  37. err := file.Close()
  38. if err != nil {
  39. logger.GetInstance().Error(errors.New(err.Error()))
  40. return
  41. }
  42. }(file)
  43. contentBytes, err := io.ReadAll(file)
  44. if err != nil {
  45. return "", nil, errors.New(err.Error())
  46. }
  47. return fileHeader.Filename, contentBytes, nil
  48. }
  49. // GetHeader 获取上下文中的Header
  50. // 参数: 无
  51. // 返回值:
  52. // - 上下文中的Header
  53. func (c *Context) GetHeader() *Header {
  54. savedHeader, exist := c.Get(headerKey)
  55. if exist {
  56. return &Header{
  57. c: c,
  58. header: savedHeader.(map[string]string),
  59. }
  60. }
  61. return &Header{
  62. c: c,
  63. header: c.ReadOriginHeader(),
  64. }
  65. }
  66. func (c *Context) ReadOriginHeader() map[string]string {
  67. originHeader := make(map[string]string)
  68. for key, values := range c.Request.Header {
  69. originHeader[key] = strings.Join(values, ",")
  70. }
  71. return originHeader
  72. }
  73. // GetBytesBody 获取上下文中的字节Body
  74. // 参数: 无
  75. // 返回值:
  76. // - 上下文中的字节Body
  77. // - 错误
  78. func (c *Context) GetBytesBody() (*BytesBody, error) {
  79. body, exist := c.Get(bodyKey)
  80. if !exist {
  81. bytesBody, err := c.ReadOriginBody()
  82. if err != nil {
  83. return nil, err
  84. }
  85. c.Set(bodyKey, bytesBody)
  86. return &BytesBody{
  87. c: c,
  88. bytesBody: bytesBody,
  89. }, nil
  90. }
  91. switch b := body.(type) {
  92. case []byte:
  93. return &BytesBody{
  94. c: c,
  95. bytesBody: b,
  96. }, nil
  97. case map[string]any:
  98. bytesBody, err := json.Marshal(b)
  99. if err != nil {
  100. return nil, errors.New(err.Error())
  101. }
  102. return &BytesBody{
  103. c: c,
  104. bytesBody: bytesBody,
  105. }, nil
  106. default:
  107. return nil, errors.New("不支持的body类型")
  108. }
  109. }
  110. // GetJsonBody 获取上下文中的JsonBody
  111. // 参数: 无
  112. // 返回值:
  113. // - 上下文中的JsonBody
  114. // - 错误
  115. func (c *Context) GetJsonBody() (*JsonBody, error) {
  116. body, exist := c.Get(bodyKey)
  117. if !exist {
  118. bytesBody, err := c.ReadOriginBody()
  119. if err != nil {
  120. return nil, err
  121. }
  122. if bytesBody == nil || len(bytesBody) == 0 {
  123. return &JsonBody{
  124. c: c,
  125. jsonBodyMap: make(map[string]any),
  126. }, nil
  127. }
  128. jsonBodyMap := make(map[string]any)
  129. err = json.Unmarshal(bytesBody, &jsonBodyMap)
  130. if err != nil {
  131. return nil, errors.New(err.Error())
  132. }
  133. c.Set(bodyKey, jsonBodyMap)
  134. return &JsonBody{
  135. c: c,
  136. jsonBodyMap: jsonBodyMap,
  137. }, nil
  138. }
  139. switch b := body.(type) {
  140. case []byte:
  141. jsonBodyMap := make(map[string]any)
  142. if b == nil || len(b) == 0 {
  143. return &JsonBody{
  144. c: c,
  145. jsonBodyMap: jsonBodyMap,
  146. }, nil
  147. }
  148. err := json.Unmarshal(b, &jsonBodyMap)
  149. if err != nil {
  150. return nil, errors.New(err.Error())
  151. }
  152. return &JsonBody{
  153. c: c,
  154. jsonBodyMap: jsonBodyMap,
  155. }, nil
  156. case map[string]any:
  157. return &JsonBody{
  158. c: c,
  159. jsonBodyMap: body.(map[string]any),
  160. }, nil
  161. default:
  162. return nil, errors.New("不支持的body类型")
  163. }
  164. }
  165. // GetQueryParams 获取上下文中的查询参数
  166. // 参数: 无
  167. // 返回值:
  168. // - 上下文中的查询参数
  169. // - 错误
  170. func (c *Context) GetQueryParams() *QueryPrams {
  171. queryParams, exist := c.Get(queryParamsKey)
  172. if !exist {
  173. originQueryParams := c.GetOriginQueryParams()
  174. c.Set(queryParamsKey, originQueryParams)
  175. return &QueryPrams{
  176. c: c,
  177. queryParams: originQueryParams,
  178. }
  179. }
  180. return &QueryPrams{
  181. c: c,
  182. queryParams: queryParams.(map[string]string),
  183. }
  184. }
  185. // GetPathParams 获取上下文中的路径参数
  186. // 参数: 无
  187. // 返回值:
  188. // - 上下文中的路径参数
  189. // - 错误
  190. func (c *Context) GetPathParams() *PathPrams {
  191. pathParams, exist := c.Get(pathParamsKey)
  192. if !exist {
  193. originPathParams := c.GetOriginPathParams()
  194. c.Set(pathParamsKey, originPathParams)
  195. return &PathPrams{
  196. c: c,
  197. pathParams: c.GetOriginPathParams(),
  198. }
  199. }
  200. return &PathPrams{
  201. c: c,
  202. pathParams: pathParams.(map[string]string),
  203. }
  204. }
  205. // ReadOriginBody 获取上下文中的原始字节Body
  206. // 参数: 无
  207. // 返回值:
  208. // - 上下文中的原始字节Body
  209. // - 错误
  210. func (c *Context) ReadOriginBody() ([]byte, error) {
  211. if c.Request.Body == nil {
  212. return make([]byte, 0), nil
  213. }
  214. body, err := io.ReadAll(c.Request.Body)
  215. if err != nil {
  216. return nil, errors.New(err.Error())
  217. }
  218. defer func(Body io.ReadCloser) {
  219. err := Body.Close()
  220. if err != nil {
  221. logger.GetInstance().Error(errors.New(err.Error()))
  222. return
  223. }
  224. }(c.Request.Body)
  225. c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
  226. c.Set(bodyKey, body)
  227. return body, nil
  228. }
  229. // GetOriginQueryParams 获取上下文中的原始查询参数
  230. // 参数: 无
  231. // 返回值:
  232. // - 上下文中的原始查询参数
  233. func (c *Context) GetOriginQueryParams() map[string]string {
  234. queryParams := make(map[string]string)
  235. for key, values := range c.Request.URL.Query() {
  236. queryParams[key] = strings.Join(values, ",")
  237. }
  238. return queryParams
  239. }
  240. // GetOriginPathParams 获取上下文中的原始路径参数
  241. // 参数: 无
  242. // 返回值:
  243. // - 上下文中的原始路径参数
  244. func (c *Context) GetOriginPathParams() map[string]string {
  245. pathParams := make(map[string]string)
  246. for _, params := range c.Params {
  247. pathParams[params.Key] = params.Value
  248. }
  249. return pathParams
  250. }
  251. type Header struct {
  252. c *Context
  253. header map[string]string
  254. }
  255. // Reload 从Context重新加载Header
  256. // 参数: 无
  257. // 返回值: 无
  258. func (header *Header) Reload() {
  259. originHeader := header.c.ReadOriginHeader()
  260. header.c.Set(headerKey, originHeader)
  261. header.header = originHeader
  262. }
  263. // Set 设置Header
  264. // 参数:
  265. // - key: Header的键
  266. // - value: Header对应键的值
  267. // 返回值: 无
  268. func (header *Header) Set(key string, value string) {
  269. header.header[key] = value
  270. header.c.Set(headerKey, header.header)
  271. }
  272. // Get 获取Header对应键的值
  273. // 参数:
  274. // - key: Header的键
  275. // 返回值:
  276. // - Header对应键的值
  277. func (header *Header) Get(key string) string {
  278. mineHeader := textproto.MIMEHeader{}
  279. for k, v := range header.header {
  280. mineHeader.Set(k, v)
  281. }
  282. return mineHeader.Get(key)
  283. }
  284. // Map 获取Header的map表示
  285. // 参数: 无
  286. // 返回值:
  287. // - Header的map表示
  288. func (header *Header) Map() map[string]string {
  289. return header.header
  290. }
  291. type BytesBody struct {
  292. c *Context
  293. bytesBody []byte
  294. }
  295. // Reload 从Context重新加载BytesBody
  296. // 参数: 无
  297. // 返回值:
  298. // - 错误
  299. func (bytesBody *BytesBody) Reload() error {
  300. originBody, err := bytesBody.c.ReadOriginBody()
  301. if err != nil {
  302. return err
  303. }
  304. bytesBody.c.Set(headerKey, originBody)
  305. bytesBody.bytesBody = originBody
  306. return nil
  307. }
  308. // Set 设置字节body
  309. // 参数:
  310. // - body: 字节body的内容
  311. // 返回值: 无
  312. func (bytesBody *BytesBody) Set(body []byte) {
  313. bytesBody.bytesBody = body
  314. bytesBody.c.Set(bodyKey, bytesBody.bytesBody)
  315. }
  316. // Bytes 获取字节body的内容
  317. // 参数: 无
  318. // 返回值:
  319. // - 字节body的内容
  320. func (bytesBody *BytesBody) Bytes() []byte {
  321. return bytesBody.bytesBody
  322. }
  323. // Marshal 将input Marshal到字节body
  324. // 参数:
  325. // - input: 输入,一般为结构或map[string]any
  326. // 返回值:
  327. // - 错误
  328. func (bytesBody *BytesBody) Marshal(input any) error {
  329. jsonBytesBody, err := json.Marshal(input)
  330. if err != nil {
  331. return errors.New(err.Error())
  332. }
  333. bytesBody.Set(jsonBytesBody)
  334. return nil
  335. }
  336. // Unmarshal 将字节body的内容Unmarshal到output
  337. // 参数:
  338. // - output: 输出,一般为结构指针或map[string]any指针
  339. // 返回值:
  340. // - 错误
  341. func (bytesBody *BytesBody) Unmarshal(output any) error {
  342. err := json.Unmarshal(bytesBody.Bytes(), output)
  343. if err != nil {
  344. return errors.New(err.Error())
  345. }
  346. return nil
  347. }
  348. type JsonBody struct {
  349. c *Context
  350. jsonBodyMap map[string]any
  351. }
  352. // Reload 从Context重新加载JsonBody
  353. // 参数: 无
  354. // 返回值:
  355. // - 错误
  356. func (jsonBody *JsonBody) Reload() error {
  357. bytesBody, err := jsonBody.c.ReadOriginBody()
  358. if err != nil {
  359. return err
  360. }
  361. jsonBodyMap := make(map[string]any)
  362. if bytesBody == nil || len(bytesBody) == 0 {
  363. jsonBody.c.Set(bodyKey, jsonBodyMap)
  364. jsonBody.jsonBodyMap = jsonBodyMap
  365. return nil
  366. }
  367. err = json.Unmarshal(bytesBody, &jsonBodyMap)
  368. if err != nil {
  369. return errors.New(err.Error())
  370. }
  371. jsonBody.c.Set(bodyKey, jsonBodyMap)
  372. jsonBody.jsonBodyMap = jsonBodyMap
  373. return nil
  374. }
  375. // Set 设置JsonBody键对应的值
  376. // 参数:
  377. // - key: JsonBody的键
  378. // - value: JsonBody对应键的值
  379. // 返回值: 无
  380. func (jsonBody *JsonBody) Set(key string, value any) {
  381. jsonBody.jsonBodyMap[key] = value
  382. jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap)
  383. }
  384. // Delete 删除JsonBody键对应的值
  385. // 参数:
  386. // - key: JsonBody的键
  387. // 返回值: 无
  388. func (jsonBody *JsonBody) Delete(key string) {
  389. delete(jsonBody.jsonBodyMap, key)
  390. jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap)
  391. }
  392. // Get 获取JsonBody对应键的值
  393. // 参数:
  394. // - key: JsonBody的键
  395. // 返回值:
  396. // - JsonBody对应键的值
  397. func (jsonBody *JsonBody) Get(key string) any {
  398. return jsonBody.jsonBodyMap[key]
  399. }
  400. // SetMap 使用map设置JsonBody
  401. // 参数:
  402. // - mapBody: map表示的body
  403. // 返回值: 无
  404. func (jsonBody *JsonBody) SetMap(mapBody map[string]any) {
  405. jsonBody.jsonBodyMap = mapBody
  406. jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap)
  407. }
  408. // Map 获取JsonBody的map表示
  409. // 参数: 无
  410. // 返回值:
  411. // - JsonBody的map表示
  412. func (jsonBody *JsonBody) Map() map[string]any {
  413. return jsonBody.jsonBodyMap
  414. }
  415. // Bytes 获取JsonBody的内容
  416. // 参数: 无
  417. // 返回值:
  418. // - JsonBody的内容
  419. func (jsonBody *JsonBody) Bytes() ([]byte, error) {
  420. if jsonBody.jsonBodyMap == nil || len(jsonBody.jsonBodyMap) == 0 {
  421. return make([]byte, 0), nil
  422. }
  423. jsonBytes, err := json.Marshal(jsonBody.jsonBodyMap)
  424. if err != nil {
  425. return nil, errors.New(err.Error())
  426. }
  427. return jsonBytes, nil
  428. }
  429. // Unmarshal 将JsonBody的内容Unmarshal到output
  430. // 参数:
  431. // - output: 输出,一般为结构指针或map[string]any指针
  432. // 返回值:
  433. // - 错误
  434. func (jsonBody *JsonBody) Unmarshal(output any) error {
  435. jsonBytes, err := jsonBody.Bytes()
  436. if err != nil {
  437. return err
  438. }
  439. if jsonBytes == nil || len(jsonBytes) == 0 {
  440. return nil
  441. }
  442. err = json.Unmarshal(jsonBytes, output)
  443. if err != nil {
  444. return errors.New(err.Error())
  445. }
  446. return nil
  447. }
  448. type QueryPrams struct {
  449. c *Context
  450. queryParams map[string]string
  451. }
  452. // Reload 从Context重新加载QueryParams
  453. // 参数: 无
  454. // 返回值: 无
  455. func (queryParams *QueryPrams) Reload() {
  456. originQueryParams := queryParams.c.GetOriginQueryParams()
  457. queryParams.c.Set(queryParamsKey, originQueryParams)
  458. queryParams.queryParams = originQueryParams
  459. }
  460. // Set 设置查询参数
  461. // 参数:
  462. // - key: 查询参数的键
  463. // - value: 查询参数对应键的值
  464. // 返回值: 无
  465. func (queryParams *QueryPrams) Set(key string, value string) {
  466. queryParams.queryParams[key] = value
  467. queryParams.c.Set(queryParamsKey, queryParams.queryParams)
  468. }
  469. // Delete 删除JsonBody键对应的值
  470. // 参数:
  471. // - key: JsonBody的键
  472. // 返回值: 无
  473. func (queryParams *QueryPrams) Delete(key string) {
  474. delete(queryParams.queryParams, key)
  475. queryParams.c.Set(queryParamsKey, queryParams.queryParams)
  476. }
  477. // Get 获取查询参数对应键的值
  478. // 参数:
  479. // - key: 查询参数的键
  480. // 返回值:
  481. // - 查询参数对应键的值
  482. func (queryParams *QueryPrams) Get(key string) string {
  483. return queryParams.queryParams[key]
  484. }
  485. // SetMap 使用map设置查询参数
  486. // 参数:
  487. // - mapQueryParams: map表示的查询参数
  488. // 返回值: 无
  489. func (queryParams *QueryPrams) SetMap(mapQueryParams map[string]string) {
  490. queryParams.queryParams = mapQueryParams
  491. queryParams.c.Set(queryParamsKey, queryParams.queryParams)
  492. }
  493. // Map 获取查询参数的map表示
  494. // 参数: 无
  495. // 返回值:
  496. // - 查询参数的map表示
  497. func (queryParams *QueryPrams) Map() map[string]string {
  498. return queryParams.queryParams
  499. }
  500. type PathPrams struct {
  501. c *Context
  502. pathParams map[string]string
  503. }
  504. // Reload 从Context重新加载QueryParams
  505. // 参数: 无
  506. // 返回值: 无
  507. func (pathParams *PathPrams) Reload() {
  508. originPathParams := pathParams.c.GetOriginPathParams()
  509. pathParams.c.Set(pathParamsKey, originPathParams)
  510. pathParams.pathParams = originPathParams
  511. }
  512. // Set 设置路径参数
  513. // 参数:
  514. // - key: 路径参数的键
  515. // - value: 路径参数对应键的值
  516. // 返回值: 无
  517. func (pathParams *PathPrams) Set(key string, value string) {
  518. pathParams.pathParams[key] = value
  519. pathParams.c.Set(pathParamsKey, pathParams.pathParams)
  520. }
  521. // Delete 删除JsonBody键对应的值
  522. // 参数:
  523. // - key: JsonBody的键
  524. // 返回值: 无
  525. func (pathParams *PathPrams) Delete(key string) {
  526. delete(pathParams.pathParams, key)
  527. pathParams.c.Set(pathParamsKey, pathParams.pathParams)
  528. }
  529. // Get 获取路径参数对应键的值
  530. // 参数:
  531. // - key: 路径参数的键
  532. // 返回值:
  533. // - 路径参数对应键的值
  534. func (pathParams *PathPrams) Get(key string) string {
  535. return pathParams.pathParams[key]
  536. }
  537. // SetMap 使用map设置路径参数
  538. // 参数:
  539. // - mapPathParams: map表示的路径参数
  540. // 返回值: 无
  541. func (pathParams *PathPrams) SetMap(mapPathParams map[string]string) {
  542. pathParams.pathParams = mapPathParams
  543. pathParams.c.Set(pathParamsKey, pathParams.pathParams)
  544. }
  545. // Map 获取路径参数的map表示
  546. // 参数: 无
  547. // 返回值:
  548. // - 路径参数的map表示
  549. func (pathParams *PathPrams) Map() map[string]string {
  550. return pathParams.pathParams
  551. }
  552. type TenantInfo interface {
  553. GetID() string
  554. GetName() string
  555. }
  556. type UserInfo interface {
  557. GetID() string
  558. GetUserName() string
  559. GetName() string
  560. }
  561. func (c *Context) SetTenantInfo(tenantInfo TenantInfo) {
  562. c.Set(tenantInfoKey, tenantInfo)
  563. }
  564. func (c *Context) SetUserInfo(userInfo UserInfo) {
  565. c.Set(userInfoKey, userInfo)
  566. }
  567. func (c *Context) GetTenantInfo() TenantInfo {
  568. tenantInfo, exist := c.Get(tenantInfoKey)
  569. if !exist {
  570. return nil
  571. }
  572. return tenantInfo.(TenantInfo)
  573. }
  574. func (c *Context) GetUserInfo() UserInfo {
  575. userInfo, exist := c.Get(userInfoKey)
  576. if !exist {
  577. return nil
  578. }
  579. return userInfo.(UserInfo)
  580. }