context.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. package api
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
  6. "github.com/gin-gonic/gin"
  7. "github.com/pkg/errors"
  8. "io"
  9. "mime/multipart"
  10. "strings"
  11. )
  12. const (
  13. bodyKey = "body-context"
  14. queryParamsKey = "query-params-context"
  15. pathParamsKey = "path-params-context"
  16. )
  17. type Context struct {
  18. *gin.Context
  19. }
  20. // GetFileHeaderBytes 获取传递的文件名和文件内容
  21. func (c *Context) GetFileHeaderBytes(fileHeader *multipart.FileHeader) (string, []byte, error) {
  22. file, err := fileHeader.Open()
  23. if err != nil {
  24. return "", nil, errors.New(err.Error())
  25. }
  26. defer func(file multipart.File) {
  27. err := file.Close()
  28. if err != nil {
  29. logger.GetInstance().Error(errors.New(err.Error()))
  30. return
  31. }
  32. }(file)
  33. contentBytes, err := io.ReadAll(file)
  34. if err != nil {
  35. return "", nil, errors.New(err.Error())
  36. }
  37. return fileHeader.Filename, contentBytes, nil
  38. }
  39. func (c *Context) GetHeaders() map[string]string {
  40. headers := make(map[string]string)
  41. for key, values := range c.Request.Header {
  42. headers[key] = strings.Join(values, ",")
  43. }
  44. return headers
  45. }
  46. type CacheBody struct {
  47. c *Context
  48. bytesBody []byte
  49. }
  50. func (cacheBody *CacheBody) Set(bytesBody []byte) {
  51. cacheBody.bytesBody = bytesBody
  52. cacheBody.c.Set(bodyKey, cacheBody.bytesBody)
  53. }
  54. func (cacheBody *CacheBody) Bytes() []byte {
  55. return cacheBody.bytesBody
  56. }
  57. func (c *Context) GetBytesBody() (*CacheBody, error) {
  58. body, exist := c.Get(bodyKey)
  59. if !exist {
  60. bytesBody, err := c.readOriginBody()
  61. if err != nil {
  62. return nil, err
  63. }
  64. return &CacheBody{
  65. c: c,
  66. bytesBody: bytesBody,
  67. }, nil
  68. }
  69. switch b := body.(type) {
  70. case []byte:
  71. return &CacheBody{
  72. c: c,
  73. bytesBody: b,
  74. }, nil
  75. case map[string]any:
  76. bytesBody, err := json.Marshal(b)
  77. if err != nil {
  78. return nil, errors.New(err.Error())
  79. }
  80. return &CacheBody{
  81. c: c,
  82. bytesBody: bytesBody,
  83. }, nil
  84. default:
  85. return nil, errors.New("不支持的body类型")
  86. }
  87. }
  88. type JsonBody struct {
  89. c *Context
  90. jsonBodyMap map[string]any
  91. }
  92. func (jsonBody *JsonBody) Set(key string, value any) {
  93. jsonBody.jsonBodyMap[key] = value
  94. jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap)
  95. }
  96. func (jsonBody *JsonBody) Delete(key string) {
  97. delete(jsonBody.jsonBodyMap, key)
  98. jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap)
  99. }
  100. func (jsonBody *JsonBody) Get(key string) any {
  101. return jsonBody.jsonBodyMap[key]
  102. }
  103. func (jsonBody *JsonBody) Map() map[string]any {
  104. return jsonBody.jsonBodyMap
  105. }
  106. func (jsonBody *JsonBody) Bytes() ([]byte, error) {
  107. jsonBytes, err := json.Marshal(jsonBody.jsonBodyMap)
  108. if err != nil {
  109. return nil, errors.New(err.Error())
  110. }
  111. return jsonBytes, nil
  112. }
  113. func (jsonBody *JsonBody) Unmarshal(output any) error {
  114. jsonBytes, err := jsonBody.Bytes()
  115. if err != nil {
  116. return err
  117. }
  118. err = json.Unmarshal(jsonBytes, output)
  119. if err != nil {
  120. return errors.New(err.Error())
  121. }
  122. return nil
  123. }
  124. func (c *Context) GetJsonBody() (*JsonBody, error) {
  125. body, exist := c.Get(bodyKey)
  126. if !exist {
  127. bytesBody, err := c.readOriginBody()
  128. if err != nil {
  129. return nil, err
  130. }
  131. jsonBodyMap := make(map[string]any)
  132. err = json.Unmarshal(bytesBody, &jsonBodyMap)
  133. if err != nil {
  134. return nil, errors.New(err.Error())
  135. }
  136. return &JsonBody{
  137. c: c,
  138. jsonBodyMap: jsonBodyMap,
  139. }, nil
  140. }
  141. switch b := body.(type) {
  142. case []byte:
  143. jsonBodyMap := make(map[string]any)
  144. err := json.Unmarshal(b, &jsonBodyMap)
  145. if err != nil {
  146. return nil, errors.New(err.Error())
  147. }
  148. return &JsonBody{
  149. c: c,
  150. jsonBodyMap: jsonBodyMap,
  151. }, nil
  152. case map[string]any:
  153. return &JsonBody{
  154. c: c,
  155. jsonBodyMap: body.(map[string]any),
  156. }, nil
  157. default:
  158. return nil, errors.New("不支持的body类型")
  159. }
  160. }
  161. type QueryPrams struct {
  162. c *Context
  163. queryParams map[string]string
  164. }
  165. func (queryParams *QueryPrams) Set(key string, value string) {
  166. queryParams.queryParams[key] = value
  167. queryParams.c.Set(queryParamsKey, queryParams.queryParams)
  168. }
  169. func (queryParams *QueryPrams) Delete(key string) {
  170. delete(queryParams.queryParams, key)
  171. queryParams.c.Set(queryParamsKey, queryParams.queryParams)
  172. }
  173. func (queryParams *QueryPrams) Get(key string) string {
  174. return queryParams.queryParams[key]
  175. }
  176. func (queryParams *QueryPrams) Map() map[string]string {
  177. return queryParams.queryParams
  178. }
  179. func (c *Context) GetQueryParams() *QueryPrams {
  180. queryParams, exist := c.Get(queryParamsKey)
  181. if !exist {
  182. return &QueryPrams{
  183. c: c,
  184. queryParams: c.getAllQueryParams(),
  185. }
  186. }
  187. return &QueryPrams{
  188. c: c,
  189. queryParams: queryParams.(map[string]string),
  190. }
  191. }
  192. type PathPrams struct {
  193. c *Context
  194. pathParams map[string]string
  195. }
  196. func (pathParams *PathPrams) Set(key string, value string) {
  197. pathParams.pathParams[key] = value
  198. pathParams.c.Set(pathParamsKey, pathParams.pathParams)
  199. }
  200. func (pathParams *PathPrams) Delete(key string) {
  201. delete(pathParams.pathParams, key)
  202. pathParams.c.Set(pathParamsKey, pathParams.pathParams)
  203. }
  204. func (pathParams *PathPrams) Get(key string) string {
  205. return pathParams.pathParams[key]
  206. }
  207. func (pathParams *PathPrams) Map() map[string]string {
  208. return pathParams.pathParams
  209. }
  210. func (c *Context) GetPathParams() *PathPrams {
  211. pathParams, exist := c.Get(pathParamsKey)
  212. if !exist {
  213. return &PathPrams{
  214. c: c,
  215. pathParams: c.getAllPathParams(),
  216. }
  217. }
  218. return &PathPrams{
  219. c: c,
  220. pathParams: pathParams.(map[string]string),
  221. }
  222. }
  223. func (c *Context) getAllQueryParams() map[string]string {
  224. queryParams := make(map[string]string)
  225. for key, values := range c.Request.URL.Query() {
  226. queryParams[key] = strings.Join(values, ",")
  227. }
  228. return queryParams
  229. }
  230. func (c *Context) getAllPathParams() map[string]string {
  231. pathParams := make(map[string]string)
  232. for _, params := range c.Params {
  233. pathParams[params.Key] = params.Value
  234. }
  235. return pathParams
  236. }
  237. func (c *Context) readOriginBody() ([]byte, error) {
  238. if c.Request.Body == nil {
  239. return make([]byte, 0), nil
  240. }
  241. body, err := io.ReadAll(c.Request.Body)
  242. if err != nil {
  243. return nil, errors.New(err.Error())
  244. }
  245. defer func(Body io.ReadCloser) {
  246. err := Body.Close()
  247. if err != nil {
  248. logger.GetInstance().Error(errors.New(err.Error()))
  249. return
  250. }
  251. }(c.Request.Body)
  252. c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
  253. c.Set(bodyKey, body)
  254. return body, nil
  255. }
  256. const (
  257. tenantInfoKey = "context-tenant-info"
  258. userInfoKey = "context-user-info"
  259. )
  260. type TenantInfo interface {
  261. GetID() string
  262. GetName() string
  263. }
  264. type UserInfo interface {
  265. GetID() string
  266. GetUserName() string
  267. GetName() string
  268. }
  269. func (c *Context) SetTenantInfo(tenantInfo TenantInfo) {
  270. c.Set(tenantInfoKey, tenantInfo)
  271. }
  272. func (c *Context) SetUserInfo(userInfo UserInfo) {
  273. c.Set(userInfoKey, userInfo)
  274. }
  275. func (c *Context) GetTenantInfo() TenantInfo {
  276. tenantInfo, exist := c.Get(tenantInfoKey)
  277. if !exist {
  278. return nil
  279. }
  280. return tenantInfo.(TenantInfo)
  281. }
  282. func (c *Context) GetUserInfo() UserInfo {
  283. userInfo, exist := c.Get(userInfoKey)
  284. if !exist {
  285. return nil
  286. }
  287. return userInfo.(UserInfo)
  288. }