diff --git a/internal/rpc/admin/token.go b/internal/rpc/admin/token.go index a98511d74..c06ad344b 100644 --- a/internal/rpc/admin/token.go +++ b/internal/rpc/admin/token.go @@ -23,12 +23,13 @@ import ( "github.com/redis/go-redis/v9" ) -func (o *adminServer) CreateToken(ctx context.Context, req *adminpb.CreateTokenReq) (*adminpb.CreateTokenResp, error) { - token, err := o.Token.CreateToken(req.UserID, req.UserType) +func (o *adminServer) CreateToken(ctx context.Context, req *admin.CreateTokenReq) (*admin.CreateTokenResp, error) { + token, expire, err := o.Token.CreateToken(req.UserID, req.UserType) + if err != nil { return nil, err } - err = o.Database.CacheToken(ctx, req.UserID, token) + err = o.Database.CacheToken(ctx, req.UserID, token, expire) if err != nil { return nil, err } diff --git a/pkg/common/db/cache/token.go b/pkg/common/db/cache/token.go index 9da188a7e..a57571ebc 100644 --- a/pkg/common/db/cache/token.go +++ b/pkg/common/db/cache/token.go @@ -18,6 +18,7 @@ import ( "context" "github.com/openimsdk/tools/utils/stringutil" + "time" "github.com/openimsdk/tools/errs" "github.com/redis/go-redis/v9" @@ -29,12 +30,14 @@ const ( type TokenInterface interface { AddTokenFlag(ctx context.Context, userID string, token string, flag int) error + AddTokenFlagNXEx(ctx context.Context, userID string, token string, flag int, expire time.Duration) (bool, error) GetTokensWithoutError(ctx context.Context, userID string) (map[string]int32, error) DeleteTokenByUid(ctx context.Context, userID string) error } type TokenCacheRedis struct { - rdb redis.UniversalClient + rdb redis.UniversalClient + accessExpire int64 } func NewTokenInterface(rdb redis.UniversalClient) *TokenCacheRedis { @@ -46,6 +49,22 @@ func (t *TokenCacheRedis) AddTokenFlag(ctx context.Context, userID string, token return errs.Wrap(t.rdb.HSet(ctx, key, token, flag).Err()) } +func (t *TokenCacheRedis) AddTokenFlagNXEx(ctx context.Context, userID string, token string, flag int, expire time.Duration) (bool, error) { + key := chatToken + userID + isSet, err := t.rdb.HSetNX(ctx, key, token, flag).Result() + if err != nil { + return false, errs.Wrap(err) + } + if !isSet { + // key already exists + return false, nil + } + if err = t.rdb.Expire(ctx, key, expire).Err(); err != nil { + return false, errs.Wrap(err) + } + return isSet, nil +} + func (t *TokenCacheRedis) GetTokensWithoutError(ctx context.Context, userID string) (map[string]int32, error) { key := chatToken + userID m, err := t.rdb.HGetAll(ctx, key).Result() diff --git a/pkg/common/db/database/admin.go b/pkg/common/db/database/admin.go index 2b867f5f9..820864df1 100644 --- a/pkg/common/db/database/admin.go +++ b/pkg/common/db/database/admin.go @@ -16,6 +16,7 @@ package database import ( "context" + "time" "github.com/openimsdk/chat/pkg/common/db/cache" "github.com/openimsdk/protocol/constant" @@ -74,7 +75,7 @@ type AdminDatabaseInterface interface { DelUserLimitLogin(ctx context.Context, ms []*admindb.LimitUserLoginIP) error CountLimitUserLoginIP(ctx context.Context, userID string) (uint32, error) GetLimitUserLoginIP(ctx context.Context, userID string, ip string) (*admindb.LimitUserLoginIP, error) - CacheToken(ctx context.Context, userID string, token string) error + CacheToken(ctx context.Context, userID string, token string, expire time.Duration) error GetTokens(ctx context.Context, userID string) (map[string]int32, error) DeleteToken(ctx context.Context, userID string) error } @@ -325,8 +326,18 @@ func (o *AdminDatabase) GetLimitUserLoginIP(ctx context.Context, userID string, return o.limitUserLoginIP.Take(ctx, userID, ip) } -func (o *AdminDatabase) CacheToken(ctx context.Context, userID string, token string) error { - return o.cache.AddTokenFlag(ctx, userID, token, constant.NormalToken) +func (o *AdminDatabase) CacheToken(ctx context.Context, userID string, token string, expire time.Duration) error { + isSet, err := o.cache.AddTokenFlagNXEx(ctx, userID, token, constant.NormalToken, expire) + if err != nil { + return err + } + if !isSet { + // already exists, update + if err = o.cache.AddTokenFlag(ctx, userID, token, constant.NormalToken); err != nil { + return err + } + } + return nil } func (o *AdminDatabase) GetTokens(ctx context.Context, userID string) (map[string]int32, error) { diff --git a/pkg/common/tokenverify/token_verify.go b/pkg/common/tokenverify/token_verify.go index 9fdfbac54..3c35ffa71 100644 --- a/pkg/common/tokenverify/token_verify.go +++ b/pkg/common/tokenverify/token_verify.go @@ -86,16 +86,16 @@ func (t *Token) getToken(str string) (string, int32, error) { } } -func (t *Token) CreateToken(UserID string, userType int32) (string, error) { +func (t *Token) CreateToken(UserID string, userType int32) (string, time.Duration, error) { if !(userType == TokenUser || userType == TokenAdmin) { - return "", errs.ErrTokenUnknown.WrapMsg("token type unknown") + return "", 0, errs.ErrTokenUnknown.WrapMsg("token type unknown") } token := jwt.NewWithClaims(jwt.SigningMethodHS256, t.buildClaims(UserID, userType)) str, err := token.SignedString([]byte(t.Secret)) if err != nil { - return "", errs.Wrap(err) + return "", 0, errs.Wrap(err) } - return str, nil + return str, t.Expires, nil } func (t *Token) GetToken(token string) (string, int32, error) {