context.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. package api
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
  6. "github.com/gin-gonic/gin"
  7. "github.com/pkg/errors"
  8. "io"
  9. "mime/multipart"
  10. "net/textproto"
  11. "strings"
  12. )
  13. const (
  14. bodyKey = "body-context"
  15. queryParamsKey = "query-params-context"
  16. pathParamsKey = "path-params-context"
  17. headerKey = "header-context"
  18. tenantInfoKey = "context-tenant-info"
  19. userInfoKey = "context-user-info"
  20. )
  21. type Context struct {
  22. *gin.Context
  23. }
  24. // GetFileHeaderBytes 获取Multipart中传递的文件名和文件内容
  25. // 参数:
  26. // - fileHeader: Multipart的文件头
  27. // 返回值:
  28. // - 文件名
  29. // - 文件字节
  30. // - 错误
  31. func (c *Context) GetFileHeaderBytes(fileHeader *multipart.FileHeader) (string, []byte, error) {
  32. file, err := fileHeader.Open()
  33. if err != nil {
  34. return "", nil, errors.New(err.Error())
  35. }
  36. defer func(file multipart.File) {
  37. err := file.Close()
  38. if err != nil {
  39. logger.GetInstance().Error(errors.New(err.Error()))
  40. return
  41. }
  42. }(file)
  43. contentBytes, err := io.ReadAll(file)
  44. if err != nil {
  45. return "", nil, errors.New(err.Error())
  46. }
  47. return fileHeader.Filename, contentBytes, nil
  48. }
  49. // GetHeader 获取上下文中的Header
  50. // 参数: 无
  51. // 返回值:
  52. // - 上下文中的Header
  53. func (c *Context) GetHeader() *Header {
  54. savedHeader, exist := c.Get(headerKey)
  55. if exist {
  56. return &Header{
  57. c: c,
  58. header: savedHeader.(map[string]string),
  59. }
  60. }
  61. return &Header{
  62. c: c,
  63. header: c.ReadOriginHeader(),
  64. }
  65. }
  66. func (c *Context) ReadOriginHeader() map[string]string {
  67. originHeader := make(map[string]string)
  68. for key, values := range c.Request.Header {
  69. originHeader[key] = strings.Join(values, ",")
  70. }
  71. return originHeader
  72. }
  73. // GetBytesBody 获取上下文中的字节Body
  74. // 参数: 无
  75. // 返回值:
  76. // - 上下文中的字节Body
  77. // - 错误
  78. func (c *Context) GetBytesBody() (*BytesBody, error) {
  79. body, exist := c.Get(bodyKey)
  80. if !exist {
  81. bytesBody, err := c.ReadOriginBody()
  82. if err != nil {
  83. return nil, err
  84. }
  85. c.Set(bodyKey, bytesBody)
  86. return &BytesBody{
  87. c: c,
  88. bytesBody: bytesBody,
  89. }, nil
  90. }
  91. switch b := body.(type) {
  92. case []byte:
  93. return &BytesBody{
  94. c: c,
  95. bytesBody: b,
  96. }, nil
  97. case map[string]any:
  98. bytesBody, err := json.Marshal(b)
  99. if err != nil {
  100. return nil, errors.New(err.Error())
  101. }
  102. return &BytesBody{
  103. c: c,
  104. bytesBody: bytesBody,
  105. }, nil
  106. default:
  107. return nil, errors.New("不支持的body类型")
  108. }
  109. }
  110. // GetJsonBody 获取上下文中的JsonBody
  111. // 参数: 无
  112. // 返回值:
  113. // - 上下文中的JsonBody
  114. // - 错误
  115. func (c *Context) GetJsonBody() (*JsonBody, error) {
  116. body, exist := c.Get(bodyKey)
  117. if !exist {
  118. bytesBody, err := c.ReadOriginBody()
  119. if err != nil {
  120. return nil, err
  121. }
  122. jsonBodyMap := make(map[string]any)
  123. err = json.Unmarshal(bytesBody, &jsonBodyMap)
  124. if err != nil {
  125. return nil, errors.New(err.Error())
  126. }
  127. c.Set(bodyKey, jsonBodyMap)
  128. return &JsonBody{
  129. c: c,
  130. jsonBodyMap: jsonBodyMap,
  131. }, nil
  132. }
  133. switch b := body.(type) {
  134. case []byte:
  135. jsonBodyMap := make(map[string]any)
  136. err := json.Unmarshal(b, &jsonBodyMap)
  137. if err != nil {
  138. return nil, errors.New(err.Error())
  139. }
  140. return &JsonBody{
  141. c: c,
  142. jsonBodyMap: jsonBodyMap,
  143. }, nil
  144. case map[string]any:
  145. return &JsonBody{
  146. c: c,
  147. jsonBodyMap: body.(map[string]any),
  148. }, nil
  149. default:
  150. return nil, errors.New("不支持的body类型")
  151. }
  152. }
  153. // GetQueryParams 获取上下文中的查询参数
  154. // 参数: 无
  155. // 返回值:
  156. // - 上下文中的查询参数
  157. // - 错误
  158. func (c *Context) GetQueryParams() *QueryPrams {
  159. queryParams, exist := c.Get(queryParamsKey)
  160. if !exist {
  161. originQueryParams := c.GetOriginQueryParams()
  162. c.Set(queryParamsKey, originQueryParams)
  163. return &QueryPrams{
  164. c: c,
  165. queryParams: originQueryParams,
  166. }
  167. }
  168. return &QueryPrams{
  169. c: c,
  170. queryParams: queryParams.(map[string]string),
  171. }
  172. }
  173. // GetPathParams 获取上下文中的路径参数
  174. // 参数: 无
  175. // 返回值:
  176. // - 上下文中的路径参数
  177. // - 错误
  178. func (c *Context) GetPathParams() *PathPrams {
  179. pathParams, exist := c.Get(pathParamsKey)
  180. if !exist {
  181. originPathParams := c.GetOriginPathParams()
  182. c.Set(pathParamsKey, originPathParams)
  183. return &PathPrams{
  184. c: c,
  185. pathParams: c.GetOriginPathParams(),
  186. }
  187. }
  188. return &PathPrams{
  189. c: c,
  190. pathParams: pathParams.(map[string]string),
  191. }
  192. }
  193. // ReadOriginBody 获取上下文中的原始字节Body
  194. // 参数: 无
  195. // 返回值:
  196. // - 上下文中的原始字节Body
  197. // - 错误
  198. func (c *Context) ReadOriginBody() ([]byte, error) {
  199. if c.Request.Body == nil {
  200. return make([]byte, 0), nil
  201. }
  202. body, err := io.ReadAll(c.Request.Body)
  203. if err != nil {
  204. return nil, errors.New(err.Error())
  205. }
  206. defer func(Body io.ReadCloser) {
  207. err := Body.Close()
  208. if err != nil {
  209. logger.GetInstance().Error(errors.New(err.Error()))
  210. return
  211. }
  212. }(c.Request.Body)
  213. c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
  214. c.Set(bodyKey, body)
  215. return body, nil
  216. }
  217. // GetOriginQueryParams 获取上下文中的原始查询参数
  218. // 参数: 无
  219. // 返回值:
  220. // - 上下文中的原始查询参数
  221. func (c *Context) GetOriginQueryParams() map[string]string {
  222. queryParams := make(map[string]string)
  223. for key, values := range c.Request.URL.Query() {
  224. queryParams[key] = strings.Join(values, ",")
  225. }
  226. return queryParams
  227. }
  228. // GetOriginPathParams 获取上下文中的原始路径参数
  229. // 参数: 无
  230. // 返回值:
  231. // - 上下文中的原始路径参数
  232. func (c *Context) GetOriginPathParams() map[string]string {
  233. pathParams := make(map[string]string)
  234. for _, params := range c.Params {
  235. pathParams[params.Key] = params.Value
  236. }
  237. return pathParams
  238. }
  239. type Header struct {
  240. c *Context
  241. header map[string]string
  242. }
  243. // Reload 从Context重新加载Header
  244. // 参数: 无
  245. // 返回值: 无
  246. func (header *Header) Reload() {
  247. originHeader := header.c.ReadOriginHeader()
  248. header.c.Set(headerKey, originHeader)
  249. header.header = originHeader
  250. }
  251. // Set 设置Header
  252. // 参数:
  253. // - key: Header的键
  254. // - value: Header对应键的值
  255. // 返回值: 无
  256. func (header *Header) Set(key string, value string) {
  257. header.header[key] = value
  258. header.c.Set(headerKey, header.header)
  259. }
  260. // Get 获取Header对应键的值
  261. // 参数:
  262. // - key: Header的键
  263. // 返回值:
  264. // - Header对应键的值
  265. func (header *Header) Get(key string) string {
  266. mineHeader := textproto.MIMEHeader{}
  267. for k, v := range header.header {
  268. mineHeader.Set(k, v)
  269. }
  270. return mineHeader.Get(key)
  271. }
  272. // Map 获取Header的map表示
  273. // 参数: 无
  274. // 返回值:
  275. // - Header的map表示
  276. func (header *Header) Map() map[string]string {
  277. return header.header
  278. }
  279. type BytesBody struct {
  280. c *Context
  281. bytesBody []byte
  282. }
  283. // Reload 从Context重新加载BytesBody
  284. // 参数: 无
  285. // 返回值:
  286. // - 错误
  287. func (bytesBody *BytesBody) Reload() error {
  288. originBody, err := bytesBody.c.ReadOriginBody()
  289. if err != nil {
  290. return err
  291. }
  292. bytesBody.c.Set(headerKey, originBody)
  293. bytesBody.bytesBody = originBody
  294. return nil
  295. }
  296. // Set 设置字节body
  297. // 参数:
  298. // - body: 字节body的内容
  299. // 返回值: 无
  300. func (bytesBody *BytesBody) Set(body []byte) {
  301. bytesBody.bytesBody = body
  302. bytesBody.c.Set(bodyKey, bytesBody.bytesBody)
  303. }
  304. // Bytes 获取字节body的内容
  305. // 参数: 无
  306. // 返回值:
  307. // - 字节body的内容
  308. func (bytesBody *BytesBody) Bytes() []byte {
  309. return bytesBody.bytesBody
  310. }
  311. // Marshal 将input Marshal到字节body
  312. // 参数:
  313. // - input: 输入,一般为结构或map[string]any
  314. // 返回值:
  315. // - 错误
  316. func (bytesBody *BytesBody) Marshal(input any) error {
  317. jsonBytesBody, err := json.Marshal(input)
  318. if err != nil {
  319. return errors.New(err.Error())
  320. }
  321. bytesBody.Set(jsonBytesBody)
  322. return nil
  323. }
  324. // Unmarshal 将字节body的内容Unmarshal到output
  325. // 参数:
  326. // - output: 输出,一般为结构指针或map[string]any指针
  327. // 返回值:
  328. // - 错误
  329. func (bytesBody *BytesBody) Unmarshal(output any) error {
  330. err := json.Unmarshal(bytesBody.Bytes(), output)
  331. if err != nil {
  332. return errors.New(err.Error())
  333. }
  334. return nil
  335. }
  336. type JsonBody struct {
  337. c *Context
  338. jsonBodyMap map[string]any
  339. }
  340. // Reload 从Context重新加载JsonBody
  341. // 参数: 无
  342. // 返回值:
  343. // - 错误
  344. func (jsonBody *JsonBody) Reload() error {
  345. bytesBody, err := jsonBody.c.ReadOriginBody()
  346. if err != nil {
  347. return err
  348. }
  349. jsonBodyMap := make(map[string]any)
  350. err = json.Unmarshal(bytesBody, &jsonBodyMap)
  351. if err != nil {
  352. return errors.New(err.Error())
  353. }
  354. jsonBody.c.Set(bodyKey, jsonBodyMap)
  355. jsonBody.jsonBodyMap = jsonBodyMap
  356. return nil
  357. }
  358. // Set 设置JsonBody键对应的值
  359. // 参数:
  360. // - key: JsonBody的键
  361. // - value: JsonBody对应键的值
  362. // 返回值: 无
  363. func (jsonBody *JsonBody) Set(key string, value any) {
  364. jsonBody.jsonBodyMap[key] = value
  365. jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap)
  366. }
  367. // Delete 删除JsonBody键对应的值
  368. // 参数:
  369. // - key: JsonBody的键
  370. // 返回值: 无
  371. func (jsonBody *JsonBody) Delete(key string) {
  372. delete(jsonBody.jsonBodyMap, key)
  373. jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap)
  374. }
  375. // Get 获取JsonBody对应键的值
  376. // 参数:
  377. // - key: JsonBody的键
  378. // 返回值:
  379. // - JsonBody对应键的值
  380. func (jsonBody *JsonBody) Get(key string) any {
  381. return jsonBody.jsonBodyMap[key]
  382. }
  383. // Map 获取JsonBody的map表示
  384. // 参数: 无
  385. // 返回值:
  386. // - JsonBody的map表示
  387. func (jsonBody *JsonBody) Map() map[string]any {
  388. return jsonBody.jsonBodyMap
  389. }
  390. // Bytes 获取JsonBody的内容
  391. // 参数: 无
  392. // 返回值:
  393. // - JsonBody的内容
  394. func (jsonBody *JsonBody) Bytes() ([]byte, error) {
  395. jsonBytes, err := json.Marshal(jsonBody.jsonBodyMap)
  396. if err != nil {
  397. return nil, errors.New(err.Error())
  398. }
  399. return jsonBytes, nil
  400. }
  401. // Unmarshal 将JsonBody的内容Unmarshal到output
  402. // 参数:
  403. // - output: 输出,一般为结构指针或map[string]any指针
  404. // 返回值:
  405. // - 错误
  406. func (jsonBody *JsonBody) Unmarshal(output any) error {
  407. jsonBytes, err := jsonBody.Bytes()
  408. if err != nil {
  409. return err
  410. }
  411. err = json.Unmarshal(jsonBytes, output)
  412. if err != nil {
  413. return errors.New(err.Error())
  414. }
  415. return nil
  416. }
  417. type QueryPrams struct {
  418. c *Context
  419. queryParams map[string]string
  420. }
  421. // Reload 从Context重新加载QueryParams
  422. // 参数: 无
  423. // 返回值: 无
  424. func (queryParams *QueryPrams) Reload() {
  425. originQueryParams := queryParams.c.GetOriginQueryParams()
  426. queryParams.c.Set(queryParamsKey, originQueryParams)
  427. queryParams.queryParams = originQueryParams
  428. }
  429. // Set 设置查询参数
  430. // 参数:
  431. // - key: 查询参数的键
  432. // - value: 查询参数对应键的值
  433. // 返回值: 无
  434. func (queryParams *QueryPrams) Set(key string, value string) {
  435. queryParams.queryParams[key] = value
  436. queryParams.c.Set(queryParamsKey, queryParams.queryParams)
  437. }
  438. // Delete 删除JsonBody键对应的值
  439. // 参数:
  440. // - key: JsonBody的键
  441. // 返回值: 无
  442. func (queryParams *QueryPrams) Delete(key string) {
  443. delete(queryParams.queryParams, key)
  444. queryParams.c.Set(queryParamsKey, queryParams.queryParams)
  445. }
  446. // Get 获取查询参数对应键的值
  447. // 参数:
  448. // - key: 查询参数的键
  449. // 返回值:
  450. // - 查询参数对应键的值
  451. func (queryParams *QueryPrams) Get(key string) string {
  452. return queryParams.queryParams[key]
  453. }
  454. // Map 获取查询参数的map表示
  455. // 参数: 无
  456. // 返回值:
  457. // - 查询参数的map表示
  458. func (queryParams *QueryPrams) Map() map[string]string {
  459. return queryParams.queryParams
  460. }
  461. type PathPrams struct {
  462. c *Context
  463. pathParams map[string]string
  464. }
  465. // Reload 从Context重新加载QueryParams
  466. // 参数: 无
  467. // 返回值: 无
  468. func (pathParams *PathPrams) Reload() {
  469. originPathParams := pathParams.c.GetOriginPathParams()
  470. pathParams.c.Set(pathParamsKey, originPathParams)
  471. pathParams.pathParams = originPathParams
  472. }
  473. // Set 设置路径参数
  474. // 参数:
  475. // - key: 路径参数的键
  476. // - value: 路径参数对应键的值
  477. // 返回值: 无
  478. func (pathParams *PathPrams) Set(key string, value string) {
  479. pathParams.pathParams[key] = value
  480. pathParams.c.Set(pathParamsKey, pathParams.pathParams)
  481. }
  482. // Delete 删除JsonBody键对应的值
  483. // 参数:
  484. // - key: JsonBody的键
  485. // 返回值: 无
  486. func (pathParams *PathPrams) Delete(key string) {
  487. delete(pathParams.pathParams, key)
  488. pathParams.c.Set(pathParamsKey, pathParams.pathParams)
  489. }
  490. // Get 获取路径参数对应键的值
  491. // 参数:
  492. // - key: 路径参数的键
  493. // 返回值:
  494. // - 路径参数对应键的值
  495. func (pathParams *PathPrams) Get(key string) string {
  496. return pathParams.pathParams[key]
  497. }
  498. // Map 获取路径参数的map表示
  499. // 参数: 无
  500. // 返回值:
  501. // - 路径参数的map表示
  502. func (pathParams *PathPrams) Map() map[string]string {
  503. return pathParams.pathParams
  504. }
  505. type TenantInfo interface {
  506. GetID() string
  507. GetName() string
  508. }
  509. type UserInfo interface {
  510. GetID() string
  511. GetUserName() string
  512. GetName() string
  513. }
  514. func (c *Context) SetTenantInfo(tenantInfo TenantInfo) {
  515. c.Set(tenantInfoKey, tenantInfo)
  516. }
  517. func (c *Context) SetUserInfo(userInfo UserInfo) {
  518. c.Set(userInfoKey, userInfo)
  519. }
  520. func (c *Context) GetTenantInfo() TenantInfo {
  521. tenantInfo, exist := c.Get(tenantInfoKey)
  522. if !exist {
  523. return nil
  524. }
  525. return tenantInfo.(TenantInfo)
  526. }
  527. func (c *Context) GetUserInfo() UserInfo {
  528. userInfo, exist := c.Get(userInfoKey)
  529. if !exist {
  530. return nil
  531. }
  532. return userInfo.(UserInfo)
  533. }