yjp 2 weeks ago
parent
commit
26dab800a2

+ 27 - 0
framework/core/infrastructure/cache/cache.go

@@ -30,6 +30,12 @@ type Cache interface {
 
 	// Clear 清除缓存
 	Clear() error
+
+	// Expire 使用秒数设置key的过期时间
+	Expire(key string, expireSec int64) error
+
+	// TTL 获取过期剩余秒数
+	TTL(key string) (int64, error)
 }
 
 // Set 设置缓存值
@@ -130,6 +136,27 @@ func Clear(cache Cache) error {
 	return cache.Clear()
 }
 
+// Expire 使用秒数设置key的过期时间
+// 参数:
+// - cache: 缓存基础设施接口
+// - expireSec: 缓存过期时间,单位秒
+// 返回值:
+// - 错误
+func Expire(cache Cache, key string, expireSec int64) error {
+	return cache.Expire(key, expireSec)
+}
+
+// TTL 获取过期剩余秒数
+// 参数:
+// - cache: 缓存基础设施接口
+// - key: 缓存的键
+// 返回值:
+// - 剩余过期秒数
+// - 错误
+func TTL(cache Cache, key string) (int64, error) {
+	return cache.TTL(key)
+}
+
 func toString(valueReflectValue reflect.Value) (string, error) {
 	dataVal := reflect.Indirect(valueReflectValue)
 	dataKind := reflectutils.GroupValueKind(dataVal)

+ 150 - 74
framework/core/infrastructure/cache/local/local.go

@@ -1,24 +1,30 @@
 package local
 
 import (
+	"github.com/pkg/errors"
 	"strings"
 	"sync"
 	"time"
 )
 
+type valueItem struct {
+	value           string
+	expireStartTime time.Time
+	expireSec       int64
+	expireDoneChan  chan any
+}
+
 type Cache struct {
-	namespace                      string
-	syncMap                        *sync.Map
-	expireRoutineDoneChannelsMutex *sync.Mutex
-	expireRoutineDoneChannels      []chan any
+	*sync.RWMutex
+	namespace string
+	cacheMap  map[string]*valueItem
 }
 
 func New(namespace string) *Cache {
 	return &Cache{
-		namespace:                      namespace,
-		syncMap:                        new(sync.Map),
-		expireRoutineDoneChannelsMutex: new(sync.Mutex),
-		expireRoutineDoneChannels:      make([]chan any, 0),
+		RWMutex:   new(sync.RWMutex),
+		namespace: namespace,
+		cacheMap:  make(map[string]*valueItem),
 	}
 }
 
@@ -27,124 +33,194 @@ func Destroy(cache *Cache) {
 		return
 	}
 
-	cache.expireRoutineDoneChannelsMutex.Lock()
+	cache.Lock()
+	defer cache.Unlock()
 
-	for _, doneChan := range cache.expireRoutineDoneChannels {
-		doneChan <- nil
-		close(doneChan)
-		doneChan = nil
+	for _, item := range cache.cacheMap {
+		if item.expireDoneChan != nil {
+			item.expireDoneChan <- nil
+		}
 	}
 
-	cache.expireRoutineDoneChannels = nil
-
-	cache.expireRoutineDoneChannelsMutex.Unlock()
-
-	cache.syncMap = nil
+	cache.cacheMap = nil
 }
 
 func (cache *Cache) Set(key string, value string, expireSec int64) error {
-	cache.syncMap.Store(cache.key2CacheKey(key), value)
+	cache.Lock()
+	defer cache.Unlock()
+
+	oldItem, ok := cache.cacheMap[cache.key2CacheKey(key)]
+	if ok {
+		cache.stopExpireRoutine(oldItem)
 
-	if expireSec != 0 {
-		doneChan := make(chan any)
+		newItem := &valueItem{
+			value:           value,
+			expireStartTime: time.Now(),
+			expireSec:       expireSec,
+		}
 
-		cache.expireRoutineDoneChannelsMutex.Lock()
+		cache.cacheMap[cache.key2CacheKey(key)] = newItem
 
-		if cache.expireRoutineDoneChannels != nil {
-			cache.expireRoutineDoneChannels = append(cache.expireRoutineDoneChannels, doneChan)
+		if expireSec == 0 {
+			return nil
 		}
 
-		cache.expireRoutineDoneChannelsMutex.Unlock()
+		cache.startExpireRoutine(key, newItem)
 
-		go func() {
-			timer := time.NewTimer(time.Second * time.Duration(expireSec))
+		return nil
+	}
 
-			defer func() {
-				timer.Stop()
+	item := &valueItem{
+		value:           value,
+		expireStartTime: time.Now(),
+		expireSec:       expireSec,
+	}
 
-				cache.expireRoutineDoneChannelsMutex.Lock()
-
-				if cache.expireRoutineDoneChannels != nil {
-					findIndex := -1
-
-					for i, savedDoneChan := range cache.expireRoutineDoneChannels {
-						if savedDoneChan == doneChan {
-							findIndex = i
-							close(savedDoneChan)
-							savedDoneChan = nil
-							break
-						}
-					}
-
-					if findIndex != -1 {
-						cache.expireRoutineDoneChannels =
-							append(cache.expireRoutineDoneChannels[:findIndex], cache.expireRoutineDoneChannels[findIndex+1:]...)
-					}
-				}
-
-				cache.expireRoutineDoneChannelsMutex.Unlock()
-			}()
-
-			for {
-				select {
-				case <-timer.C:
-					cache.syncMap.Delete(cache.key2CacheKey(key))
-					return
-				case <-doneChan:
-					return
-				}
-			}
-		}()
+	cache.cacheMap[cache.key2CacheKey(key)] = item
+
+	if item.expireSec != 0 {
+		cache.startExpireRoutine(key, item)
 	}
 
 	return nil
 }
 
 func (cache *Cache) Get(key string) (string, error) {
-	value, loaded := cache.syncMap.Load(cache.key2CacheKey(key))
-	if !loaded {
+	cache.RLock()
+	defer cache.RUnlock()
+
+	item, ok := cache.cacheMap[cache.key2CacheKey(key)]
+	if !ok {
 		return "", nil
 	}
 
-	return value.(string), nil
+	return item.value, nil
 }
 
 func (cache *Cache) GetMulti(keys []string) (map[string]string, error) {
+	cache.RLock()
+	defer cache.RUnlock()
+
 	result := make(map[string]string)
 
 	for _, key := range keys {
-		value, loaded := cache.syncMap.Load(cache.key2CacheKey(key))
-		if !loaded {
+		item, ok := cache.cacheMap[cache.key2CacheKey(key)]
+		if !ok {
 			result[key] = ""
 		}
 
-		result[key] = value.(string)
+		result[key] = item.value
 	}
 
 	return result, nil
 }
 
 func (cache *Cache) GetAll() (map[string]string, error) {
+	cache.RLock()
+	defer cache.RUnlock()
+
 	result := make(map[string]string)
 
-	cache.syncMap.Range(func(key any, value any) bool {
-		result[cache.cacheKey2Key(key.(string))] = value.(string)
-		return true
-	})
+	for key, item := range cache.cacheMap {
+		result[cache.cacheKey2Key(key)] = item.value
+	}
 
 	return result, nil
 }
 
 func (cache *Cache) Delete(key string) error {
-	cache.syncMap.Delete(cache.key2CacheKey(key))
+	cache.Lock()
+	defer cache.Unlock()
+
+	delete(cache.cacheMap, cache.key2CacheKey(key))
 	return nil
 }
 
 func (cache *Cache) Clear() error {
-	cache.syncMap = new(sync.Map)
+	cache.Lock()
+	defer cache.Unlock()
+
+	cache.cacheMap = make(map[string]*valueItem)
+	return nil
+}
+
+func (cache *Cache) Expire(key string, expireSec int64) error {
+	cache.Lock()
+	defer cache.Unlock()
+
+	oldItem, ok := cache.cacheMap[cache.key2CacheKey(key)]
+	if !ok {
+		return errors.New("对应的键不存在: " + key)
+	}
+
+	cache.stopExpireRoutine(oldItem)
+
+	newItem := &valueItem{
+		value:           oldItem.value,
+		expireStartTime: time.Now(),
+		expireSec:       expireSec,
+	}
+
+	cache.cacheMap[cache.key2CacheKey(key)] = newItem
+
+	if expireSec == 0 {
+		return nil
+	}
+
+	cache.startExpireRoutine(key, newItem)
+
 	return nil
 }
 
+func (cache *Cache) TTL(key string) (int64, error) {
+	cache.RLock()
+	defer cache.RUnlock()
+
+	item, ok := cache.cacheMap[cache.key2CacheKey(key)]
+	if !ok {
+		return 0, nil
+	}
+
+	remainSec := int64(item.expireStartTime.Add(time.Duration(item.expireSec)*time.Second).Sub(time.Now()).Seconds() + 0.5)
+	if remainSec <= 0 {
+		return 0, nil
+	}
+
+	return remainSec, nil
+}
+
+func (cache *Cache) startExpireRoutine(key string, item *valueItem) {
+	item.expireDoneChan = make(chan any)
+
+	go func() {
+		timer := time.NewTimer(time.Second * time.Duration(item.expireSec))
+
+		defer func() {
+			if item.expireDoneChan != nil {
+				close(item.expireDoneChan)
+				item.expireDoneChan = nil
+			}
+		}()
+
+		for {
+			select {
+			case <-timer.C:
+				delete(cache.cacheMap, cache.key2CacheKey(key))
+				return
+			case <-item.expireDoneChan:
+				timer.Stop()
+				return
+			}
+		}
+	}()
+}
+
+func (cache *Cache) stopExpireRoutine(item *valueItem) {
+	if item.expireDoneChan != nil {
+		item.expireDoneChan <- nil
+	}
+}
+
 func (cache *Cache) key2CacheKey(key string) string {
 	return cache.namespace + "::" + key
 }

+ 22 - 0
framework/core/infrastructure/cache/redis/redis.go

@@ -141,6 +141,28 @@ func (cache *Cache) Clear() error {
 	return nil
 }
 
+func (cache *Cache) Expire(key string, expireSec int64) error {
+	cmd := cache.redisClient.Expire(context.Background(), cache.key2CacheKey(key), time.Duration(expireSec)*time.Second)
+	if cmd.Err() != nil {
+		return errors.New(cmd.Err().Error())
+	}
+
+	return nil
+}
+
+func (cache *Cache) TTL(key string) (int64, error) {
+	cmd := cache.redisClient.TTL(context.Background(), cache.key2CacheKey(key))
+	if cmd.Err() != nil {
+		if cmd.Err().Error() == redis.Nil.Error() {
+			return 0, nil
+		}
+
+		return 0, errors.New(cmd.Err().Error())
+	}
+
+	return int64(cmd.Val().Seconds()), nil
+}
+
 func (cache *Cache) key2CacheKey(key string) string {
 	return keyPrefix + cache.namespace + "::" + key
 }

+ 5 - 3
framework/core/infrastructure/database/clause/condition.go

@@ -17,9 +17,11 @@ func NewConditions() *Conditions {
 
 func (conditions *Conditions) AddCondition(query string, args ...any) *Conditions {
 	conditions.queries = append(conditions.queries, query)
-        for _, arg := range args {
-                conditions.args = append(conditions.args, []any{arg})
-        }
+
+	for _, arg := range args {
+		conditions.args = append(conditions.args, []any{arg})
+	}
+
 	return conditions
 }
 

+ 5 - 3
framework/core/infrastructure/database/sql/conditions.go

@@ -25,9 +25,11 @@ func NewConditions() *Conditions {
 // - 数据库条件
 func (conditions *Conditions) AddCondition(query string, args ...any) *Conditions {
 	conditions.queries = append(conditions.queries, query)
-        for _, arg := range args {
-                conditions.args = append(conditions.args, []any{arg})
-        }
+
+	for _, arg := range args {
+		conditions.args = append(conditions.args, []any{arg})
+	}
+
 	return conditions
 }
 

+ 105 - 2
test/cache_test.go

@@ -6,6 +6,7 @@ import (
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure/cache/local"
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure/cache/redis"
 	"git.sxidc.com/go-tools/utils/strutils"
+	"github.com/pkg/errors"
 	"testing"
 	"time"
 )
@@ -16,7 +17,7 @@ func TestLocalCache(t *testing.T) {
 }
 
 func TestRedisCache(t *testing.T) {
-	redisCache, err := redis.New("localhost:30379", "", "mtyzxhc", 1, "test")
+	redisCache, err := redis.New("localhost:6379", "", "mtyzxhc", 1, "test")
 	if err != nil {
 		t.Fatalf("%+v\n", err)
 	}
@@ -35,7 +36,7 @@ func TestCacheInfrastructure(t *testing.T) {
 		CacheConfig: infrastructure.CacheConfig{
 			Namespace: "test",
 			Redis: &infrastructure.RedisConfig{
-				Address:  "localhost:30379",
+				Address:  "localhost:6379",
 				UserName: "",
 				Password: "mtyzxhc",
 				DB:       1,
@@ -48,6 +49,24 @@ func TestCacheInfrastructure(t *testing.T) {
 	testCache(t, i.RedisCache())
 }
 
+func TestCacheExpire(t *testing.T) {
+	i := infrastructure.NewInfrastructure(infrastructure.Config{
+		CacheConfig: infrastructure.CacheConfig{
+			Namespace: "test",
+			Redis: &infrastructure.RedisConfig{
+				Address:  "localhost:6379",
+				UserName: "",
+				Password: "mtyzxhc",
+				DB:       1,
+			},
+		},
+	})
+	defer infrastructure.DestroyInfrastructure(i)
+
+	testCacheExpire(t, i.LocalCache())
+	testCacheExpire(t, i.RedisCache())
+}
+
 func testLocalCache(t *testing.T, localCache *local.Cache) {
 	err := localCache.Set("test1", "test1-value", 0)
 	if err != nil {
@@ -335,3 +354,87 @@ func testCache(t *testing.T, cacheInterface cache.Cache) {
 		t.Fatalf("%+v\n", err)
 	}
 }
+
+func testCacheExpire(t *testing.T, cacheInterface cache.Cache) {
+	err := cache.Set(cacheInterface, "test", 10, 1)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	time.Sleep(2 * time.Second)
+
+	value, err := cache.Get[string](cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	if strutils.IsStringNotEmpty(value) {
+		t.Fatalf("%+v\n", errors.Errorf("Value Expire Error: cache %v", value))
+	}
+
+	err = cache.Set(cacheInterface, "test", 10, 0)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	err = cache.Expire(cacheInterface, "test", 1)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	time.Sleep(2 * time.Second)
+
+	value, err = cache.Get[string](cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	if strutils.IsStringNotEmpty(value) {
+		t.Fatalf("%+v\n", errors.Errorf("Value Expire Error: cache %v", value))
+	}
+
+	err = cache.Set(cacheInterface, "test", 10, 5)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	err = cache.Expire(cacheInterface, "test", 1)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	time.Sleep(2 * time.Second)
+
+	value, err = cache.Get[string](cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	if strutils.IsStringNotEmpty(value) {
+		t.Fatalf("%+v\n", errors.Errorf("Value Expire Error: cache %v", value))
+	}
+
+	err = cache.Set(cacheInterface, "test", 10, 10)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	err = cache.Expire(cacheInterface, "test", 5)
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	ttl, err := cache.TTL(cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	if ttl != 5 {
+		t.Fatalf("%+v\n", errors.Errorf("TTL Error: ttl %v", ttl))
+	}
+
+	err = cache.Delete(cacheInterface, "test")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+}