package syncutils

import "sync"

// WriteVarFunc 回调变量用来写的函数
type WriteVarFunc[T any] func(v T) T

// ReadVarFunc 回调变量用来读的函数
type ReadVarFunc[T any] func(v T)

// SyncVar 线程安全的变量封装
type SyncVar[T any] struct {
	writeLocker sync.Locker
	readLocker  sync.Locker
	v           T
}

// NewSyncVar 创建线程安全的变量
func NewSyncVar[T any](v T, isRWMutex bool) *SyncVar[T] {
	var writeLocker sync.Locker
	var readLocker sync.Locker

	if !isRWMutex {
		mutex := &sync.Mutex{}
		writeLocker = mutex
		readLocker = mutex
	} else {
		mutex := &sync.RWMutex{}
		writeLocker = mutex
		readLocker = mutex.RLocker()
	}

	return &SyncVar[T]{
		writeLocker: writeLocker,
		readLocker:  readLocker,
		v:           v,
	}
}

// ForWrite 为写而获取变量
func (syncVar *SyncVar[T]) ForWrite(writeVarFunc WriteVarFunc[T]) {
	if writeVarFunc == nil {
		return
	}

	syncVar.writeLocker.Lock()
	defer syncVar.writeLocker.Unlock()

	syncVar.v = writeVarFunc(syncVar.v)
}

// ForRead 为写而获取变量
func (syncVar *SyncVar[T]) ForRead(readVarFunc ReadVarFunc[T]) {
	if readVarFunc == nil {
		return
	}

	syncVar.readLocker.Lock()
	defer syncVar.readLocker.Unlock()

	readVarFunc(syncVar.v)
}