Browse Source

添加缓存接口

yjp 2 weeks ago
parent
commit
fe290d3a17

+ 0 - 14
framework/core/infrastructure/cache/cache.go

@@ -34,9 +34,6 @@ type Cache interface {
 	// Expire 使用秒数设置key的过期时间
 	Expire(key string, expireSec int64) error
 
-	// ExpireTime 获取过期秒数
-	ExpireTime(key string) (int64, error)
-
 	// TTL 获取过期剩余秒数
 	TTL(key string) (int64, error)
 }
@@ -149,17 +146,6 @@ func Expire(cache Cache, key string, expireSec int64) error {
 	return cache.Expire(key, expireSec)
 }
 
-// ExpireTime 获取过期秒数
-// 参数:
-// - cache: 缓存基础设施接口
-// - key: 缓存的键
-// 返回值:
-// - 过期秒数
-// - 错误
-func ExpireTime(cache Cache, key string) (int64, error) {
-	return cache.ExpireTime(key)
-}
-
 // TTL 获取过期剩余秒数
 // 参数:
 // - cache: 缓存基础设施接口

+ 49 - 74
framework/core/infrastructure/cache/local/local.go

@@ -11,23 +11,20 @@ type valueItem struct {
 	value           string
 	expireStartTime time.Time
 	expireSec       int64
+	expireDoneChan  chan any
 }
 
 type Cache struct {
 	*sync.RWMutex
-	namespace                      string
-	cacheMap                       map[string]*valueItem
-	expireRoutineDoneChannelsMutex *sync.Mutex
-	expireRoutineDoneChannels      map[string]chan any
+	namespace string
+	cacheMap  map[string]*valueItem
 }
 
 func New(namespace string) *Cache {
 	return &Cache{
-		RWMutex:                        new(sync.RWMutex),
-		namespace:                      namespace,
-		cacheMap:                       nil,
-		expireRoutineDoneChannelsMutex: new(sync.Mutex),
-		expireRoutineDoneChannels:      make(map[string]chan any),
+		RWMutex:   new(sync.RWMutex),
+		namespace: namespace,
+		cacheMap:  make(map[string]*valueItem),
 	}
 }
 
