context.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. package api
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
  6. "github.com/gin-gonic/gin"
  7. "github.com/pkg/errors"
  8. "io"
  9. "mime/multipart"
  10. "strings"
  11. )
  12. const (
  13. bodyKey = "body-context"
  14. queryParamsKey = "query-params-context"
  15. pathParamsKey = "path-params-context"
  16. )
  17. type Context struct {
  18. *gin.Context
  19. }
  20. // GetFileHeaderBytes 获取传递的文件名和文件内容
  21. func (c *Context) GetFileHeaderBytes(fileHeader *multipart.FileHeader) (string, []byte, error) {
  22. file, err := fileHeader.Open()
  23. if err != nil {
  24. return "", nil, errors.New(err.Error())
  25. }
  26. defer func(file multipart.File) {
  27. err := file.Close()
  28. if err != nil {
  29. logger.GetInstance().Error(errors.New(err.Error()))
  30. return
  31. }
  32. }(file)
  33. contentBytes, err := io.ReadAll(file)
  34. if err != nil {
  35. return "", nil, errors.New(err.Error())
  36. }
  37. return fileHeader.Filename, contentBytes, nil
  38. }
  39. func (c *Context) GetHeaders() map[string]string {
  40. headers := make(map[string]string)
  41. for key, values := range c.Request.Header {
  42. headers[key] = strings.Join(values, ",")
  43. }
  44. return headers
  45. }
  46. type CacheBody struct {
  47. c *Context
  48. bytesBody []byte
  49. }
  50. func (cacheBody *CacheBody) Set(bytesBody []byte) {
  51. cacheBody.bytesBody = bytesBody
  52. cacheBody.c.Set(bodyKey, cacheBody.bytesBody)
  53. }
  54. func (cacheBody *CacheBody) Bytes() []byte {
  55. return cacheBody.bytesBody
  56. }
  57. func (cacheBody *CacheBody) Marshal(output any) error {
  58. bytesBody, err := json.Marshal(output)
  59. if err != nil {
  60. return errors.New(err.Error())
  61. }
  62. cacheBody.Set(bytesBody)
  63. return nil
  64. }
  65. func (cacheBody *CacheBody) Unmarshal(output any) error {
  66. err := json.Unmarshal(cacheBody.Bytes(), output)
  67. if err != nil {
  68. return errors.New(err.Error())
  69. }
  70. return nil
  71. }
  72. func (c *Context) GetBytesBody() (*CacheBody, error) {
  73. body, exist := c.Get(bodyKey)
  74. if !exist {
  75. bytesBody, err := c.readOriginBody()
  76. if err != nil {
  77. return nil, err
  78. }
  79. return &CacheBody{
  80. c: c,
  81. bytesBody: bytesBody,
  82. }, nil
  83. }
  84. switch b := body.(type) {
  85. case []byte:
  86. return &CacheBody{
  87. c: c,
  88. bytesBody: b,
  89. }, nil
  90. case map[string]any:
  91. bytesBody, err := json.Marshal(b)
  92. if err != nil {
  93. return nil, errors.New(err.Error())
  94. }
  95. return &CacheBody{
  96. c: c,
  97. bytesBody: bytesBody,
  98. }, nil
  99. default:
  100. return nil, errors.New("不支持的body类型")
  101. }
  102. }
  103. type JsonBody struct {
  104. c *Context
  105. jsonBodyMap map[string]any
  106. }
  107. func (jsonBody *JsonBody) Set(key string, value any) {
  108. jsonBody.jsonBodyMap[key] = value
  109. jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap)
  110. }
  111. func (jsonBody *JsonBody) Delete(key string) {
  112. delete(jsonBody.jsonBodyMap, key)
  113. jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap)
  114. }
  115. func (jsonBody *JsonBody) Get(key string) any {
  116. return jsonBody.jsonBodyMap[key]
  117. }
  118. func (jsonBody *JsonBody) Map() map[string]any {
  119. return jsonBody.jsonBodyMap
  120. }
  121. func (jsonBody *JsonBody) Bytes() ([]byte, error) {
  122. jsonBytes, err := json.Marshal(jsonBody.jsonBodyMap)
  123. if err != nil {
  124. return nil, errors.New(err.Error())
  125. }
  126. return jsonBytes, nil
  127. }
  128. func (jsonBody *JsonBody) Unmarshal(output any) error {
  129. jsonBytes, err := jsonBody.Bytes()
  130. if err != nil {
  131. return err
  132. }
  133. err = json.Unmarshal(jsonBytes, output)
  134. if err != nil {
  135. return errors.New(err.Error())
  136. }
  137. return nil
  138. }
  139. func (c *Context) GetJsonBody() (*JsonBody, error) {
  140. body, exist := c.Get(bodyKey)
  141. if !exist {
  142. bytesBody, err := c.readOriginBody()
  143. if err != nil {
  144. return nil, err
  145. }
  146. jsonBodyMap := make(map[string]any)
  147. err = json.Unmarshal(bytesBody, &jsonBodyMap)
  148. if err != nil {
  149. return nil, errors.New(err.Error())
  150. }
  151. return &JsonBody{
  152. c: c,
  153. jsonBodyMap: jsonBodyMap,
  154. }, nil
  155. }
  156. switch b := body.(type) {
  157. case []byte:
  158. jsonBodyMap := make(map[string]any)
  159. err := json.Unmarshal(b, &jsonBodyMap)
  160. if err != nil {
  161. return nil, errors.New(err.Error())
  162. }
  163. return &JsonBody{
  164. c: c,
  165. jsonBodyMap: jsonBodyMap,
  166. }, nil
  167. case map[string]any:
  168. return &JsonBody{
  169. c: c,
  170. jsonBodyMap: body.(map[string]any),
  171. }, nil
  172. default:
  173. return nil, errors.New("不支持的body类型")
  174. }
  175. }
  176. type QueryPrams struct {
  177. c *Context
  178. queryParams map[string]string
  179. }
  180. func (queryParams *QueryPrams) Set(key string, value string) {
  181. queryParams.queryParams[key] = value
  182. queryParams.c.Set(queryParamsKey, queryParams.queryParams)
  183. }
  184. func (queryParams *QueryPrams) Delete(key string) {
  185. delete(queryParams.queryParams, key)
  186. queryParams.c.Set(queryParamsKey, queryParams.queryParams)
  187. }
  188. func (queryParams *QueryPrams) Get(key string) string {
  189. return queryParams.queryParams[key]
  190. }
  191. func (queryParams *QueryPrams) Map() map[string]string {
  192. return queryParams.queryParams
  193. }
  194. func (c *Context) GetQueryParams() *QueryPrams {
  195. queryParams, exist := c.Get(queryParamsKey)
  196. if !exist {
  197. return &QueryPrams{
  198. c: c,
  199. queryParams: c.getAllQueryParams(),
  200. }
  201. }
  202. return &QueryPrams{
  203. c: c,
  204. queryParams: queryParams.(map[string]string),
  205. }
  206. }
  207. type PathPrams struct {
  208. c *Context
  209. pathParams map[string]string
  210. }
  211. func (pathParams *PathPrams) Set(key string, value string) {
  212. pathParams.pathParams[key] = value
  213. pathParams.c.Set(pathParamsKey, pathParams.pathParams)
  214. }
  215. func (pathParams *PathPrams) Delete(key string) {
  216. delete(pathParams.pathParams, key)
  217. pathParams.c.Set(pathParamsKey, pathParams.pathParams)
  218. }
  219. func (pathParams *PathPrams) Get(key string) string {
  220. return pathParams.pathParams[key]
  221. }
  222. func (pathParams *PathPrams) Map() map[string]string {
  223. return pathParams.pathParams
  224. }
  225. func (c *Context) GetPathParams() *PathPrams {
  226. pathParams, exist := c.Get(pathParamsKey)
  227. if !exist {
  228. return &PathPrams{
  229. c: c,
  230. pathParams: c.getAllPathParams(),
  231. }
  232. }
  233. return &PathPrams{
  234. c: c,
  235. pathParams: pathParams.(map[string]string),
  236. }
  237. }
  238. func (c *Context) getAllQueryParams() map[string]string {
  239. queryParams := make(map[string]string)
  240. for key, values := range c.Request.URL.Query() {
  241. queryParams[key] = strings.Join(values, ",")
  242. }
  243. return queryParams
  244. }
  245. func (c *Context) getAllPathParams() map[string]string {
  246. pathParams := make(map[string]string)
  247. for _, params := range c.Params {
  248. pathParams[params.Key] = params.Value
  249. }
  250. return pathParams
  251. }
  252. func (c *Context) readOriginBody() ([]byte, error) {
  253. if c.Request.Body == nil {
  254. return make([]byte, 0), nil
  255. }
  256. body, err := io.ReadAll(c.Request.Body)
  257. if err != nil {
  258. return nil, errors.New(err.Error())
  259. }
  260. defer func(Body io.ReadCloser) {
  261. err := Body.Close()
  262. if err != nil {
  263. logger.GetInstance().Error(errors.New(err.Error()))
  264. return
  265. }
  266. }(c.Request.Body)
  267. c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
  268. c.Set(bodyKey, body)
  269. return body, nil
  270. }
  271. const (
  272. tenantInfoKey = "context-tenant-info"
  273. userInfoKey = "context-user-info"
  274. )
  275. type TenantInfo interface {
  276. GetID() string
  277. GetName() string
  278. }
  279. type UserInfo interface {
  280. GetID() string
  281. GetUserName() string
  282. GetName() string
  283. }
  284. func (c *Context) SetTenantInfo(tenantInfo TenantInfo) {
  285. c.Set(tenantInfoKey, tenantInfo)
  286. }
  287. func (c *Context) SetUserInfo(userInfo UserInfo) {
  288. c.Set(userInfoKey, userInfo)
  289. }
  290. func (c *Context) GetTenantInfo() TenantInfo {
  291. tenantInfo, exist := c.Get(tenantInfoKey)
  292. if !exist {
  293. return nil
  294. }
  295. return tenantInfo.(TenantInfo)
  296. }
  297. func (c *Context) GetUserInfo() UserInfo {
  298. userInfo, exist := c.Get(userInfoKey)
  299. if !exist {
  300. return nil
  301. }
  302. return userInfo.(UserInfo)
  303. }