context.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. package api
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
  6. "github.com/gin-gonic/gin"
  7. "github.com/pkg/errors"
  8. "io"
  9. "mime/multipart"
  10. "net/textproto"
  11. "strings"
  12. )
  13. const (
  14. bodyKey = "body-context"
  15. queryParamsKey = "query-params-context"
  16. pathParamsKey = "path-params-context"
  17. headerKey = "header-context"
  18. tenantInfoKey = "context-tenant-info"
  19. userInfoKey = "context-user-info"
  20. )
  21. type Context struct {
  22. *gin.Context
  23. }
  24. // GetFileHeaderBytes 获取Multipart中传递的文件名和文件内容
  25. // 参数:
  26. // - fileHeader: Multipart的文件头
  27. // 返回值:
  28. // - 文件名
  29. // - 文件字节
  30. // - 错误
  31. func (c *Context) GetFileHeaderBytes(fileHeader *multipart.FileHeader) (string, []byte, error) {
  32. file, err := fileHeader.Open()
  33. if err != nil {
  34. return "", nil, errors.New(err.Error())
  35. }
  36. defer func(file multipart.File) {
  37. err := file.Close()
  38. if err != nil {
  39. logger.GetInstance().Error(errors.New(err.Error()))
  40. return
  41. }
  42. }(file)
  43. contentBytes, err := io.ReadAll(file)
  44. if err != nil {
  45. return "", nil, errors.New(err.Error())
  46. }
  47. return fileHeader.Filename, contentBytes, nil
  48. }
  49. // GetHeader 获取上下文中的Header
  50. // 参数: 无
  51. // 返回值:
  52. // - 上下文中的Header
  53. func (c *Context) GetHeader() *Header {
  54. savedHeader, exist := c.Get(headerKey)
  55. if exist {
  56. return &Header{
  57. c: c,
  58. header: savedHeader.(map[string]string),
  59. }
  60. }
  61. header := make(map[string]string)
  62. for key, values := range c.Request.Header {
  63. header[key] = strings.Join(values, ",")
  64. }
  65. c.Set(headerKey, header)
  66. return &Header{
  67. c: c,
  68. header: header,
  69. }
  70. }
  71. // GetBytesBody 获取上下文中的字节Body
  72. // 参数: 无
  73. // 返回值:
  74. // - 上下文中的字节Body
  75. // - 错误
  76. func (c *Context) GetBytesBody() (*BytesBody, error) {
  77. body, exist := c.Get(bodyKey)
  78. if !exist {
  79. bytesBody, err := c.ReadOriginBody()
  80. if err != nil {
  81. return nil, err
  82. }
  83. c.Set(bodyKey, bytesBody)
  84. return &BytesBody{
  85. c: c,
  86. bytesBody: bytesBody,
  87. }, nil
  88. }
  89. switch b := body.(type) {
  90. case []byte:
  91. return &BytesBody{
  92. c: c,
  93. bytesBody: b,
  94. }, nil
  95. case map[string]any:
  96. bytesBody, err := json.Marshal(b)
  97. if err != nil {
  98. return nil, errors.New(err.Error())
  99. }
  100. return &BytesBody{
  101. c: c,
  102. bytesBody: bytesBody,
  103. }, nil
  104. default:
  105. return nil, errors.New("不支持的body类型")
  106. }
  107. }
  108. // GetJsonBody 获取上下文中的JsonBody
  109. // 参数: 无
  110. // 返回值:
  111. // - 上下文中的JsonBody
  112. // - 错误
  113. func (c *Context) GetJsonBody() (*JsonBody, error) {
  114. body, exist := c.Get(bodyKey)
  115. if !exist {
  116. bytesBody, err := c.ReadOriginBody()
  117. if err != nil {
  118. return nil, err
  119. }
  120. jsonBodyMap := make(map[string]any)
  121. err = json.Unmarshal(bytesBody, &jsonBodyMap)
  122. if err != nil {
  123. return nil, errors.New(err.Error())
  124. }
  125. c.Set(bodyKey, jsonBodyMap)
  126. return &JsonBody{
  127. c: c,
  128. jsonBodyMap: jsonBodyMap,
  129. }, nil
  130. }
  131. switch b := body.(type) {
  132. case []byte:
  133. jsonBodyMap := make(map[string]any)
  134. err := json.Unmarshal(b, &jsonBodyMap)
  135. if err != nil {
  136. return nil, errors.New(err.Error())
  137. }
  138. return &JsonBody{
  139. c: c,
  140. jsonBodyMap: jsonBodyMap,
  141. }, nil
  142. case map[string]any:
  143. return &JsonBody{
  144. c: c,
  145. jsonBodyMap: body.(map[string]any),
  146. }, nil
  147. default:
  148. return nil, errors.New("不支持的body类型")
  149. }
  150. }
  151. // GetQueryParams 获取上下文中的查询参数
  152. // 参数: 无
  153. // 返回值:
  154. // - 上下文中的查询参数
  155. // - 错误
  156. func (c *Context) GetQueryParams() *QueryPrams {
  157. queryParams, exist := c.Get(queryParamsKey)
  158. if !exist {
  159. return &QueryPrams{
  160. c: c,
  161. queryParams: c.GetOriginQueryParams(),
  162. }
  163. }
  164. return &QueryPrams{
  165. c: c,
  166. queryParams: queryParams.(map[string]string),
  167. }
  168. }
  169. // GetPathParams 获取上下文中的路径参数
  170. // 参数: 无
  171. // 返回值:
  172. // - 上下文中的路径参数
  173. // - 错误
  174. func (c *Context) GetPathParams() *PathPrams {
  175. pathParams, exist := c.Get(pathParamsKey)
  176. if !exist {
  177. return &PathPrams{
  178. c: c,
  179. pathParams: c.GetOriginPathParams(),
  180. }
  181. }
  182. return &PathPrams{
  183. c: c,
  184. pathParams: pathParams.(map[string]string),
  185. }
  186. }
  187. // ReadOriginBody 获取上下文中的原始字节Body
  188. // 参数: 无
  189. // 返回值:
  190. // - 上下文中的原始字节Body
  191. // - 错误
  192. func (c *Context) ReadOriginBody() ([]byte, error) {
  193. if c.Request.Body == nil {
  194. return make([]byte, 0), nil
  195. }
  196. body, err := io.ReadAll(c.Request.Body)
  197. if err != nil {
  198. return nil, errors.New(err.Error())
  199. }
  200. defer func(Body io.ReadCloser) {
  201. err := Body.Close()
  202. if err != nil {
  203. logger.GetInstance().Error(errors.New(err.Error()))
  204. return
  205. }
  206. }(c.Request.Body)
  207. c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
  208. c.Set(bodyKey, body)
  209. return body, nil
  210. }
  211. // GetOriginQueryParams 获取上下文中的原始查询参数
  212. // 参数: 无
  213. // 返回值:
  214. // - 上下文中的原始查询参数
  215. func (c *Context) GetOriginQueryParams() map[string]string {
  216. queryParams := make(map[string]string)
  217. for key, values := range c.Request.URL.Query() {
  218. queryParams[key] = strings.Join(values, ",")
  219. }
  220. return queryParams
  221. }
  222. // GetOriginPathParams 获取上下文中的原始路径参数
  223. // 参数: 无
  224. // 返回值:
  225. // - 上下文中的原始路径参数
  226. func (c *Context) GetOriginPathParams() map[string]string {
  227. pathParams := make(map[string]string)
  228. for _, params := range c.Params {
  229. pathParams[params.Key] = params.Value
  230. }
  231. return pathParams
  232. }
  233. type Header struct {
  234. c *Context
  235. header map[string]string
  236. }
  237. // Set 设置Header
  238. // 参数:
  239. // - key: Header的键
  240. // - value: Header对应键的值
  241. // 返回值: 无
  242. func (header *Header) Set(key string, value string) {
  243. header.header[key] = value
  244. header.c.Set(headerKey, header.header)
  245. }
  246. // Get 获取Header对应键的值
  247. // 参数:
  248. // - key: Header的键
  249. // 返回值:
  250. // - Header对应键的值
  251. func (header *Header) Get(key string) string {
  252. mineHeader := textproto.MIMEHeader{}
  253. for k, v := range header.header {
  254. mineHeader.Set(k, v)
  255. }
  256. return mineHeader.Get(key)
  257. }
  258. // Map 获取Header的map表示
  259. // 参数: 无
  260. // 返回值:
  261. // - Header的map表示
  262. func (header *Header) Map() map[string]string {
  263. return header.header
  264. }
  265. type BytesBody struct {
  266. c *Context
  267. bytesBody []byte
  268. }
  269. // Set 设置字节body
  270. // 参数:
  271. // - body: 字节body的内容
  272. // 返回值: 无
  273. func (bytesBody *BytesBody) Set(body []byte) {
  274. bytesBody.bytesBody = body
  275. bytesBody.c.Set(bodyKey, bytesBody.bytesBody)
  276. }
  277. // Bytes 获取字节body的内容
  278. // 参数: 无
  279. // 返回值:
  280. // - 字节body的内容
  281. func (bytesBody *BytesBody) Bytes() []byte {
  282. return bytesBody.bytesBody
  283. }
  284. // Marshal 将input Marshal到字节body
  285. // 参数:
  286. // - input: 输入,一般为结构或map[string]any
  287. // 返回值:
  288. // - 错误
  289. func (bytesBody *BytesBody) Marshal(input any) error {
  290. jsonBytesBody, err := json.Marshal(input)
  291. if err != nil {
  292. return errors.New(err.Error())
  293. }
  294. bytesBody.Set(jsonBytesBody)
  295. return nil
  296. }
  297. // Unmarshal 将字节body的内容Unmarshal到output
  298. // 参数:
  299. // - output: 输出,一般为结构指针或map[string]any指针
  300. // 返回值:
  301. // - 错误
  302. func (bytesBody *BytesBody) Unmarshal(output any) error {
  303. err := json.Unmarshal(bytesBody.Bytes(), output)
  304. if err != nil {
  305. return errors.New(err.Error())
  306. }
  307. return nil
  308. }
  309. type JsonBody struct {
  310. c *Context
  311. jsonBodyMap map[string]any
  312. }
  313. // Set 设置JsonBody键对应的值
  314. // 参数:
  315. // - key: JsonBody的键
  316. // - value: JsonBody对应键的值
  317. // 返回值: 无
  318. func (jsonBody *JsonBody) Set(key string, value any) {
  319. jsonBody.jsonBodyMap[key] = value
  320. jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap)
  321. }
  322. // Delete 删除JsonBody键对应的值
  323. // 参数:
  324. // - key: JsonBody的键
  325. // 返回值: 无
  326. func (jsonBody *JsonBody) Delete(key string) {
  327. delete(jsonBody.jsonBodyMap, key)
  328. jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap)
  329. }
  330. // Get 获取JsonBody对应键的值
  331. // 参数:
  332. // - key: JsonBody的键
  333. // 返回值:
  334. // - JsonBody对应键的值
  335. func (jsonBody *JsonBody) Get(key string) any {
  336. return jsonBody.jsonBodyMap[key]
  337. }
  338. // Map 获取JsonBody的map表示
  339. // 参数: 无
  340. // 返回值:
  341. // - JsonBody的map表示
  342. func (jsonBody *JsonBody) Map() map[string]any {
  343. return jsonBody.jsonBodyMap
  344. }
  345. // Bytes 获取JsonBody的内容
  346. // 参数: 无
  347. // 返回值:
  348. // - JsonBody的内容
  349. func (jsonBody *JsonBody) Bytes() ([]byte, error) {
  350. jsonBytes, err := json.Marshal(jsonBody.jsonBodyMap)
  351. if err != nil {
  352. return nil, errors.New(err.Error())
  353. }
  354. return jsonBytes, nil
  355. }
  356. // Unmarshal 将JsonBody的内容Unmarshal到output
  357. // 参数:
  358. // - output: 输出,一般为结构指针或map[string]any指针
  359. // 返回值:
  360. // - 错误
  361. func (jsonBody *JsonBody) Unmarshal(output any) error {
  362. jsonBytes, err := jsonBody.Bytes()
  363. if err != nil {
  364. return err
  365. }
  366. err = json.Unmarshal(jsonBytes, output)
  367. if err != nil {
  368. return errors.New(err.Error())
  369. }
  370. return nil
  371. }
  372. type QueryPrams struct {
  373. c *Context
  374. queryParams map[string]string
  375. }
  376. // Set 设置查询参数
  377. // 参数:
  378. // - key: 查询参数的键
  379. // - value: 查询参数对应键的值
  380. // 返回值: 无
  381. func (queryParams *QueryPrams) Set(key string, value string) {
  382. queryParams.queryParams[key] = value
  383. queryParams.c.Set(queryParamsKey, queryParams.queryParams)
  384. }
  385. // Delete 删除JsonBody键对应的值
  386. // 参数:
  387. // - key: JsonBody的键
  388. // 返回值: 无
  389. func (queryParams *QueryPrams) Delete(key string) {
  390. delete(queryParams.queryParams, key)
  391. queryParams.c.Set(queryParamsKey, queryParams.queryParams)
  392. }
  393. // Get 获取查询参数对应键的值
  394. // 参数:
  395. // - key: 查询参数的键
  396. // 返回值:
  397. // - 查询参数对应键的值
  398. func (queryParams *QueryPrams) Get(key string) string {
  399. return queryParams.queryParams[key]
  400. }
  401. // Map 获取查询参数的map表示
  402. // 参数: 无
  403. // 返回值:
  404. // - 查询参数的map表示
  405. func (queryParams *QueryPrams) Map() map[string]string {
  406. return queryParams.queryParams
  407. }
  408. type PathPrams struct {
  409. c *Context
  410. pathParams map[string]string
  411. }
  412. // Set 设置路径参数
  413. // 参数:
  414. // - key: 路径参数的键
  415. // - value: 路径参数对应键的值
  416. // 返回值: 无
  417. func (pathParams *PathPrams) Set(key string, value string) {
  418. pathParams.pathParams[key] = value
  419. pathParams.c.Set(pathParamsKey, pathParams.pathParams)
  420. }
  421. // Delete 删除JsonBody键对应的值
  422. // 参数:
  423. // - key: JsonBody的键
  424. // 返回值: 无
  425. func (pathParams *PathPrams) Delete(key string) {
  426. delete(pathParams.pathParams, key)
  427. pathParams.c.Set(pathParamsKey, pathParams.pathParams)
  428. }
  429. // Get 获取路径参数对应键的值
  430. // 参数:
  431. // - key: 路径参数的键
  432. // 返回值:
  433. // - 路径参数对应键的值
  434. func (pathParams *PathPrams) Get(key string) string {
  435. return pathParams.pathParams[key]
  436. }
  437. // Map 获取路径参数的map表示
  438. // 参数: 无
  439. // 返回值:
  440. // - 路径参数的map表示
  441. func (pathParams *PathPrams) Map() map[string]string {
  442. return pathParams.pathParams
  443. }
  444. type TenantInfo interface {
  445. GetID() string
  446. GetName() string
  447. }
  448. type UserInfo interface {
  449. GetID() string
  450. GetUserName() string
  451. GetName() string
  452. }
  453. func (c *Context) SetTenantInfo(tenantInfo TenantInfo) {
  454. c.Set(tenantInfoKey, tenantInfo)
  455. }
  456. func (c *Context) SetUserInfo(userInfo UserInfo) {
  457. c.Set(userInfoKey, userInfo)
  458. }
  459. func (c *Context) GetTenantInfo() TenantInfo {
  460. tenantInfo, exist := c.Get(tenantInfoKey)
  461. if !exist {
  462. return nil
  463. }
  464. return tenantInfo.(TenantInfo)
  465. }
  466. func (c *Context) GetUserInfo() UserInfo {
  467. userInfo, exist := c.Get(userInfoKey)
  468. if !exist {
  469. return nil
  470. }
  471. return userInfo.(UserInfo)
  472. }