context.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package api
  2. import (
  3. "bytes"
  4. "git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
  5. "github.com/gin-gonic/gin"
  6. "io"
  7. "mime/multipart"
  8. "strings"
  9. )
  10. type Context struct {
  11. *gin.Context
  12. }
  13. // GetFileHeaderBytes 获取传递的文件名和文件内容
  14. func (c *Context) GetFileHeaderBytes(fileHeader *multipart.FileHeader) (string, []byte, error) {
  15. file, err := fileHeader.Open()
  16. if err != nil {
  17. return "", nil, err
  18. }
  19. defer func(file multipart.File) {
  20. err := file.Close()
  21. if err != nil {
  22. logger.GetInstance().Error(err)
  23. return
  24. }
  25. }(file)
  26. contentBytes, err := io.ReadAll(file)
  27. if err != nil {
  28. return "", nil, err
  29. }
  30. return fileHeader.Filename, contentBytes, nil
  31. }
  32. func (c *Context) GetHeaders() map[string]string {
  33. headers := make(map[string]string, 0)
  34. for key, values := range c.Request.Header {
  35. headers[key] = strings.Join(values, ",")
  36. }
  37. return headers
  38. }
  39. func (c *Context) ReadBody() ([]byte, error) {
  40. if c.Request.Body == nil {
  41. return make([]byte, 0), nil
  42. }
  43. body, err := io.ReadAll(c.Request.Body)
  44. if err != nil {
  45. return nil, err
  46. }
  47. defer func(Body io.ReadCloser) {
  48. err := Body.Close()
  49. if err != nil {
  50. logger.GetInstance().Error(err)
  51. return
  52. }
  53. }(c.Request.Body)
  54. c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
  55. return body, nil
  56. }
  57. func (c *Context) ReplaceBody(body []byte) error {
  58. if c.Request.Body != nil {
  59. err := c.Request.Body.Close()
  60. if err != nil {
  61. return err
  62. }
  63. }
  64. c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
  65. return nil
  66. }
  67. func (c *Context) GetAllQueryParams() map[string]string {
  68. queryParams := make(map[string]string, 0)
  69. for key, values := range c.Request.URL.Query() {
  70. queryParams[key] = strings.Join(values, ",")
  71. }
  72. return queryParams
  73. }
  74. func (c *Context) GetAllPathParams() map[string]string {
  75. pathParams := make(map[string]string, 0)
  76. for _, params := range c.Params {
  77. pathParams[params.Key] = params.Value
  78. }
  79. return pathParams
  80. }