|
@@ -0,0 +1,210 @@
|
|
|
+package redis_lock
|
|
|
+
|
|
|
+import (
|
|
|
+ "context"
|
|
|
+ "errors"
|
|
|
+ "github.com/redis/go-redis/v9"
|
|
|
+ "time"
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ tryLockTimeoutSec = 5
|
|
|
+)
|
|
|
+
|
|
|
+type Lock struct {
|
|
|
+ redisClient *redis.Client
|
|
|
+ ownerID string
|
|
|
+}
|
|
|
+
|
|
|
+func NewLock(address string, password string, db int) (*Lock, error) {
|
|
|
+ if address == "" {
|
|
|
+ return nil, errors.New("redis address不能为空")
|
|
|
+ }
|
|
|
+
|
|
|
+ return &Lock{redisClient: redis.NewClient(&redis.Options{
|
|
|
+ Addr: address,
|
|
|
+ Password: password,
|
|
|
+ DB: db,
|
|
|
+ })}, nil
|
|
|
+}
|
|
|
+
|
|
|
+func DestroyLock(lock *Lock) error {
|
|
|
+ if lock == nil || lock.redisClient == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ err := lock.redisClient.Close()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ lock.redisClient = nil
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (lock *Lock) Lock(ownerID string, key string, expireSec int) error {
|
|
|
+ if ownerID == "" {
|
|
|
+ return errors.New("ownerID不能为空")
|
|
|
+ }
|
|
|
+
|
|
|
+ if key == "" {
|
|
|
+ return errors.New("key不能为空")
|
|
|
+ }
|
|
|
+
|
|
|
+ if expireSec == 0 {
|
|
|
+ return errors.New("expireSec不能为0")
|
|
|
+ }
|
|
|
+
|
|
|
+ for {
|
|
|
+ locked, err := lock.tryLock(ownerID, key, expireSec)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ if locked {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ time.Sleep(200 * time.Millisecond)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (lock *Lock) TryLock(ownerID string, key string, expireSec int) (bool, error) {
|
|
|
+ if ownerID == "" {
|
|
|
+ return false, errors.New("ownerID不能为空")
|
|
|
+ }
|
|
|
+
|
|
|
+ if key == "" {
|
|
|
+ return false, errors.New("key不能为空")
|
|
|
+ }
|
|
|
+
|
|
|
+ if expireSec == 0 {
|
|
|
+ return false, errors.New("expireSec不能为0")
|
|
|
+ }
|
|
|
+
|
|
|
+ return lock.tryLock(ownerID, key, expireSec)
|
|
|
+}
|
|
|
+
|
|
|
+func (lock *Lock) Unlock(ownerID string, key string) error {
|
|
|
+ if ownerID == "" {
|
|
|
+ return errors.New("ownerID不能为空")
|
|
|
+ }
|
|
|
+
|
|
|
+ if key == "" {
|
|
|
+ return errors.New("key不能为空")
|
|
|
+ }
|
|
|
+
|
|
|
+ if ownerID != lock.ownerID {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return lock.deleteLockTime(key)
|
|
|
+}
|
|
|
+
|
|
|
+func (lock *Lock) tryLock(ownerID string, key string, expireSec int) (bool, error) {
|
|
|
+ expire := time.Duration(expireSec) * time.Second
|
|
|
+ nowTime := time.Now()
|
|
|
+ nowNano := nowTime.UnixNano()
|
|
|
+ currentLockNano := nowTime.Add(expire).UnixNano()
|
|
|
+
|
|
|
+
|
|
|
+ locked, err := lock.setLockTime(key, currentLockNano, expire)
|
|
|
+ if err != nil {
|
|
|
+ return false, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if locked {
|
|
|
+ lock.ownerID = ownerID
|
|
|
+ return true, nil
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ savedLockTimeNano, err := lock.getLockTime(key)
|
|
|
+ if err != nil {
|
|
|
+ return false, err
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ if savedLockTimeNano != 0 && savedLockTimeNano < nowNano {
|
|
|
+
|
|
|
+ oldLockTime, err := lock.getAndSetLockTime(key, currentLockNano)
|
|
|
+ if err != nil {
|
|
|
+ return false, err
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ if oldLockTime != 0 && oldLockTime == savedLockTimeNano {
|
|
|
+ lock.ownerID = ownerID
|
|
|
+ return true, nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return false, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (lock *Lock) setLockTime(key string, lockTime int64, expireSec time.Duration) (bool, error) {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ cmd := lock.redisClient.SetNX(ctx, key, lockTime, expireSec)
|
|
|
+ if cmd.Err() != nil {
|
|
|
+ return false, cmd.Err()
|
|
|
+ }
|
|
|
+
|
|
|
+ return cmd.Val(), nil
|
|
|
+}
|
|
|
+
|
|
|
+func (lock *Lock) getLockTime(key string) (int64, error) {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ cmd := lock.redisClient.Get(ctx, key)
|
|
|
+ if cmd.Err() != nil {
|
|
|
+ if cmd.Err().Error() == redis.Nil.Error() {
|
|
|
+ return 0, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return 0, cmd.Err()
|
|
|
+ }
|
|
|
+
|
|
|
+ lockTime, err := cmd.Int64()
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return lockTime, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (lock *Lock) getAndSetLockTime(key string, newLockTime int64) (int64, error) {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ cmd := lock.redisClient.GetSet(ctx, key, newLockTime)
|
|
|
+ if cmd.Err() != nil {
|
|
|
+ if cmd.Err().Error() == redis.Nil.Error() {
|
|
|
+ return 0, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return 0, cmd.Err()
|
|
|
+ }
|
|
|
+
|
|
|
+ lockTime, err := cmd.Int64()
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return lockTime, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (lock *Lock) deleteLockTime(key string) error {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ cmd := lock.redisClient.Del(ctx, key)
|
|
|
+ if cmd.Err() != nil {
|
|
|
+ return cmd.Err()
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|