@@ -39,37 +36,50 @@ func Destroy(cache *Cache) {
 	cache.Lock()
 	defer cache.Unlock()
 
-	cache.expireRoutineDoneChannelsMutex.Lock()
-
-	for _, doneChan := range cache.expireRoutineDoneChannels {
-		doneChan <- nil
-		close(doneChan)
-		doneChan = nil
+	for _, item := range cache.cacheMap {
+		if item.expireDoneChan != nil {
+			item.expireDoneChan <- nil
+		}
 	}
 
-	cache.expireRoutineDoneChannels = nil
-
-	cache.expireRoutineDoneChannelsMutex.Unlock()
-
 	cache.cacheMap = nil
 }
 
 func (cache *Cache) Set(key string, value string, expireSec int64) error {
+	cache.Lock()
+	defer cache.Unlock()
+
+	oldItem, ok := cache.cacheMap[cache.key2CacheKey(key)]
+	if ok {
+		cache.stopExpireRoutine(oldItem)
+
+		newItem := &valueItem{
+			value:           value,
+			expireStartTime: time.Now(),
+			expireSec:       expireSec,
+		}
+
+		cache.cacheMap[cache.key2CacheKey(key)] = newItem
+
+		if expireSec == 0 {
+			return nil
+		}
+
+		cache.startExpireRoutine(key, newItem)
+
+		return nil
+	}
+
 	item := &valueItem{
 		value:           value,
 		expireStartTime: time.Now(),
 		expireSec:       expireSec,
 	}
 
-	cache.Lock()
-	defer cache.Unlock()
-
 	cache.cacheMap[cache.key2CacheKey(key)] = item
 
 	if item.expireSec != 0 {
-		cache.expireRoutineDoneChannelsMutex.Lock()
-		cache.startExpireRoutineWithoutLock(key, item)
-		cache.expireRoutineDoneChannelsMutex.Unlock()
+		cache.startExpireRoutine(key, item)
 	}
 
 	return nil
@@ -138,16 +148,13 @@ func (cache *Cache) Expire(key string, expireSec int64) error {
 	cache.Lock()
 	defer cache.Unlock()
 
-	cache.expireRoutineDoneChannelsMutex.Lock()
-	defer cache.expireRoutineDoneChannelsMutex.Unlock()
-
-	cache.stopExpireRoutineWithoutLock(key)
-
 	oldItem, ok := cache.cacheMap[cache.key2CacheKey(key)]
 	if !ok {
-		return errors.New("对应的键不存在")
+		return errors.New("对应的键不存在: " + key)
 	}
 
+	cache.stopExpireRoutine(oldItem)
+
 	newItem := &valueItem{
 		value:           oldItem.value,
 		expireStartTime: time.Now(),
@@ -160,23 +167,11 @@ func (cache *Cache) Expire(key string, expireSec int64) error {
 		return nil
 	}
 
-	cache.startExpireRoutineWithoutLock(key, newItem)
+	cache.startExpireRoutine(key, newItem)
 
 	return nil
 }
 
-func (cache *Cache) ExpireTime(key string) (int64, error) {
-	cache.RLock()
-	defer cache.RUnlock()
-
-	item, ok := cache.cacheMap[cache.key2CacheKey(key)]
-	if !ok {
-		return 0, nil
-	}
-
-	return item.expireSec, nil
-}
-
 func (cache *Cache) TTL(key string) (int64, error) {
 	cache.RLock()
 	defer cache.RUnlock()
@@ -194,33 +189,17 @@ func (cache *Cache) TTL(key string) (int64, error) {
 	return remainSec, nil
 }
 
-func (cache *Cache) startExpireRoutineWithoutLock(key string, item *valueItem) {
-	doneChan := make(chan any)
-
-	if cache.expireRoutineDoneChannels != nil {
-		cache.expireRoutineDoneChannels[cache.key2CacheKey(key)] = doneChan
-	}
+func (cache *Cache) startExpireRoutine(key string, item *valueItem) {
+	item.expireDoneChan = make(chan any)
 
 	go func() {
 		timer := time.NewTimer(time.Second * time.Duration(item.expireSec))
 
 		defer func() {
-			timer.Stop()
-
-			cache.expireRoutineDoneChannelsMutex.Lock()
-
-			if cache.expireRoutineDoneChannels != nil && len(cache.expireRoutineDoneChannels) > 0 {
-				savedDoneChan, ok := cache.expireRoutineDoneChannels[cache.key2CacheKey(key)]
-				if !ok {
-					return
-				}
-
-				close(savedDoneChan)
-				savedDoneChan = nil
-				delete(cache.expireRoutineDoneChannels, cache.key2CacheKey(key))
+			if item.expireDoneChan != nil {
+				close(item.expireDoneChan)
+				item.expireDoneChan = nil
 			}
-
-			cache.expireRoutineDoneChannelsMutex.Unlock()
 		}()
 
 		for {
@@ -228,22 +207,18 @@ func (cache *Cache) startExpireRoutineWithoutLock(key string, item *valueItem) {
 			case <-timer.C:
 				delete(cache.cacheMap, cache.key2CacheKey(key))
 				return
-			case <-doneChan:
+			case <-item.expireDoneChan:
+				timer.Stop()
 				return
 			}
 		}
 	}()
 }
 
-func (cache *Cache) stopExpireRoutineWithoutLock(key string) {
-	doneChan, ok := cache.expireRoutineDoneChannels[cache.key2CacheKey(key)]
-	if ok {
-		doneChan <- nil
-		close(doneChan)
-		doneChan = nil
+func (cache *Cache) stopExpireRoutine(item *valueItem) {
+	if item.expireDoneChan != nil {
+		item.expireDoneChan <- nil
 	}
-
-	delete(cache.expireRoutineDoneChannels, cache.key2CacheKey(key))
 }
 
 func (cache *Cache) key2CacheKey(key string) string {

+ 0 - 13
framework/core/infrastructure/cache/redis/redis.go

@@ -150,19 +150,6 @@ func (cache *Cache) Expire(key string, expireSec int64) error {
 	return nil
 }
 
-func (cache *Cache) ExpireTime(key string) (int64, error) {
-	cmd := cache.redisClient.ExpireTime(context.Background(), cache.key2CacheKey(key))
-	if cmd.Err() != nil {
-		if cmd.Err().Error() == redis.Nil.Error() {
-			return 0, nil
-		}
-
-		return 0, errors.New(cmd.Err().Error())
-	}
-
-	return int64(cmd.Val().Seconds()), nil
-}
-
 func (cache *Cache) TTL(key string) (int64, error) {
 	cmd := cache.redisClient.TTL(context.Background(), cache.key2CacheKey(key))
 	if cmd.Err() != nil {

+ 3 - 1
framework/core/infrastructure/database/clause/condition.go

@@ -17,7 +17,9 @@ func NewConditions() *Conditions {
 
 func (conditions *Conditions) AddCondition(query string, args ...any) *Conditions {
 	conditions.queries = append(conditions.queries, query)
-	conditions.args = append(conditions.args, args...)
+	for _, arg := range args {
+		conditions.args = append(conditions.args, []any{arg})
+	}
 	return conditions
 }
 

+ 3 - 1
framework/core/infrastructure/database/sql/conditions.go

@@ -25,7 +25,9 @@ func NewConditions() *Conditions {
 // - 数据库条件
 func (conditions *Conditions) AddCondition(query string, args ...any) *Conditions {
 	conditions.queries = append(conditions.queries, query)
-	conditions.args = append(conditions.args, args...)
+	for _, arg := range args {
+		conditions.args = append(conditions.args, []any{arg})
+	}
 	return conditions
 }
 

+ 7 - 22
test/cache_test.go

@@ -17,7 +17,7 @@ func TestLocalCache(t *testing.T) {
 }
 
 func TestRedisCache(t *testing.T) {
-	redisCache, err := redis.New("localhost:30379", "", "mtyzxhc", 1, "test")
+	redisCache, err := redis.New("localhost:6379", "", "mtyzxhc", 1, "test")
 	if err != nil {
 		t.Fatalf("%+v\n", err)
 	}
@@ -36,7 +36,7 @@ func TestCacheInfrastructure(t *testing.T) {
 		CacheConfig: infrastructure.CacheConfig{
 			Namespace: "test",
 			Redis: &infrastructure.RedisConfig{
-				Address:  "localhost:30379",
+				Address:  "localhost:6379",
 				UserName: "",
 				Password: "mtyzxhc",
 				DB:       1,
@@ -54,7 +54,7 @@ func TestCacheExpire(t *testing.T) {
 		CacheConfig: infrastructure.CacheConfig{
 			Namespace: "test",
 			Redis: &infrastructure.RedisConfig{
-				Address:  "localhost:30379",
+				Address:  "localhost:6379",
 				UserName: "",
 				Password: "mtyzxhc",
 				DB:       1,
@@ -361,7 +361,7 @@ func testCacheExpire(t *testing.T, cacheInterface cache.Cache) {
 		t.Fatalf("%+v\n", err)
 	}
 
-	time.Sleep(1 * time.Second)
+	time.Sleep(2 * time.Second)
 
 	value, err := cache.Get[string](cacheInterface, "test")
 	if err != nil {
@@ -382,7 +382,7 @@ func testCacheExpire(t *testing.T, cacheInterface cache.Cache) {
 		t.Fatalf("%+v\n", err)
 	}
 
-	time.Sleep(1 * time.Second)
+	time.Sleep(2 * time.Second)
 
 	value, err = cache.Get[string](cacheInterface, "test")
 	if err != nil {
@@ -393,7 +393,7 @@ func testCacheExpire(t *testing.T, cacheInterface cache.Cache) {
 		t.Fatalf("%+v\n", errors.Errorf("Value Expire Error: cache %v", value))
 	}
 
-	err = cache.Set(cacheInterface, "test", 10, 10)
+	err = cache.Set(cacheInterface, "test", 10, 5)
 	if err != nil {
 		t.Fatalf("%+v\n", err)
 	}
@@ -403,7 +403,7 @@ func testCacheExpire(t *testing.T, cacheInterface cache.Cache) {
 		t.Fatalf("%+v\n", err)
 	}
 
-	time.Sleep(1 * time.Second)
+	time.Sleep(2 * time.Second)
 
 	value, err = cache.Get[string](cacheInterface, "test")
 	if err != nil {
@@ -419,26 +419,11 @@ func testCacheExpire(t *testing.T, cacheInterface cache.Cache) {
 		t.Fatalf("%+v\n", err)
 	}
 
-	expired, err := cache.ExpireTime(cacheInterface, "test")
-	if err != nil {
-		t.Fatalf("%+v\n", err)
-	}
-
-	if expired != 10 {
-		t.Fatalf("%+v\n", errors.Errorf("Expire Error: expired %v", expired))
-	}
-
 	err = cache.Expire(cacheInterface, "test", 5)
-
-	expired, err = cache.ExpireTime(cacheInterface, "test")
 	if err != nil {
 		t.Fatalf("%+v\n", err)
 	}
 
-	if expired != 5 {
-		t.Fatalf("%+v\n", errors.Errorf("Expire Error: expired %v", expired))
-	}
-
 	ttl, err := cache.TTL(cacheInterface, "test")
 	if err != nil {
 		t.Fatalf("%+v\n", err)