From c00d6c5bb4f1fcb29348ccb26d990e0f73c7c10f Mon Sep 17 00:00:00 2001 From: icey-yu <1186114839@qq.com> Date: Fri, 22 Nov 2024 12:08:00 +0800 Subject: [PATCH] fix: admin token limit --- internal/rpc/auth/auth.go | 10 +++- pkg/common/storage/controller/auth.go | 69 +++++++++++++++------------ pkg/common/storage/controller/msg.go | 2 +- 3 files changed, 48 insertions(+), 33 deletions(-) diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index 62df74d214..a1acfd9313 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -16,6 +16,7 @@ package auth import ( "context" + "errors" "github.com/openimsdk/open-im-server/v3/pkg/common/config" redis2 "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis" @@ -66,6 +67,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg config.Share.Secret, config.RpcConfig.TokenPolicy.Expire, config.Share.MultiLogin, + config.Share.IMAdminUserID, ), config: config, }) @@ -129,6 +131,10 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim if err != nil { return nil, errs.Wrap(err) } + isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID) + if isAdmin { + return claims, nil + } m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID) if err != nil { return nil, err @@ -190,7 +196,7 @@ func (s *authServer) forceKickOff(ctx context.Context, userID string, platformID } m, err := s.authDatabase.GetTokensWithoutError(ctx, userID, int(platformID)) - if err != nil && err != redis.Nil { + if err != nil && errors.Is(err, redis.Nil) { return err } for k := range m { @@ -208,7 +214,7 @@ func (s *authServer) forceKickOff(ctx context.Context, userID string, platformID func (s *authServer) InvalidateToken(ctx context.Context, req *pbauth.InvalidateTokenReq) (*pbauth.InvalidateTokenResp, error) { m, err := s.authDatabase.GetTokensWithoutError(ctx, req.UserID, int(req.PlatformID)) - if err != nil && err != redis.Nil { + if err != nil && errors.Is(err, redis.Nil) { return nil, err } if m == nil { diff --git a/pkg/common/storage/controller/auth.go b/pkg/common/storage/controller/auth.go index df12749677..e7b5bc297c 100644 --- a/pkg/common/storage/controller/auth.go +++ b/pkg/common/storage/controller/auth.go @@ -35,9 +35,10 @@ type authDatabase struct { accessSecret string accessExpire int64 multiLogin multiLoginConfig + adminUserIDs []string } -func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, multiLogin config.MultiLogin) AuthDatabase { +func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, multiLogin config.MultiLogin, adminUserIDs []string) AuthDatabase { return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, multiLogin: multiLoginConfig{ Policy: multiLogin.Policy, MaxNumOneEnd: multiLogin.MaxNumOneEnd, @@ -53,7 +54,8 @@ func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire i constant.IPadPlatformID: multiLogin.CustomizeLoginNum.IPad, constant.AdminPlatformID: multiLogin.CustomizeLoginNum.Admin, }, - }} + }, adminUserIDs: adminUserIDs, + } } // If the result is empty. @@ -90,27 +92,31 @@ func (a *authDatabase) BatchSetTokenMapByUidPid(ctx context.Context, tokens []st // Create Token. func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) { - tokens, err := a.cache.GetAllTokensWithoutError(ctx, userID) - if err != nil { - return "", err - } - deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID) - if err != nil { - return "", err - } - if len(deleteTokenKey) != 0 { - err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) + isAdmin := authverify.IsManagerUserID(userID, a.adminUserIDs) + if !isAdmin { + tokens, err := a.cache.GetAllTokensWithoutError(ctx, userID) if err != nil { return "", err } - } - if len(kickedTokenKey) != 0 { - for _, k := range kickedTokenKey { - err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken) + + deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID) + if err != nil { + return "", err + } + if len(deleteTokenKey) != 0 { + err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) if err != nil { return "", err } - log.ZDebug(ctx, "kicked token in create token", "token", k) + } + if len(kickedTokenKey) != 0 { + for _, k := range kickedTokenKey { + err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken) + if err != nil { + return "", err + } + log.ZDebug(ctx, "kicked token in create token", "token", k) + } } } @@ -121,9 +127,12 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI return "", errs.WrapMsg(err, "token.SignedString") } - if err = a.cache.SetTokenFlagEx(ctx, userID, platformID, tokenString, constant.NormalToken); err != nil { - return "", err + if !isAdmin { + if err = a.cache.SetTokenFlagEx(ctx, userID, platformID, tokenString, constant.NormalToken); err != nil { + return "", err + } } + return tokenString, nil } @@ -226,16 +235,16 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string return nil, nil, errs.New("unknown multiLogin policy").Wrap() } - var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd - if a.multiLogin.Policy == constant.Customize { - adminTokenMaxNum = a.multiLogin.CustomizeLoginNum[constant.AdminPlatformID] - } - l := len(adminToken) - if platformID == constant.AdminPlatformID { - l++ - } - if l > adminTokenMaxNum { - kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...) - } + //var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd + //if a.multiLogin.Policy == constant.Customize { + // adminTokenMaxNum = a.multiLogin.CustomizeLoginNum[constant.AdminPlatformID] + //} + //l := len(adminToken) + //if platformID == constant.AdminPlatformID { + // l++ + //} + //if l > adminTokenMaxNum { + // kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...) + //} return deleteToken, kickToken, nil } diff --git a/pkg/common/storage/controller/msg.go b/pkg/common/storage/controller/msg.go index 90b4790646..9636f7a152 100644 --- a/pkg/common/storage/controller/msg.go +++ b/pkg/common/storage/controller/msg.go @@ -490,7 +490,7 @@ func (db *commonMsgDatabase) GetMsgBySeqs(ctx context.Context, userID string, co } successMsgs, failedSeqs, err := db.msg.GetMessagesBySeq(ctx, conversationID, newSeqs) if err != nil { - if err != redis.Nil { + if errors.Is(err, redis.Nil) { log.ZError(ctx, "get message from redis exception", err, "failedSeqs", failedSeqs, "conversationID", conversationID) } }