Parcourir la source

服务码设置并发安全问题处理

jys il y a 9 mois
Parent
commit
7ebc64f7cd
4 fichiers modifiés avec 44 ajouts et 37 suppressions
  1. 16 3
      code.go
  2. 3 3
      errors_test.go
  3. 19 27
      public.go
  4. 6 4
      public_test.go

+ 16 - 3
code.go

@@ -2,6 +2,7 @@ package fserr
 
 import (
 	"net/http"
+	"sync/atomic"
 )
 
 type codeType interface {
@@ -9,7 +10,19 @@ type codeType interface {
 		uint8 | uint16 | uint32 | uint64 | uint
 }
 
-var serviceCode int
+type atomicServiceCode struct {
+	atomic.Int64
+}
+
+func (a *atomicServiceCode) Load() int {
+	return int(a.Int64.Load())
+}
+
+func (a *atomicServiceCode) Store(v int64) {
+	a.Int64.Store(v)
+}
+
+var serviceCode atomicServiceCode
 
 const (
 	ErrBasic = iota + 1
@@ -37,12 +50,12 @@ type ErrCode struct {
 // 当错误码匹配失败时,提供的备选方案,已内置默认错误码,
 // 它的HTTP码为200,业务码和信息均为零值
 func SetDefault(httpCode, businessCode int, message string) {
-	defaultErrCode = ErrCode{httpCode, serviceCode + businessCode, message}
+	defaultErrCode = ErrCode{httpCode, serviceCode.Load() + businessCode, message}
 }
 
 // NewCode 创建指定信息的错误码
 func NewCode(httpCode, businessCode int, message string) ErrCode {
-	code := ErrCode{httpCode, serviceCode + businessCode, message}
+	code := ErrCode{httpCode, serviceCode.Load() + businessCode, message}
 	register(code)
 	return code
 }

+ 3 - 3
errors_test.go

@@ -56,10 +56,10 @@ func (s *TestErrorsSuite) TestStack() {
 func (s *TestErrorsSuite) TestMessage() {
 	s.Equal(s.originStack, s.originMessage.Cause())
 	s.Equal(s.originStack, s.originMessage.Unwrap())
-	s.Equal("origin message: origin fundamental", s.originMessage.Error())
-	s.Equal("origin message: origin fundamental",
+	s.Equal("origin message", s.originMessage.Error())
+	s.Equal("origin message",
 		fmt.Sprintf("%s", s.originMessage))
-	s.Equal(`"origin message: origin fundamental"`,
+	s.Equal(`"origin message"`,
 		fmt.Sprintf("%q", s.originMessage))
 }
 

+ 19 - 27
public.go

@@ -39,7 +39,7 @@ func Wrap(err error, format string, args ...any) error {
 // WithCode 创建带有错误码的error,支持格式化占位符
 // 使用option可以替换其中信息
 func WithCode[T codeType](err error, businessCode T, options ...Option) error {
-	code := getCode(serviceCode + int(businessCode))
+	code := getCode(serviceCode.Load() + int(businessCode))
 	ret := &withCode{
 		cause:        wrapStack(err),
 		Msg:          code.Message,
@@ -101,20 +101,20 @@ func As(err error, target any) bool { return errors.As(err, target) }
 // 若想得到原始的错误码错误,可以使用As方法
 func ParseCode(err error) *withCode {
 	var target *withCode
-	if !As(err, &target) {
+	if As(err, &target) {
 		return &withCode{
-			cause:        nil,
-			Msg:          err.Error(),
-			HttpCode:     defaultErrCode.HttpCode,
-			BusinessCode: serviceCode,
+			cause:        target.cause,
+			Msg:          outerMsg(err),
+			HttpCode:     target.HttpCode,
+			BusinessCode: target.BusinessCode,
 		}
 	}
 
 	return &withCode{
-		cause:        target.cause,
-		Msg:          outerMsg(err),
-		HttpCode:     target.HttpCode,
-		BusinessCode: target.BusinessCode,
+		cause:        nil,
+		Msg:          err.Error(),
+		HttpCode:     defaultErrCode.HttpCode,
+		BusinessCode: serviceCode.Load(),
 	}
 }
 
@@ -124,14 +124,14 @@ func IsCode[T codeType](err error, code T) bool {
 	if !As(err, &target) {
 		return false
 	}
-	return target.BusinessCode == serviceCode+int(code)
+	return target.BusinessCode == serviceCode.Load()+int(code)
 }
 
 // SetAppCode 设置服务错误码
 // 模块码、模块错误码一共四位,指定应用码将拼接在前
 // 例如应用码位101,模块码为1,模块错误码为21,那么最终业务错误码为:1010121
 func SetAppCode[T codeType](code T) {
-	serviceCode = int(code) * 10000
+	serviceCode.Store(int64(code) * 10000)
 }
 
 // outerMsg 获取最外层的错误信息
@@ -152,21 +152,13 @@ func outerMsg(err error) string {
 }
 
 func wrapStack(err error) error {
-	if _, ok := err.(*fundamental); ok {
-		return err
-	}
-	if _, ok := err.(*withCode); ok {
-		return err
-	}
-	if _, ok := err.(*withMessage); ok {
-		return err
-	}
-	if _, ok := err.(*withStack); ok {
+	switch err.(type) {
+	case *fundamental, *withCode, *withMessage, *withStack:
 		return err
-	}
-
-	return &withStack{
-		error: err,
-		stack: callers(),
+	default:
+		return &withStack{
+			error: err,
+			stack: callers(),
+		}
 	}
 }

+ 6 - 4
public_test.go

@@ -13,7 +13,7 @@ type TestPublicSuite struct {
 	suite.Suite
 	outerErr,
 	newErr, newFmtErr, stackErr,
-	wrapErr, wrapFmtErr, wrapOuterErr, wrapNilErr,
+	wrapErr, wrapFmtErr, wrapOuterErr, wrapNilErr, wrapEmptyErr,
 	codeErr, codeOptionErr, codeOuterErr, codeNilErr error
 	errBasicMsg string
 }
@@ -32,6 +32,7 @@ func (s *TestPublicSuite) SetupTest() {
 	s.wrapFmtErr = Wrap(s.newErr, "wrap %s", "error")
 	s.wrapOuterErr = Wrap(s.outerErr, "wrap error")
 	s.wrapNilErr = Wrap(nil, "wrap error")
+	s.wrapEmptyErr = Wrap(s.newErr, " ")
 
 	// code
 	s.codeErr = WithCode(s.newErr, ErrBasic)
@@ -48,9 +49,10 @@ func (s *TestPublicSuite) TestNew() {
 }
 
 func (s *TestPublicSuite) TestWrap() {
-	s.Equal("wrap error: new error", s.wrapErr.Error())
-	s.Equal("wrap error: new error", s.wrapFmtErr.Error())
-	s.Equal("wrap error: outer error", s.wrapOuterErr.Error())
+	s.Equal("wrap error", s.wrapErr.Error())
+	s.Equal("wrap error", s.wrapFmtErr.Error())
+	s.Equal("wrap error", s.wrapOuterErr.Error())
+	s.Equal("new error", s.wrapEmptyErr.Error())
 	s.Nil(s.wrapNilErr)
 }