package gwtools import ( "encoding/json" "git.sxidc.com/go-framework/baize/framework/core/api" "git.sxidc.com/go-framework/baize/framework/gateway" "git.sxidc.com/service-supports/fserr" ) type GetTenantIDFunc func(c *api.Context) (string, error) type GetUserIDFunc func(c *api.Context) (string, error) func AddBodyTenantIDAndUserID(tenantIDFieldName string, userIDFieldName string, getTenantIDFunc GetTenantIDFunc, getUserIDFunc GetUserIDFunc) gateway.FormBodyFunc { return func(c *api.Context, historyRequests []gateway.BuilderRequest, customResultMap map[string]any) (any, error) { body, err := gateway.DefaultFormBodyFunc(c, historyRequests, customResultMap) if err != nil { return nil, err } if getTenantIDFunc == nil && getUserIDFunc == nil { return body, nil } bodyBytes, ok := body.([]byte) if !ok { return nil, fserr.New("body不是json") } bodyMap := make(map[string]any) err = json.Unmarshal(bodyBytes, &bodyMap) if err != nil { return nil, err } if getTenantIDFunc != nil { _, ok := bodyMap[tenantIDFieldName] if !ok { tenantID, err := getTenantIDFunc(c) if err != nil { return nil, err } bodyMap[tenantIDFieldName] = tenantID } } if getUserIDFunc != nil { _, ok := bodyMap[userIDFieldName] if !ok { userID, err := getUserIDFunc(c) if err != nil { return nil, err } bodyMap[userIDFieldName] = userID } } newBody, err := json.Marshal(bodyMap) if err != nil { return nil, err } err = c.ReplaceBody(newBody) if err != nil { return nil, err } return newBody, nil } } func AddQueryParamsTenantIDAndUserID(tenantIDFieldName string, userIDFieldName string, getTenantIDFunc GetTenantIDFunc, getUserIDFunc GetUserIDFunc) gateway.FormQueryParamsFunc { return func(c *api.Context, historyRequests []gateway.BuilderRequest, customResultMap map[string]any) (map[string]string, error) { queryParams, err := gateway.DefaultFormQueryParamsFunc(c, historyRequests, customResultMap) if err != nil { return nil, err } if getTenantIDFunc == nil && getUserIDFunc == nil { return queryParams, nil } if getTenantIDFunc != nil { _, ok := queryParams[tenantIDFieldName] if !ok { tenantID, err := getTenantIDFunc(c) if err != nil { return nil, err } queryParams[tenantIDFieldName] = tenantID } } if getUserIDFunc != nil { _, ok := queryParams[userIDFieldName] if !ok { userID, err := getUserIDFunc(c) if err != nil { return nil, err } queryParams[userIDFieldName] = userID } } return queryParams, nil } }