middlewares.go 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. package middlewares
  2. import (
  3. "git.sxidc.com/go-framework/baize/convenient/domain/auth/jwt_tools"
  4. "git.sxidc.com/go-framework/baize/convenient/domain/auth/user"
  5. "git.sxidc.com/go-framework/baize/framework/binding"
  6. "git.sxidc.com/go-framework/baize/framework/core/api"
  7. "git.sxidc.com/go-framework/baize/framework/core/api/response"
  8. "git.sxidc.com/go-framework/baize/framework/core/domain"
  9. "git.sxidc.com/go-framework/baize/framework/core/infrastructure"
  10. "git.sxidc.com/go-framework/baize/framework/core/infrastructure/database"
  11. "git.sxidc.com/go-framework/baize/framework/core/infrastructure/database/sql"
  12. "github.com/dgrijalva/jwt-go/request"
  13. "github.com/pkg/errors"
  14. "net/http"
  15. )
  16. func Authentication(dbSchema string, jwtSecretKey string) binding.Middleware {
  17. return func(c *api.Context, i *infrastructure.Infrastructure) {
  18. respFunc := response.SendMapResponse
  19. // 获取token
  20. token, err := request.AuthorizationHeaderExtractor.ExtractToken(c.Request)
  21. if err != nil {
  22. respFunc(c, http.StatusUnauthorized, nil, errors.New(err.Error()))
  23. c.Abort()
  24. return
  25. }
  26. // 校验token
  27. valid, _, err := jwt_tools.CheckJWT(jwtSecretKey, token)
  28. if err != nil {
  29. respFunc(c, http.StatusUnauthorized, nil, errors.New(err.Error()))
  30. c.Abort()
  31. return
  32. }
  33. if !valid {
  34. respFunc(c, http.StatusUnauthorized, nil, errors.New("无效token"))
  35. c.Abort()
  36. return
  37. }
  38. // 获取用户信息
  39. dbExecutor := i.DBExecutor()
  40. // 查询用户
  41. result, err := database.QueryOne(dbExecutor, &sql.QueryOneExecuteParams{
  42. TableName: domain.TableName(dbSchema, &user.Entity{}),
  43. Conditions: sql.NewConditions().Equal(user.ColumnToken, token),
  44. })
  45. if err != nil {
  46. if database.IsErrorDBRecordNotExist(err) {
  47. respFunc(c, http.StatusUnauthorized, nil, errors.New("token对应的用户不存在"))
  48. } else {
  49. respFunc(c, http.StatusUnauthorized, nil, errors.New(err.Error()))
  50. }
  51. c.Abort()
  52. return
  53. }
  54. userInfo := new(user.Info)
  55. err = sql.ParseSqlResult(result, userInfo)
  56. if err != nil {
  57. respFunc(c, http.StatusUnauthorized, nil, errors.New(err.Error()))
  58. c.Abort()
  59. return
  60. }
  61. // 设置用户上下文
  62. c.SetUserInfo(userInfo)
  63. c.Next()
  64. }
  65. }