package redis_lock

import (
	"context"
	"errors"
	"github.com/redis/go-redis/v9"
	"time"
)

const (
	tryLockTimeoutSec = 5
)

type Lock struct {
	redisClient *redis.Client
	ownerID     string
}

func NewLock(address string, password string, db int) (*Lock, error) {
	if address == "" {
		return nil, errors.New("redis address不能为空")
	}

	return &Lock{redisClient: redis.NewClient(&redis.Options{
		Addr:     address,
		Password: password,
		DB:       db,
	})}, nil
}

func DestroyLock(lock *Lock) error {
	if lock == nil || lock.redisClient == nil {
		return nil
	}

	err := lock.redisClient.Close()
	if err != nil {
		return err
	}

	lock.redisClient = nil

	return nil
}

func (lock *Lock) Lock(ownerID string, key string, expireSec int) error {
	if ownerID == "" {
		return errors.New("ownerID不能为空")
	}

	if key == "" {
		return errors.New("key不能为空")
	}

	if expireSec == 0 {
		return errors.New("expireSec不能为0")
	}

	for {
		locked, err := lock.tryLock(ownerID, key, expireSec)
		if err != nil {
			return err
		}

		if locked {
			return nil
		}

		time.Sleep(200 * time.Millisecond)
	}
}

func (lock *Lock) TryLock(ownerID string, key string, expireSec int) (bool, error) {
	if ownerID == "" {
		return false, errors.New("ownerID不能为空")
	}

	if key == "" {
		return false, errors.New("key不能为空")
	}

	if expireSec == 0 {
		return false, errors.New("expireSec不能为0")
	}

	return lock.tryLock(ownerID, key, expireSec)
}

func (lock *Lock) Unlock(ownerID string, key string) error {
	if ownerID == "" {
		return errors.New("ownerID不能为空")
	}

	if key == "" {
		return errors.New("key不能为空")
	}

	if ownerID != lock.ownerID {
		return nil
	}

	return lock.deleteLockTime(key)
}

func (lock *Lock) tryLock(ownerID string, key string, expireSec int) (bool, error) {
	expire := time.Duration(expireSec) * time.Second
	nowTime := time.Now()
	nowNano := nowTime.UnixNano()
	currentLockNano := nowTime.Add(expire).UnixNano()

	// 尝试以当前lockTime上锁,如果返回true,则上锁成功,否则锁正被占用
	locked, err := lock.setLockTime(key, currentLockNano, expire)
	if err != nil {
		return false, err
	}

	if locked {
		lock.ownerID = ownerID
		return true, nil
	}

	// 获取保存的lockTime
	savedLockTimeNano, err := lock.getLockTime(key)
	if err != nil {
		return false, err
	}

	// 锁已经过期
	if savedLockTimeNano != 0 && savedLockTimeNano < nowNano {
		// 尝试以当前lockTime上锁,返回之前设置的值
		oldLockTime, err := lock.getAndSetLockTime(key, currentLockNano)
		if err != nil {
			return false, err
		}

		// 排除上一步操作存在的并发问题
		if oldLockTime != 0 && oldLockTime == savedLockTimeNano {
			lock.ownerID = ownerID
			return true, nil
		}
	}

	return false, nil
}

func (lock *Lock) setLockTime(key string, lockTime int64, expireSec time.Duration) (bool, error) {
	ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
	defer cancel()

	cmd := lock.redisClient.SetNX(ctx, key, lockTime, expireSec)
	if cmd.Err() != nil {
		return false, cmd.Err()
	}

	return cmd.Val(), nil
}

func (lock *Lock) getLockTime(key string) (int64, error) {
	ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
	defer cancel()

	cmd := lock.redisClient.Get(ctx, key)
	if cmd.Err() != nil {
		if cmd.Err().Error() == redis.Nil.Error() {
			return 0, nil
		}

		return 0, cmd.Err()
	}

	lockTime, err := cmd.Int64()
	if err != nil {
		return 0, err
	}

	return lockTime, nil
}

func (lock *Lock) getAndSetLockTime(key string, newLockTime int64) (int64, error) {
	ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
	defer cancel()

	cmd := lock.redisClient.GetSet(ctx, key, newLockTime)
	if cmd.Err() != nil {
		if cmd.Err().Error() == redis.Nil.Error() {
			return 0, nil
		}

		return 0, cmd.Err()
	}

	lockTime, err := cmd.Int64()
	if err != nil {
		return 0, err
	}

	return lockTime, nil
}

func (lock *Lock) deleteLockTime(key string) error {
	ctx, cancel := context.WithTimeout(context.Background(), tryLockTimeoutSec*time.Second)
	defer cancel()

	cmd := lock.redisClient.Del(ctx, key)
	if cmd.Err() != nil {
		return cmd.Err()
	}

	return nil
}