123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- package redislock
- import (
- "context"
- "errors"
- "github.com/redis/go-redis/v9"
- "time"
- )
- const (
- tryLockTimeoutSec = 5
- )
- type Lock struct {
- redisClient *redis.Client
- ownerID string
- }
- func NewLock(address string, password string, db int) (*Lock, error) {
- if address == "" {
- return nil, errors.New("redis address不能为空")
- }
- return &Lock{redisClient: redis.NewClient(&redis.Options{
- Addr: address,
- Password: password,
- DB: db,
- })}, nil
- }
- func DestroyLock(lock *Lock) error {
- if lock == nil || lock.redisClient == nil {
- return nil
- }
- err := lock.redisClient.Close()
- if err != nil {
- return err
- }
- lock.redisClient = nil
- return nil
- }
- func (lock *Lock) Lock(ownerID string, key string, expireSec int) error {
- if ownerID == "" {
- return errors.New("ownerID不能为空")
- }
- if key == "" {
- return errors.New("key不能为空")
- }
- if expireSec == 0 {
- return errors.New("expireSec不能为0")
- }
- for {
- locked, err := lock.tryLock(ownerID, key, expireSec)
- if err != nil {
- return err
- }
- if locked {
- return nil
- }
- time.Sleep(200 * time.Millisecond)
- }
- }
- func (lock *Lock) TryLock(ownerID string, key string, expireSec int) (bool, error) {
- if ownerID == "" {
- return false, errors.New("ownerID不能为空")
- }
- if key == "" {
- return false, errors.New("key不能为空")
- }
- if expireSec == 0 {
- return false, errors.New("expireSec不能为0")
- }
- return lock.tryLock(ownerID, key, expireSec)
- }
- func (lock *Lock) Unlock(ownerID string, key string) error {
- if ownerID == "" {
- return errors.New("ownerID不能为空")
- }
- if key == "" {
- return errors.New("key不能为空")
- }
- if ownerID != lock.ownerID {
- return nil
- }
- return lock.deleteLockTime(key)
- }
- func (lock *Lock) tryLock(ownerID string, key string, expireSec int) (bool, error) {
- expire := time.Duration(expireSec) * time.Second
- nowTime := time.Now()
- nowNano := nowTime.UnixNano()
- currentLockNano := nowTime.Add(expire).UnixNano()
- // 尝试以当前lockTime上锁,如果返回true,则上锁成功,否则锁正被占用
- locked, err := lock.setLockTime(key, currentLockNano, expire)
- if err != nil {
- return false, err
- }
- if locked {
- lock.ownerID = ownerID
- return true, nil
- }
- // 获取保存的lockTime
- savedLockTimeNano, err := lock.getLockTime(key)
- if err != nil {
- return false, err
- }
- // 锁已经过期
- if savedLockTimeNano != 0 && savedLockTimeNano < nowNano {
- // 尝试以当前lockTime上锁,返回之前设置的值
- oldLockTime, err := lock.getAndSetLockTime(key, currentLockNano)
- if err != nil {
- return false, err
- }
- // 排除上一步操作存在的并发问题
- if oldLockTime != 0 && oldLockTime == savedLockTimeNano {
- lock.ownerID = ownerID
- return true, nil
- }
- }
- return false, nil
- }
- func (lock *Lock) setLockTime(key string, lockTime int64, expireSec time.Duration) (bool, error) {
- ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
- defer cancel()
- cmd := lock.redisClient.SetNX(ctx, key, lockTime, expireSec)
- if cmd.Err() != nil {
- return false, cmd.Err()
- }
- return cmd.Val(), nil
- }
- func (lock *Lock) getLockTime(key string) (int64, error) {
- ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
- defer cancel()
- cmd := lock.redisClient.Get(ctx, key)
- if cmd.Err() != nil {
- if cmd.Err().Error() == redis.Nil.Error() {
- return 0, nil
- }
- return 0, cmd.Err()
- }
- lockTime, err := cmd.Int64()
- if err != nil {
- return 0, err
- }
- return lockTime, nil
- }
- func (lock *Lock) getAndSetLockTime(key string, newLockTime int64) (int64, error) {
- ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
- defer cancel()
- cmd := lock.redisClient.GetSet(ctx, key, newLockTime)
- if cmd.Err() != nil {
- if cmd.Err().Error() == redis.Nil.Error() {
- return 0, nil
- }
- return 0, cmd.Err()
- }
- lockTime, err := cmd.Int64()
- if err != nil {
- return 0, err
- }
- return lockTime, nil
- }
- func (lock *Lock) deleteLockTime(key string) error {
- ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
- defer cancel()
- cmd := lock.redisClient.Del(ctx, key)
- if cmd.Err() != nil {
- return cmd.Err()
- }
- return nil
- }
|