From af4ec6b8f73b61fdc28d6d0f7a3268d87188377e Mon Sep 17 00:00:00 2001 From: chao <48119764+withchao@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:20:01 +0800 Subject: [PATCH] fix: errors caused by too many tokens (#608) * fix: too many tokens * fix: too many tokens * fix: too many tokens --- internal/api/mw/mw.go | 44 +---------- internal/rpc/admin/start.go | 10 +-- internal/rpc/admin/token.go | 1 - pkg/common/db/cache/token.go | 103 ++++++++++++++++++------- pkg/common/db/database/admin.go | 18 +---- pkg/common/tokenverify/token_verify.go | 15 ++++ 6 files changed, 103 insertions(+), 88 deletions(-) diff --git a/internal/api/mw/mw.go b/internal/api/mw/mw.go index 8f8bb2a4..fa109243 100644 --- a/internal/api/mw/mw.go +++ b/internal/api/mw/mw.go @@ -20,7 +20,6 @@ import ( "github.com/gin-gonic/gin" "github.com/openimsdk/chat/pkg/common/constant" "github.com/openimsdk/chat/pkg/protocol/admin" - constantpb "github.com/openimsdk/protocol/constant" "github.com/openimsdk/tools/apiresp" "github.com/openimsdk/tools/errs" ) @@ -56,74 +55,37 @@ func (o *MW) parseTokenType(c *gin.Context, userType int32) (string, string, err return userID, token, nil } -func (o *MW) isValidToken(c *gin.Context, userID string, token string) error { - resp, err := o.client.GetUserToken(c, &admin.GetUserTokenReq{UserID: userID}) - if err != nil { - return err - } - if len(resp.TokensMap) == 0 { - return errs.ErrTokenExpired.Wrap() - } - if v, ok := resp.TokensMap[token]; ok { - switch v { - case constantpb.NormalToken: - case constantpb.KickedToken: - return errs.ErrTokenExpired.Wrap() - default: - return errs.ErrTokenUnknown.Wrap() - } - } else { - return errs.ErrTokenExpired.Wrap() - } - return nil -} - func (o *MW) setToken(c *gin.Context, userID string, userType int32) { SetToken(c, userID, userType) } func (o *MW) CheckToken(c *gin.Context) { - userID, userType, token, err := o.parseToken(c) + userID, userType, _, err := o.parseToken(c) if err != nil { c.Abort() apiresp.GinError(c, err) return } - if err := o.isValidToken(c, userID, token); err != nil { - c.Abort() - apiresp.GinError(c, err) - return - } o.setToken(c, userID, userType) } func (o *MW) CheckAdmin(c *gin.Context) { - userID, token, err := o.parseTokenType(c, constant.AdminUser) + userID, _, err := o.parseTokenType(c, constant.AdminUser) if err != nil { c.Abort() apiresp.GinError(c, err) return } - if err := o.isValidToken(c, userID, token); err != nil { - c.Abort() - apiresp.GinError(c, err) - return - } o.setToken(c, userID, constant.AdminUser) } func (o *MW) CheckUser(c *gin.Context) { - userID, token, err := o.parseTokenType(c, constant.NormalUser) + userID, _, err := o.parseTokenType(c, constant.NormalUser) if err != nil { c.Abort() apiresp.GinError(c, err) return } - if err := o.isValidToken(c, userID, token); err != nil { - c.Abort() - apiresp.GinError(c, err) - return - } o.setToken(c, userID, constant.NormalUser) } diff --git a/internal/rpc/admin/start.go b/internal/rpc/admin/start.go index 54de3be3..d08c31a7 100644 --- a/internal/rpc/admin/start.go +++ b/internal/rpc/admin/start.go @@ -47,7 +47,11 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg return err } var srv adminServer - srv.Database, err = database.NewAdminDatabase(mgocli, rdb) + srv.Token = &tokenverify.Token{ + Expires: time.Duration(config.RpcConfig.TokenPolicy.Expire) * time.Hour * 24, + Secret: config.RpcConfig.Secret, + } + srv.Database, err = database.NewAdminDatabase(mgocli, rdb, srv.Token) if err != nil { return err } @@ -56,10 +60,6 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg return err } srv.Chat = chatClient.NewChatClient(chat.NewChatClient(conn)) - srv.Token = &tokenverify.Token{ - Expires: time.Duration(config.RpcConfig.TokenPolicy.Expire) * time.Hour * 24, - Secret: config.RpcConfig.Secret, - } if err := srv.initAdmin(ctx, config.Share.ChatAdmin, config.Share.OpenIM.AdminUserID); err != nil { return err } diff --git a/internal/rpc/admin/token.go b/internal/rpc/admin/token.go index dc7fa191..57e81d6a 100644 --- a/internal/rpc/admin/token.go +++ b/internal/rpc/admin/token.go @@ -26,7 +26,6 @@ import ( func (o *adminServer) CreateToken(ctx context.Context, req *adminpb.CreateTokenReq) (*adminpb.CreateTokenResp, error) { token, expire, err := o.Token.CreateToken(req.UserID, req.UserType) - if err != nil { return nil, err } diff --git a/pkg/common/db/cache/token.go b/pkg/common/db/cache/token.go index a57571eb..3ad6b6ca 100644 --- a/pkg/common/db/cache/token.go +++ b/pkg/common/db/cache/token.go @@ -16,8 +16,9 @@ package cache import ( "context" - + "github.com/openimsdk/chat/pkg/common/tokenverify" "github.com/openimsdk/tools/utils/stringutil" + "sort" "time" "github.com/openimsdk/tools/errs" @@ -25,44 +26,24 @@ import ( ) const ( - chatToken = "CHAT_UID_TOKEN_STATUS:" + chatToken = "CHAT_UID_TOKEN_STATUS:" + userMaxTokenNum = 10 ) type TokenInterface interface { - AddTokenFlag(ctx context.Context, userID string, token string, flag int) error - AddTokenFlagNXEx(ctx context.Context, userID string, token string, flag int, expire time.Duration) (bool, error) + SetTokenExpire(ctx context.Context, userID string, token string, expire time.Duration) error GetTokensWithoutError(ctx context.Context, userID string) (map[string]int32, error) DeleteTokenByUid(ctx context.Context, userID string) error } type TokenCacheRedis struct { + token *tokenverify.Token rdb redis.UniversalClient accessExpire int64 } -func NewTokenInterface(rdb redis.UniversalClient) *TokenCacheRedis { - return &TokenCacheRedis{rdb: rdb} -} - -func (t *TokenCacheRedis) AddTokenFlag(ctx context.Context, userID string, token string, flag int) error { - key := chatToken + userID - return errs.Wrap(t.rdb.HSet(ctx, key, token, flag).Err()) -} - -func (t *TokenCacheRedis) AddTokenFlagNXEx(ctx context.Context, userID string, token string, flag int, expire time.Duration) (bool, error) { - key := chatToken + userID - isSet, err := t.rdb.HSetNX(ctx, key, token, flag).Result() - if err != nil { - return false, errs.Wrap(err) - } - if !isSet { - // key already exists - return false, nil - } - if err = t.rdb.Expire(ctx, key, expire).Err(); err != nil { - return false, errs.Wrap(err) - } - return isSet, nil +func NewTokenInterface(rdb redis.UniversalClient, token *tokenverify.Token) *TokenCacheRedis { + return &TokenCacheRedis{rdb: rdb, token: token} } func (t *TokenCacheRedis) GetTokensWithoutError(ctx context.Context, userID string) (map[string]int32, error) { @@ -82,3 +63,71 @@ func (t *TokenCacheRedis) DeleteTokenByUid(ctx context.Context, userID string) e key := chatToken + userID return errs.Wrap(t.rdb.Del(ctx, key).Err()) } + +func (t *TokenCacheRedis) SetTokenExpire(ctx context.Context, userID string, token string, expire time.Duration) error { + key := chatToken + userID + if err := t.rdb.HSet(ctx, key, token, "0").Err(); err != nil { + return errs.Wrap(err) + } + if err := t.rdb.Expire(ctx, key, expire).Err(); err != nil { + return errs.Wrap(err) + } + mm, err := t.rdb.HGetAll(ctx, key).Result() + if err != nil { + return errs.Wrap(err) + } + if len(mm) <= 1 { + return nil + } + var ( + fields []string + ts tokenTimes + ) + now := time.Now() + for k := range mm { + if k == token { + continue + } + val := t.token.GetExpire(k) + if val.IsZero() || val.Before(now) { + fields = append(fields, k) + } else { + ts = append(ts, tokenTime{Token: k, Time: val}) + } + } + var sorted bool + var index int + for i := len(mm) - len(fields); i > userMaxTokenNum; i-- { + if !sorted { + sorted = true + sort.Sort(ts) + } + fields = append(fields, ts[index].Token) + index++ + } + if len(fields) > 0 { + if err := t.rdb.HDel(ctx, key, fields...).Err(); err != nil { + return errs.Wrap(err) + } + } + return nil +} + +type tokenTime struct { + Token string + Time time.Time +} + +type tokenTimes []tokenTime + +func (t tokenTimes) Len() int { + return len(t) +} + +func (t tokenTimes) Less(i, j int) bool { + return t[i].Time.Before(t[j].Time) +} + +func (t tokenTimes) Swap(i, j int) { + t[i], t[j] = t[j], t[i] +} diff --git a/pkg/common/db/database/admin.go b/pkg/common/db/database/admin.go index 820864df..bf485977 100644 --- a/pkg/common/db/database/admin.go +++ b/pkg/common/db/database/admin.go @@ -16,10 +16,10 @@ package database import ( "context" + "github.com/openimsdk/chat/pkg/common/tokenverify" "time" "github.com/openimsdk/chat/pkg/common/db/cache" - "github.com/openimsdk/protocol/constant" "github.com/openimsdk/tools/db/mongoutil" "github.com/openimsdk/tools/db/pagination" "github.com/openimsdk/tools/db/tx" @@ -80,7 +80,7 @@ type AdminDatabaseInterface interface { DeleteToken(ctx context.Context, userID string) error } -func NewAdminDatabase(cli *mongoutil.Client, rdb redis.UniversalClient) (AdminDatabaseInterface, error) { +func NewAdminDatabase(cli *mongoutil.Client, rdb redis.UniversalClient, token *tokenverify.Token) (AdminDatabaseInterface, error) { a, err := admin.NewAdmin(cli.GetDB()) if err != nil { return nil, err @@ -128,7 +128,7 @@ func NewAdminDatabase(cli *mongoutil.Client, rdb redis.UniversalClient) (AdminDa registerAddGroup: registerAddGroup, applet: applet, clientConfig: clientConfig, - cache: cache.NewTokenInterface(rdb), + cache: cache.NewTokenInterface(rdb, token), }, nil } @@ -327,17 +327,7 @@ func (o *AdminDatabase) GetLimitUserLoginIP(ctx context.Context, userID string, } func (o *AdminDatabase) CacheToken(ctx context.Context, userID string, token string, expire time.Duration) error { - isSet, err := o.cache.AddTokenFlagNXEx(ctx, userID, token, constant.NormalToken, expire) - if err != nil { - return err - } - if !isSet { - // already exists, update - if err = o.cache.AddTokenFlag(ctx, userID, token, constant.NormalToken); err != nil { - return err - } - } - return nil + return o.cache.SetTokenExpire(ctx, userID, token, expire) } func (o *AdminDatabase) GetTokens(ctx context.Context, userID string) (map[string]int32, error) { diff --git a/pkg/common/tokenverify/token_verify.go b/pkg/common/tokenverify/token_verify.go index 3c35ffa7..59521873 100644 --- a/pkg/common/tokenverify/token_verify.go +++ b/pkg/common/tokenverify/token_verify.go @@ -109,6 +109,21 @@ func (t *Token) GetToken(token string) (string, int32, error) { return userID, userType, nil } +func (t *Token) GetExpire(token string) time.Time { + val, err := jwt.ParseWithClaims(token, &claims{}, t.secret()) + if err != nil { + return time.Time{} + } + c, ok := val.Claims.(*claims) + if !ok { + return time.Time{} + } + if c.ExpiresAt == nil { + return time.Time{} + } + return c.ExpiresAt.Time +} + //func (t *Token) GetAdminToken(token string) (string, error) { // userID, userType, err := getToken(token) // if err != nil {