Browse Source

完成测试

yjp 2 weeks ago
parent
commit
1ce134c3e0

+ 32 - 0
framework/core/infrastructure/cache/cache.go

@@ -139,6 +139,38 @@ func Clear(cache Cache) error {
 	return cache.Clear()
 }
 
+// Expire 使用秒数设置key的过期时间
+// 参数:
+// - cache: 缓存基础设施接口
+// - expireSec: 缓存过期时间,单位秒
+// 返回值:
+// - 错误
+func Expire(cache Cache, key string, expireSec int64) error {
+	return cache.Expire(key, expireSec)
+}
+
+// ExpireTime 获取过期秒数
+// 参数:
+// - cache: 缓存基础设施接口
+// - key: 缓存的键
+// 返回值:
+// - 过期秒数
+// - 错误
+func ExpireTime(cache Cache, key string) (int64, error) {
+	return cache.ExpireTime(key)
+}
+
+// TTL 获取过期剩余秒数
+// 参数:
+// - cache: 缓存基础设施接口
+// - key: 缓存的键
+// 返回值:
+// - 剩余过期秒数
+// - 错误
+func TTL(cache Cache, key string) (int64, error) {
+	return cache.TTL(key)
+}
+
 func toString(valueReflectValue reflect.Value) (string, error) {
 	dataVal := reflect.Indirect(valueReflectValue)
 	dataKind := reflectutils.GroupValueKind(dataVal)

+ 121 - 85
framework/core/infrastructure/cache/local/local.go

@@ -14,16 +14,18 @@ type valueItem struct {
 }
 
 type Cache struct {
+	*sync.RWMutex
 	namespace                      string
-	syncMap                        *sync.Map
+	cacheMap                       map[string]*valueItem
 	expireRoutineDoneChannelsMutex *sync.Mutex
 	expireRoutineDoneChannels      map[string]chan any
 }
 
 func New(namespace string) *Cache {
 	return &Cache{
+		RWMutex:                        new(sync.RWMutex),
 		namespace:                      namespace,
-		syncMap:                        new(sync.Map),
+		cacheMap:                       nil,
 		expireRoutineDoneChannelsMutex: new(sync.Mutex),
 		expireRoutineDoneChannels:      make(map[string]chan any),
 	}
@@ -34,6 +36,9 @@ func Destroy(cache *Cache) {
 		return
 	}
 
+	cache.Lock()
+	defer cache.Unlock()
+
 	cache.expireRoutineDoneChannelsMutex.Lock()
 
 	for _, doneChan := range cache.expireRoutineDoneChannels {
@@ -46,7 +51,7 @@ func Destroy(cache *Cache) {
 
 	cache.expireRoutineDoneChannelsMutex.Unlock()
 
-	cache.syncMap = nil
+	cache.cacheMap = nil
 }
 
 func (cache *Cache) Set(key string, value string, expireSec int64) error {
@@ -56,153 +61,132 @@ func (cache *Cache) Set(key string, value string, expireSec int64) error {
 		expireSec:       expireSec,
 	}
 
-	cache.syncMap.Store(cache.key2CacheKey(key), item)
+	cache.Lock()
+	defer cache.Unlock()
 
-	if item.expireSec != 0 {
-		doneChan := make(chan any)
+	cache.cacheMap[cache.key2CacheKey(key)] = item
 
+	if item.expireSec != 0 {
 		cache.expireRoutineDoneChannelsMutex.Lock()
-
-		if cache.expireRoutineDoneChannels != nil {
-			cache.expireRoutineDoneChannels[cache.key2CacheKey(key)] = doneChan
-		}
-
+		cache.startExpireRoutineWithoutLock(key, item)
 		cache.expireRoutineDoneChannelsMutex.Unlock()
-
-		go func() {
-			timer := time.NewTimer(time.Second * time.Duration(item.expireSec))
-
-			defer func() {
-				timer.Stop()
-
-				cache.expireRoutineDoneChannelsMutex.Lock()
-
-				if cache.expireRoutineDoneChannels != nil && len(cache.expireRoutineDoneChannels) > 0 {
-					savedDoneChan, ok := cache.expireRoutineDoneChannels[cache.key2CacheKey(key)]
-					if !ok {
-						return
-					}
-
-					close(savedDoneChan)
-					savedDoneChan = nil
-					delete(cache.expireRoutineDoneChannels, cache.key2CacheKey(key))
-				}
-
-				cache.expireRoutineDoneChannelsMutex.Unlock()
-			}()
-
-			for {
-				select {
-				case <-timer.C:
-					cache.syncMap.Delete(cache.key2CacheKey(key))
-					return
-				case <-doneChan:
-					return
-				}
-			}
-		}()
 	}
 
 	return nil
 }
 
 func (cache *Cache) Get(key string) (string, error) {
-	value, loaded := cache.syncMap.Load(cache.key2CacheKey(key))
-	if !loaded {
+	cache.RLock()
+	defer cache.RUnlock()
+
+	item, ok := cache.cacheMap[cache.key2CacheKey(key)]
+	if !ok {
 		return "", nil
 	}
 
-	return value.(string), nil
+	return item.value, nil
 }
 
 func (cache *Cache) GetMulti(keys []string) (map[string]string, error) {
+	cache.RLock()
+	defer cache.RUnlock()
+
 	result := make(map[string]string)
 
 	for _, key := range keys {
-		value, loaded := cache.syncMap.Load(cache.key2CacheKey(key))
-		if !loaded {
+		item, ok := cache.cacheMap[cache.key2CacheKey(key)]
+		if !ok {
 			result[key] = ""
 		}
 
-		result[key] = value.(string)
+		result[key] = item.value
 	}
 
 	return result, nil
 }
 
 func (cache *Cache) GetAll() (map[string]string, error) {
+	cache.RLock()
+	defer cache.RUnlock()
+
 	result := make(map[string]string)
 
-	cache.syncMap.Range(func(key any, value any) bool {
-		result[cache.cacheKey2Key(key.(string))] = value.(string)
-		return true
-	})
+	for key, item := range cache.cacheMap {
+		result[cache.cacheKey2Key(key)] = item.value
+	}
 
 	return result, nil
 }
 
 func (cache *Cache) Delete(key string) error {
-	cache.syncMap.Delete(cache.key2CacheKey(key))
+	cache.Lock()
+	defer cache.Unlock()
+
+	delete(cache.cacheMap, cache.key2CacheKey(key))
 	return nil
 }
 
 func (cache *Cache) Clear() error {
-	cache.syncMap = new(sync.Map)
+	cache.Lock()
+	defer cache.Unlock()
+
+	cache.cacheMap = make(map[string]*valueItem)
 	return nil
 }
 
 func (cache *Cache) Expire(key string, expireSec int64) error {
-	value, loaded := cache.syncMap.Load(cache.key2CacheKey(key))
-	if !loaded {
-		return errors.New("对应的键不存在")
-	}
-
-	oldItem := value.(*valueItem)
-
-	// 超时时间修改为0,停止协程
-	if expireSec == 0 {
-		cache.expireRoutineDoneChannelsMutex.Lock()
+	cache.Lock()
+	defer cache.Unlock()
 
-		doneChan, ok := cache.expireRoutineDoneChannels[cache.key2CacheKey(key)]
-		if ok {
-			doneChan <- nil
-			close(doneChan)
-			doneChan = nil
-		}
+	cache.expireRoutineDoneChannelsMutex.Lock()
+	defer cache.expireRoutineDoneChannelsMutex.Unlock()
 
-		delete(cache.expireRoutineDoneChannels, cache.key2CacheKey(key))
+	cache.stopExpireRoutineWithoutLock(key)
 
-		cache.expireRoutineDoneChannelsMutex.Unlock()
+	oldItem, ok := cache.cacheMap[cache.key2CacheKey(key)]
+	if !ok {
+		return errors.New("对应的键不存在")
 	}
 
-	cache.syncMap.Store(cache.key2CacheKey(key), &valueItem{
+	newItem := &valueItem{
 		value:           oldItem.value,
 		expireStartTime: time.Now(),
 		expireSec:       expireSec,
-	})
+	}
+
+	cache.cacheMap[cache.key2CacheKey(key)] = newItem
+
+	if expireSec == 0 {
+		return nil
+	}
+
+	cache.startExpireRoutineWithoutLock(key, newItem)
 
 	return nil
 }
 
 func (cache *Cache) ExpireTime(key string) (int64, error) {
-	value, loaded := cache.syncMap.Load(cache.key2CacheKey(key))
-	if !loaded {
+	cache.RLock()
+	defer cache.RUnlock()
+
+	item, ok := cache.cacheMap[cache.key2CacheKey(key)]
+	if !ok {
 		return 0, nil
 	}
 
-	oldItem := value.(*valueItem)
-
-	return oldItem.expireSec, nil
+	return item.expireSec, nil
 }
 
 func (cache *Cache) TTL(key string) (int64, error) {
-	value, loaded := cache.syncMap.Load(cache.key2CacheKey(key))
-	if !loaded {
+	cache.RLock()
+	defer cache.RUnlock()
+
+	item, ok := cache.cacheMap[cache.key2CacheKey(key)]
+	if !ok {
 		return 0, nil
 	}
 
-	oldItem := value.(*valueItem)
-	remainSec := int64(oldItem.expireStartTime.Add(time.Duration(oldItem.expireSec) * time.Second).Sub(time.Now()).Seconds())
+	remainSec := int64(item.expireStartTime.Add(time.Duration(item.expireSec)*time.Second).Sub(time.Now()).Seconds() + 0.5)
 	if remainSec <= 0 {
 		return 0, nil
 	}
@@ -210,6 +194,58 @@ func (cache *Cache) TTL(key string) (int64, error) {
 	return remainSec, nil
 }
 
+func (cache *Cache) startExpireRoutineWithoutLock(key string, item *valueItem) {
+	doneChan := make(chan any)
+
+	if cache.expireRoutineDoneChannels != nil {
+		cache.expireRoutineDoneChannels[cache.key2CacheKey(key)] = doneChan
+	}
+
+	go func() {
+		timer := time.NewTimer(time.Second * time.Duration(item.expireSec))
+
+		defer func() {
+			timer.Stop()
+
+			cache.expireRoutineDoneChannelsMutex.Lock()
+
+			if cache.expireRoutineDoneChannels != nil && len(cache.expireRoutineDoneChannels) > 0 {
+				savedDoneChan, ok := cache.expireRoutineDoneChannels[cache.key2CacheKey(key)]
+				if !ok {
+					return
+				}
+
+				close(savedDoneChan)
+				savedDoneChan = nil
+				delete(cache.expireRoutineDoneChannels, cache.key2CacheKey(key))
+			}
+
+			cache.expireRoutineDoneChannelsMutex.Unlock()
+		}()
+
+		for {
+			select {
+			case <-timer.C:
+				delete(cache.cacheMap, cache.key2CacheKey(key))
+				return
+			case <-doneChan:
+				return
+			}
+		}
+	}()
+}
+
+func (cache *Cache) stopExpireRoutineWithoutLock(key string) {
+	doneChan, ok := cache.expireRoutineDoneChannels[cache.key2CacheKey(key)]
+	if ok {
+		doneChan <- nil
+		close(doneChan)
+		doneChan = nil
+	}
+
+	delete(cache.expireRoutineDoneChannels, cache.key2CacheKey(key))
+}
+
 func (cache *Cache) key2CacheKey(key string) string {
 	return cache.namespace + "::" + key
 }

+ 118 - 0
test/cache_test.go

@@ -6,6 +6,7 @@ import (
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure/cache/local"
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure/cache/redis"
 	"git.sxidc.com/go-tools/utils/strutils"
+	"github.com/pkg/errors"
 	"testing"
 	"time"
 )
@@ -48,6 +49,24 @@ func TestCacheInfrastructure(t *testing.T) {
 	testCache(t, i.RedisCache())
 }
 
+func TestCacheExpire(t *testing.T) {
+	i := infrastructure.NewInfrastructure(infrastructure.Config{
+		CacheConfig: infrastructure.CacheConfig{
+			Namespace: "test",
+			Redis: &infrastructure.RedisConfig{
+				Address:  "localhost:30379",
+				UserName: "",
+				Password: "mtyzxhc",
+				DB:       1,
+			},
+		},
+	})
+	defer infrastructure.DestroyInfrastructure(i)
+
+	testCacheExpire(t, i.LocalCache())
+	testCacheExpire(t, i.RedisCache())
+}
+
 func testLocalCache(t *testing.T, localCache *local.Cache) {
 	err := localCache.Set("test1", "test1-value", 0)
 	if err != nil {
@@ -335,3 +354,102 @@ func testCache(t *testing.T, cacheInterface cache.Cache) {
 		t.Fatalf("%+v\n", err)
 	}
 }
+
+func testCacheExpire(t *testing.T, cacheInterface cache.Cache) {
+	err := cache.Set(cacheInterface, "test", 10, 1)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	time.Sleep(1 * time.Second)
+
+	value, err := cache.Get[string](cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	if strutils.IsStringNotEmpty(value) {
+		t.Fatalf("%+v\n", errors.Errorf("Value Expire Error: cache %v", value))
+	}
+
+	err = cache.Set(cacheInterface, "test", 10, 0)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	err = cache.Expire(cacheInterface, "test", 1)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	time.Sleep(1 * time.Second)
+
+	value, err = cache.Get[string](cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	if strutils.IsStringNotEmpty(value) {
+		t.Fatalf("%+v\n", errors.Errorf("Value Expire Error: cache %v", value))
+	}
+
+	err = cache.Set(cacheInterface, "test", 10, 10)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	err = cache.Expire(cacheInterface, "test", 1)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	time.Sleep(1 * time.Second)
+
+	value, err = cache.Get[string](cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	if strutils.IsStringNotEmpty(value) {
+		t.Fatalf("%+v\n", errors.Errorf("Value Expire Error: cache %v", value))
+	}
+
+	err = cache.Set(cacheInterface, "test", 10, 10)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	expired, err := cache.ExpireTime(cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	if expired != 10 {
+		t.Fatalf("%+v\n", errors.Errorf("Expire Error: expired %v", expired))
+	}
+
+	err = cache.Expire(cacheInterface, "test", 5)
+
+	expired, err = cache.ExpireTime(cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	if expired != 5 {
+		t.Fatalf("%+v\n", errors.Errorf("Expire Error: expired %v", expired))
+	}
+
+	ttl, err := cache.TTL(cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	if ttl != 5 {
+		t.Fatalf("%+v\n", errors.Errorf("TTL Error: ttl %v", ttl))
+	}
+
+	err = cache.Delete(cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+}