package api import ( "bytes" "encoding/json" "git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger" "github.com/gin-gonic/gin" "github.com/pkg/errors" "io" "mime/multipart" "net/textproto" "strings" ) const ( bodyKey = "body-context" queryParamsKey = "query-params-context" pathParamsKey = "path-params-context" headerKey = "header-context" tenantInfoKey = "context-tenant-info" userInfoKey = "context-user-info" ) type Context struct { *gin.Context } // GetFileHeaderBytes 获取Multipart中传递的文件名和文件内容 // 参数: // - fileHeader: Multipart的文件头 // 返回值: // - 文件名 // - 文件字节 // - 错误 func (c *Context) GetFileHeaderBytes(fileHeader *multipart.FileHeader) (string, []byte, error) { file, err := fileHeader.Open() if err != nil { return "", nil, errors.New(err.Error()) } defer func(file multipart.File) { err := file.Close() if err != nil { logger.GetInstance().Error(errors.New(err.Error())) return } }(file) contentBytes, err := io.ReadAll(file) if err != nil { return "", nil, errors.New(err.Error()) } return fileHeader.Filename, contentBytes, nil } // GetHeader 获取上下文中的Header // 参数: 无 // 返回值: // - 上下文中的Header func (c *Context) GetHeader() *Header { savedHeader, exist := c.Get(headerKey) if exist { return &Header{ c: c, header: savedHeader.(map[string]string), } } return &Header{ c: c, header: c.ReadOriginHeader(), } } func (c *Context) ReadOriginHeader() map[string]string { originHeader := make(map[string]string) for key, values := range c.Request.Header { originHeader[key] = strings.Join(values, ",") } return originHeader } // GetBytesBody 获取上下文中的字节Body // 参数: 无 // 返回值: // - 上下文中的字节Body // - 错误 func (c *Context) GetBytesBody() (*BytesBody, error) { body, exist := c.Get(bodyKey) if !exist { bytesBody, err := c.ReadOriginBody() if err != nil { return nil, err } c.Set(bodyKey, bytesBody) return &BytesBody{ c: c, bytesBody: bytesBody, }, nil } switch b := body.(type) { case []byte: return &BytesBody{ c: c, bytesBody: b, }, nil case map[string]any: bytesBody, err := json.Marshal(b) if err != nil { return nil, errors.New(err.Error()) } return &BytesBody{ c: c, bytesBody: bytesBody, }, nil default: return nil, errors.New("不支持的body类型") } } // GetJsonBody 获取上下文中的JsonBody // 参数: 无 // 返回值: // - 上下文中的JsonBody // - 错误 func (c *Context) GetJsonBody() (*JsonBody, error) { body, exist := c.Get(bodyKey) if !exist { bytesBody, err := c.ReadOriginBody() if err != nil { return nil, err } if bytesBody == nil || len(bytesBody) == 0 { return &JsonBody{ c: c, jsonBodyMap: make(map[string]any), }, nil } jsonBodyMap := make(map[string]any) err = json.Unmarshal(bytesBody, &jsonBodyMap) if err != nil { return nil, errors.New(err.Error()) } c.Set(bodyKey, jsonBodyMap) return &JsonBody{ c: c, jsonBodyMap: jsonBodyMap, }, nil } switch b := body.(type) { case []byte: jsonBodyMap := make(map[string]any) if b == nil || len(b) == 0 { return &JsonBody{ c: c, jsonBodyMap: jsonBodyMap, }, nil } err := json.Unmarshal(b, &jsonBodyMap) if err != nil { return nil, errors.New(err.Error()) } return &JsonBody{ c: c, jsonBodyMap: jsonBodyMap, }, nil case map[string]any: return &JsonBody{ c: c, jsonBodyMap: body.(map[string]any), }, nil default: return nil, errors.New("不支持的body类型") } } // GetQueryParams 获取上下文中的查询参数 // 参数: 无 // 返回值: // - 上下文中的查询参数 // - 错误 func (c *Context) GetQueryParams() *QueryPrams { queryParams, exist := c.Get(queryParamsKey) if !exist { originQueryParams := c.GetOriginQueryParams() c.Set(queryParamsKey, originQueryParams) return &QueryPrams{ c: c, queryParams: originQueryParams, } } return &QueryPrams{ c: c, queryParams: queryParams.(map[string]string), } } // GetPathParams 获取上下文中的路径参数 // 参数: 无 // 返回值: // - 上下文中的路径参数 // - 错误 func (c *Context) GetPathParams() *PathPrams { pathParams, exist := c.Get(pathParamsKey) if !exist { originPathParams := c.GetOriginPathParams() c.Set(pathParamsKey, originPathParams) return &PathPrams{ c: c, pathParams: c.GetOriginPathParams(), } } return &PathPrams{ c: c, pathParams: pathParams.(map[string]string), } } // ReadOriginBody 获取上下文中的原始字节Body // 参数: 无 // 返回值: // - 上下文中的原始字节Body // - 错误 func (c *Context) ReadOriginBody() ([]byte, error) { if c.Request.Body == nil { return make([]byte, 0), nil } body, err := io.ReadAll(c.Request.Body) if err != nil { return nil, errors.New(err.Error()) } defer func(Body io.ReadCloser) { err := Body.Close() if err != nil { logger.GetInstance().Error(errors.New(err.Error())) return } }(c.Request.Body) c.Request.Body = io.NopCloser(bytes.NewBuffer(body)) c.Set(bodyKey, body) return body, nil } // GetOriginQueryParams 获取上下文中的原始查询参数 // 参数: 无 // 返回值: // - 上下文中的原始查询参数 func (c *Context) GetOriginQueryParams() map[string]string { queryParams := make(map[string]string) for key, values := range c.Request.URL.Query() { queryParams[key] = strings.Join(values, ",") } return queryParams } // GetOriginPathParams 获取上下文中的原始路径参数 // 参数: 无 // 返回值: // - 上下文中的原始路径参数 func (c *Context) GetOriginPathParams() map[string]string { pathParams := make(map[string]string) for _, params := range c.Params { pathParams[params.Key] = params.Value } return pathParams } type Header struct { c *Context header map[string]string } // Reload 从Context重新加载Header // 参数: 无 // 返回值: 无 func (header *Header) Reload() { originHeader := header.c.ReadOriginHeader() header.c.Set(headerKey, originHeader) header.header = originHeader } // Set 设置Header // 参数: // - key: Header的键 // - value: Header对应键的值 // 返回值: 无 func (header *Header) Set(key string, value string) { header.header[key] = value header.c.Set(headerKey, header.header) } // Get 获取Header对应键的值 // 参数: // - key: Header的键 // 返回值: // - Header对应键的值 func (header *Header) Get(key string) string { mineHeader := textproto.MIMEHeader{} for k, v := range header.header { mineHeader.Set(k, v) } return mineHeader.Get(key) } // Map 获取Header的map表示 // 参数: 无 // 返回值: // - Header的map表示 func (header *Header) Map() map[string]string { return header.header } type BytesBody struct { c *Context bytesBody []byte } // Reload 从Context重新加载BytesBody // 参数: 无 // 返回值: // - 错误 func (bytesBody *BytesBody) Reload() error { originBody, err := bytesBody.c.ReadOriginBody() if err != nil { return err } bytesBody.c.Set(headerKey, originBody) bytesBody.bytesBody = originBody return nil } // Set 设置字节body // 参数: // - body: 字节body的内容 // 返回值: 无 func (bytesBody *BytesBody) Set(body []byte) { bytesBody.bytesBody = body bytesBody.c.Set(bodyKey, bytesBody.bytesBody) } // Bytes 获取字节body的内容 // 参数: 无 // 返回值: // - 字节body的内容 func (bytesBody *BytesBody) Bytes() []byte { return bytesBody.bytesBody } // Marshal 将input Marshal到字节body // 参数: // - input: 输入,一般为结构或map[string]any // 返回值: // - 错误 func (bytesBody *BytesBody) Marshal(input any) error { jsonBytesBody, err := json.Marshal(input) if err != nil { return errors.New(err.Error()) } bytesBody.Set(jsonBytesBody) return nil } // Unmarshal 将字节body的内容Unmarshal到output // 参数: // - output: 输出,一般为结构指针或map[string]any指针 // 返回值: // - 错误 func (bytesBody *BytesBody) Unmarshal(output any) error { err := json.Unmarshal(bytesBody.Bytes(), output) if err != nil { return errors.New(err.Error()) } return nil } type JsonBody struct { c *Context jsonBodyMap map[string]any } // Reload 从Context重新加载JsonBody // 参数: 无 // 返回值: // - 错误 func (jsonBody *JsonBody) Reload() error { bytesBody, err := jsonBody.c.ReadOriginBody() if err != nil { return err } jsonBodyMap := make(map[string]any) if bytesBody == nil || len(bytesBody) == 0 { jsonBody.c.Set(bodyKey, jsonBodyMap) jsonBody.jsonBodyMap = jsonBodyMap return nil } err = json.Unmarshal(bytesBody, &jsonBodyMap) if err != nil { return errors.New(err.Error()) } jsonBody.c.Set(bodyKey, jsonBodyMap) jsonBody.jsonBodyMap = jsonBodyMap return nil } // Set 设置JsonBody键对应的值 // 参数: // - key: JsonBody的键 // - value: JsonBody对应键的值 // 返回值: 无 func (jsonBody *JsonBody) Set(key string, value any) { jsonBody.jsonBodyMap[key] = value jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap) } // Delete 删除JsonBody键对应的值 // 参数: // - key: JsonBody的键 // 返回值: 无 func (jsonBody *JsonBody) Delete(key string) { delete(jsonBody.jsonBodyMap, key) jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap) } // Get 获取JsonBody对应键的值 // 参数: // - key: JsonBody的键 // 返回值: // - JsonBody对应键的值 func (jsonBody *JsonBody) Get(key string) any { return jsonBody.jsonBodyMap[key] } // SetMap 使用map设置JsonBody // 参数: // - mapBody: map表示的body // 返回值: 无 func (jsonBody *JsonBody) SetMap(mapBody map[string]any) { jsonBody.jsonBodyMap = mapBody jsonBody.c.Set(bodyKey, jsonBody.jsonBodyMap) } // Map 获取JsonBody的map表示 // 参数: 无 // 返回值: // - JsonBody的map表示 func (jsonBody *JsonBody) Map() map[string]any { return jsonBody.jsonBodyMap } // Bytes 获取JsonBody的内容 // 参数: 无 // 返回值: // - JsonBody的内容 func (jsonBody *JsonBody) Bytes() ([]byte, error) { if jsonBody.jsonBodyMap == nil || len(jsonBody.jsonBodyMap) == 0 { return make([]byte, 0), nil } jsonBytes, err := json.Marshal(jsonBody.jsonBodyMap) if err != nil { return nil, errors.New(err.Error()) } return jsonBytes, nil } // Unmarshal 将JsonBody的内容Unmarshal到output // 参数: // - output: 输出,一般为结构指针或map[string]any指针 // 返回值: // - 错误 func (jsonBody *JsonBody) Unmarshal(output any) error { jsonBytes, err := jsonBody.Bytes() if err != nil { return err } if jsonBytes == nil || len(jsonBytes) == 0 { return nil } err = json.Unmarshal(jsonBytes, output) if err != nil { return errors.New(err.Error()) } return nil } type QueryPrams struct { c *Context queryParams map[string]string } // Reload 从Context重新加载QueryParams // 参数: 无 // 返回值: 无 func (queryParams *QueryPrams) Reload() { originQueryParams := queryParams.c.GetOriginQueryParams() queryParams.c.Set(queryParamsKey, originQueryParams) queryParams.queryParams = originQueryParams } // Set 设置查询参数 // 参数: // - key: 查询参数的键 // - value: 查询参数对应键的值 // 返回值: 无 func (queryParams *QueryPrams) Set(key string, value string) { queryParams.queryParams[key] = value queryParams.c.Set(queryParamsKey, queryParams.queryParams) } // Delete 删除JsonBody键对应的值 // 参数: // - key: JsonBody的键 // 返回值: 无 func (queryParams *QueryPrams) Delete(key string) { delete(queryParams.queryParams, key) queryParams.c.Set(queryParamsKey, queryParams.queryParams) } // Get 获取查询参数对应键的值 // 参数: // - key: 查询参数的键 // 返回值: // - 查询参数对应键的值 func (queryParams *QueryPrams) Get(key string) string { return queryParams.queryParams[key] } // SetMap 使用map设置查询参数 // 参数: // - mapQueryParams: map表示的查询参数 // 返回值: 无 func (queryParams *QueryPrams) SetMap(mapQueryParams map[string]string) { queryParams.queryParams = mapQueryParams queryParams.c.Set(queryParamsKey, queryParams.queryParams) } // Map 获取查询参数的map表示 // 参数: 无 // 返回值: // - 查询参数的map表示 func (queryParams *QueryPrams) Map() map[string]string { return queryParams.queryParams } type PathPrams struct { c *Context pathParams map[string]string } // Reload 从Context重新加载QueryParams // 参数: 无 // 返回值: 无 func (pathParams *PathPrams) Reload() { originPathParams := pathParams.c.GetOriginPathParams() pathParams.c.Set(pathParamsKey, originPathParams) pathParams.pathParams = originPathParams } // Set 设置路径参数 // 参数: // - key: 路径参数的键 // - value: 路径参数对应键的值 // 返回值: 无 func (pathParams *PathPrams) Set(key string, value string) { pathParams.pathParams[key] = value pathParams.c.Set(pathParamsKey, pathParams.pathParams) } // Delete 删除JsonBody键对应的值 // 参数: // - key: JsonBody的键 // 返回值: 无 func (pathParams *PathPrams) Delete(key string) { delete(pathParams.pathParams, key) pathParams.c.Set(pathParamsKey, pathParams.pathParams) } // Get 获取路径参数对应键的值 // 参数: // - key: 路径参数的键 // 返回值: // - 路径参数对应键的值 func (pathParams *PathPrams) Get(key string) string { return pathParams.pathParams[key] } // SetMap 使用map设置路径参数 // 参数: // - mapPathParams: map表示的路径参数 // 返回值: 无 func (pathParams *PathPrams) SetMap(mapPathParams map[string]string) { pathParams.pathParams = mapPathParams pathParams.c.Set(pathParamsKey, pathParams.pathParams) } // Map 获取路径参数的map表示 // 参数: 无 // 返回值: // - 路径参数的map表示 func (pathParams *PathPrams) Map() map[string]string { return pathParams.pathParams } type TenantInfo interface { GetID() string GetName() string } type UserInfo interface { GetID() string GetUserName() string GetName() string } func (c *Context) SetTenantInfo(tenantInfo TenantInfo) { c.Set(tenantInfoKey, tenantInfo) } func (c *Context) SetUserInfo(userInfo UserInfo) { c.Set(userInfoKey, userInfo) } func (c *Context) GetTenantInfo() TenantInfo { tenantInfo, exist := c.Get(tenantInfoKey) if !exist { return nil } return tenantInfo.(TenantInfo) } func (c *Context) GetUserInfo() UserInfo { userInfo, exist := c.Get(userInfoKey) if !exist { return nil } return userInfo.(UserInfo) }