From e3d4aedeb272a1ccf920116439f118c6ef5322ec Mon Sep 17 00:00:00 2001 From: godcong Date: Tue, 26 Nov 2024 03:40:32 +0800 Subject: [PATCH] refactor(cache): implement a new memory cache for security module - Rename memory cache package from storage/cache/memory to contrib/cache/memory - Add a new security/cache.go file implementing a simple memory cache for security purposes - Update security/storage.go to use the new security cache instead of the old memory cache - Refactor cache/errors.go to use more generic error names --- {storage => contrib}/cache/memory/memory.go | 5 +- security/cache.go | 80 +++++++++++++++++++++ security/storage.go | 7 +- storage/cache/error.go | 4 +- 4 files changed, 87 insertions(+), 9 deletions(-) rename {storage => contrib}/cache/memory/memory.go (99%) create mode 100644 security/cache.go diff --git a/storage/cache/memory/memory.go b/contrib/cache/memory/memory.go similarity index 99% rename from storage/cache/memory/memory.go rename to contrib/cache/memory/memory.go index 9a86d11..3fb0956 100644 --- a/storage/cache/memory/memory.go +++ b/contrib/cache/memory/memory.go @@ -13,10 +13,13 @@ import ( ) const ( - ErrNotFound = errors.String("not found") defaultSize = 64 * 1024 * 1024 ) +const ( + ErrNotFound = errors.String("not found") +) + type Cache struct { Expiration time.Duration CleanupInterval time.Duration diff --git a/security/cache.go b/security/cache.go new file mode 100644 index 0000000..541f4d3 --- /dev/null +++ b/security/cache.go @@ -0,0 +1,80 @@ +// Package security implements the functions, types, and interfaces for the module. +package security + +import ( + "sync" + "time" + + "github.com/origadmin/toolkits/errors" + + "github.com/origadmin/toolkits/context" + "github.com/origadmin/toolkits/storage/cache" +) + +type element struct { + value string + expireAt time.Time +} + +type securityCache struct { + maps sync.Map +} + +func (s *securityCache) Get(ctx context.Context, key string) (string, error) { + value, ok := s.maps.Load(key) + if !ok { + return "", cache.ErrNotFound + } + ele, ok := value.(*element) + if !ok { + return "", errors.New("invalid cache value") + } + if ele.expireAt.Before(time.Now()) { + _ = s.Delete(ctx, key) + return "", cache.ErrNotFound + } + return ele.value, nil +} + +func (s *securityCache) GetAndDelete(ctx context.Context, key string) (string, error) { + value, ok := s.maps.LoadAndDelete(key) + if !ok { + return "", cache.ErrNotFound + } + ele, ok := value.(*element) + if !ok { + return "", errors.New("invalid cache value") + } + return ele.value, nil +} + +func (s *securityCache) Exists(ctx context.Context, key string) error { + _, ok := s.maps.Load(key) + if !ok { + return cache.ErrNotFound + } + return nil +} + +func (s *securityCache) Set(ctx context.Context, key string, value string, expiration ...time.Duration) error { + var expireAt time.Time + if len(expiration) > 0 { + expireAt = time.Now().Add(expiration[0]) + } else { + expireAt = time.Now().Add(time.Hour) + } + ele := &element{value: value, expireAt: expireAt} + s.maps.Store(key, ele) + return nil +} + +func (s *securityCache) Delete(ctx context.Context, key string) error { + s.maps.Delete(key) + return nil +} + +func NewSecurityCache() cache.Cache { + return &securityCache{} +} + +var _ cache.Cache = (*securityCache)(nil) diff --git a/security/storage.go b/security/storage.go index 2895c24..d40e007 100644 --- a/security/storage.go +++ b/security/storage.go @@ -10,7 +10,6 @@ import ( "github.com/goexts/generic/settings" "github.com/origadmin/toolkits/storage/cache" - "github.com/origadmin/toolkits/storage/cache/memory" ) const ( @@ -62,11 +61,7 @@ func NewTokenStorage(ss ...StorageSetting) TokenStorage { }, ss) if opt.Cache == nil { - c := memory.NewCache() - c.DefaultExpiration = 24 * time.Hour - c.CleanupInterval = 30 * time.Minute - c.Delimiter = ":" - opt.Cache = c + opt.Cache = NewSecurityCache() } s := &tokenStorage{ diff --git a/storage/cache/error.go b/storage/cache/error.go index ed54b73..8d2717a 100644 --- a/storage/cache/error.go +++ b/storage/cache/error.go @@ -37,8 +37,8 @@ func (c *cacheError) Is(err error) bool { } var ( - ErrCacheClosed error = &cacheError{msg: "cache closed"} - ErrCacheNotFound error = &cacheError{msg: "cache not found"} + ErrClosed error = &cacheError{msg: "cache closed"} + ErrNotFound error = &cacheError{msg: "cache not found"} ) func NewError(msg string) error {