diff --git a/config/share.yml b/config/share.yml index 5f8521eaa9..7d977ae150 100644 --- a/config/share.yml +++ b/config/share.yml @@ -13,4 +13,17 @@ rpcRegisterName: imAdminUserID: [ imAdmin ] # 1: For Android, iOS, Windows, Mac, and web platforms, only one instance can be online at a time -multiLoginPolicy: 1 +multiLogin: + policy: 1 + maxNumOneEnd: 30 + customizeLoginNum: + ios: 1 + android: 1 + windows: 1 + osx: 1 + web: 1 + miniWeb: 1 + linux: 1 + aPad: 1 + iPad: 1 + admin: 1 diff --git a/go.mod b/go.mod index b6baca2a13..aea5360d70 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/gorilla/websocket v1.5.1 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/openimsdk/protocol v0.0.72-alpha.41 + github.com/openimsdk/protocol v0.0.72-alpha.45 github.com/openimsdk/tools v0.0.50-alpha.16 github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/client_golang v1.18.0 diff --git a/go.sum b/go.sum index 6f54752741..4732726609 100644 --- a/go.sum +++ b/go.sum @@ -319,8 +319,8 @@ github.com/onsi/gomega v1.25.0 h1:Vw7br2PCDYijJHSfBOWhov+8cAnUf8MfMaIOV323l6Y= github.com/onsi/gomega v1.25.0/go.mod h1:r+zV744Re+DiYCIPRlYOTxn0YkOLcAnW8k1xXdMPGhM= github.com/openimsdk/gomake v0.0.14-alpha.5 h1:VY9c5x515lTfmdhhPjMvR3BBRrRquAUCFsz7t7vbv7Y= github.com/openimsdk/gomake v0.0.14-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= -github.com/openimsdk/protocol v0.0.72-alpha.41 h1:SMMoTc1iu+wtRqUqmIgqPJFejLgPeauOwoJ4VVG4iMQ= -github.com/openimsdk/protocol v0.0.72-alpha.41/go.mod h1:OZQA9FR55lseYoN2Ql1XAHYKHJGu7OMNkUbuekrKCM8= +github.com/openimsdk/protocol v0.0.72-alpha.45 h1:xTxEG/NzBw/ZxLggqz76l7rl9HUfg7Kb2xS+jU0G2E4= +github.com/openimsdk/protocol v0.0.72-alpha.45/go.mod h1:OZQA9FR55lseYoN2Ql1XAHYKHJGu7OMNkUbuekrKCM8= github.com/openimsdk/tools v0.0.50-alpha.16 h1:bC1AQvJMuOHtZm8LZRvN8L5mH1Ws2VYdL+TLTs1iGSc= github.com/openimsdk/tools v0.0.50-alpha.16/go.mod h1:h1cYmfyaVtgFbKmb1Cfsl8XwUOMTt8ubVUQrdGtsUh4= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= diff --git a/internal/msggateway/ws_server.go b/internal/msggateway/ws_server.go index 7df2974885..b92d7eb442 100644 --- a/internal/msggateway/ws_server.go +++ b/internal/msggateway/ws_server.go @@ -1,17 +1,3 @@ -// Copyright © 2023 OpenIM. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package msggateway import ( @@ -212,7 +198,6 @@ func (ws *WsServer) sendUserOnlineInfoToOtherNode(ctx context.Context, client *C if err != nil { return err } - wg := errgroup.Group{} wg.SetLimit(concurrentRequest) @@ -321,8 +306,32 @@ func (ws *WsServer) KickUserConn(client *Client) error { } func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Client, newClient *Client) { - switch ws.msgGatewayConfig.Share.MultiLoginPolicy { + kickTokenFunc := func(kickClients []*Client) { + var kickTokens []string + ws.clients.DeleteClients(newClient.UserID, kickClients) + for _, c := range kickClients { + kickTokens = append(kickTokens, c.token) + err := c.KickOnlineMessage() + if err != nil { + log.ZWarn(c.ctx, "KickOnlineMessage", err) + } + } + ctx := mcontext.WithMustInfoCtx( + []string{newClient.ctx.GetOperationID(), newClient.ctx.GetUserID(), + constant.PlatformIDToName(newClient.PlatformID), newClient.ctx.GetConnID()}, + ) + if _, err := ws.authClient.KickTokens(ctx, kickTokens); err != nil { + log.ZWarn(newClient.ctx, "kickTokens err", err) + } + } + + switch ws.msgGatewayConfig.Share.MultiLogin.Policy { case constant.DefalutNotKick: + case constant.WebAndOther: + if constant.PlatformIDToClass(newClient.PlatformID) == constant.WebPlatformStr { + return + } + fallthrough case constant.PCAndOther: if constant.PlatformIDToClass(newClient.PlatformID) == constant.TerminalPC { return @@ -347,6 +356,35 @@ func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Clien log.ZWarn(newClient.ctx, "InvalidateToken err", err, "userID", newClient.UserID, "platformID", newClient.PlatformID) } + case constant.PcMobileAndWeb: + clients, ok := ws.clients.GetAll(newClient.UserID) + if !ok { + return + } + var ( + kickClients []*Client + ) + for _, client := range clients { + if constant.PlatformIDToClass(client.PlatformID) == constant.PlatformIDToClass(newClient.PlatformID) { + kickClients = append(kickClients, client) + } + } + kickTokenFunc(kickClients) + + case constant.SingleTerminalLogin: + clients, ok := ws.clients.GetAll(newClient.UserID) + if !ok { + return + } + var ( + kickClients []*Client + ) + for _, client := range clients { + kickClients = append(kickClients, client) + } + kickTokenFunc(kickClients) + case constant.Customize: + // todo } } diff --git a/internal/push/offlinepush/offlinepusher.go b/internal/push/offlinepush/offlinepusher.go index 9aa6625deb..d655a924a2 100644 --- a/internal/push/offlinepush/offlinepusher.go +++ b/internal/push/offlinepush/offlinepusher.go @@ -23,10 +23,13 @@ import ( "github.com/openimsdk/open-im-server/v3/internal/push/offlinepush/options" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" + "github.com/openimsdk/tools/log" + "github.com/openimsdk/tools/mcontext" + "strings" ) const ( - geTUI = "geTui" + geTUI = "getui" firebase = "fcm" jPush = "jpush" ) @@ -38,6 +41,7 @@ type OfflinePusher interface { func NewOfflinePusher(pushConf *config.Push, cache cache.ThirdCache, fcmConfigPath string) (OfflinePusher, error) { var offlinePusher OfflinePusher + pushConf.Enable = strings.ToLower(pushConf.Enable) switch pushConf.Enable { case geTUI: offlinePusher = getui.NewClient(pushConf, cache) @@ -47,6 +51,7 @@ func NewOfflinePusher(pushConf *config.Push, cache cache.ThirdCache, fcmConfigPa offlinePusher = jpush.NewClient(pushConf) default: offlinePusher = dummy.NewClient() + log.ZWarn(mcontext.WithMustInfoCtx([]string{"push start", "admin", "admin", ""}), "Unknown push config", nil) } return offlinePusher, nil } diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index 06ae89d971..62df74d214 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -65,7 +65,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg redis2.NewTokenCacheModel(rdb, config.RpcConfig.TokenPolicy.Expire), config.Share.Secret, config.RpcConfig.TokenPolicy.Expire, - config.Share.MultiLoginPolicy, + config.Share.MultiLogin, ), config: config, }) @@ -230,3 +230,10 @@ func (s *authServer) InvalidateToken(ctx context.Context, req *pbauth.Invalidate } return &pbauth.InvalidateTokenResp{}, nil } + +func (s *authServer) KickTokens(ctx context.Context, req *pbauth.KickTokensReq) (*pbauth.KickTokensResp, error) { + if err := s.authDatabase.BatchSetTokenMapByUidPid(ctx, req.Tokens); err != nil { + return nil, err + } + return &pbauth.KickTokensResp{}, nil +} diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index c3ee0d3d53..b8e6c2aca9 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -1526,29 +1526,61 @@ func (g *groupServer) SetGroupMemberInfo(ctx context.Context, req *pbgroup.SetGr case 0: if !isAppManagerUid { roleLevel := dbMembers[opUserIndex].RoleLevel - if roleLevel != constant.GroupOwner { - switch roleLevel { - case constant.GroupAdmin: - for _, member := range dbMembers { - if member.RoleLevel == constant.GroupOwner { - return nil, errs.ErrNoPermission.WrapMsg("admin can not change group owner") - } - if member.RoleLevel == constant.GroupAdmin && member.UserID != opUserID { - return nil, errs.ErrNoPermission.WrapMsg("admin can not change other group admin") - } + var ( + dbSelf = &model.GroupMember{} + reqSelf *pbgroup.SetGroupMemberInfo + ) + switch roleLevel { + case constant.GroupOwner: + for _, member := range dbMembers { + if member.UserID == opUserID { + dbSelf = member + break } - case constant.GroupOrdinaryUsers: - for _, member := range dbMembers { - if !(member.RoleLevel == constant.GroupOrdinaryUsers && member.UserID == opUserID) { - return nil, errs.ErrNoPermission.WrapMsg("ordinary users can not change other role level") - } + } + case constant.GroupAdmin: + for _, member := range dbMembers { + if member.UserID == opUserID { + dbSelf = member + } + if member.RoleLevel == constant.GroupOwner { + return nil, errs.ErrNoPermission.WrapMsg("admin can not change group owner") } - default: - for _, member := range dbMembers { - if member.RoleLevel >= roleLevel { - return nil, errs.ErrNoPermission.WrapMsg("can not change higher role level") - } + if member.RoleLevel == constant.GroupAdmin && member.UserID != opUserID { + return nil, errs.ErrNoPermission.WrapMsg("admin can not change other group admin") + } + } + case constant.GroupOrdinaryUsers: + for _, member := range dbMembers { + if member.UserID == opUserID { + dbSelf = member + } + if !(member.RoleLevel == constant.GroupOrdinaryUsers && member.UserID == opUserID) { + return nil, errs.ErrNoPermission.WrapMsg("ordinary users can not change other role level") + } + } + default: + for _, member := range dbMembers { + if member.UserID == opUserID { + dbSelf = member } + if member.RoleLevel >= roleLevel { + return nil, errs.ErrNoPermission.WrapMsg("can not change higher role level") + } + } + } + for _, member := range req.Members { + if member.UserID == opUserID { + reqSelf = member + break + } + } + if reqSelf != nil && reqSelf.RoleLevel != nil { + if reqSelf.RoleLevel.GetValue() > dbSelf.RoleLevel { + return nil, errs.ErrNoPermission.WrapMsg("can not improve role level by self") + } + if roleLevel == constant.GroupOwner { + return nil, errs.ErrArgs.WrapMsg("group owner can not change own role level") // Prevent the absence of a group owner } } } @@ -1589,7 +1621,7 @@ func (g *groupServer) SetGroupMemberInfo(ctx context.Context, req *pbgroup.SetGr g.notification.GroupMemberSetToOrdinaryUserNotification(ctx, member.GroupID, member.UserID) } } - if member.Nickname != nil || member.FaceURL != nil || member.Ex != nil { + if member.Nickname != nil || member.FaceURL != nil || member.Ex != nil || member.RoleLevel != nil { g.notification.GroupMemberInfoSetNotification(ctx, member.GroupID, member.UserID) } } diff --git a/pkg/common/config/config.go b/pkg/common/config/config.go index 77fcbb8aa1..da6c63d600 100644 --- a/pkg/common/config/config.go +++ b/pkg/common/config/config.go @@ -361,11 +361,29 @@ type AfterConfig struct { } type Share struct { - Secret string `mapstructure:"secret"` - RpcRegisterName RpcRegisterName `mapstructure:"rpcRegisterName"` - IMAdminUserID []string `mapstructure:"imAdminUserID"` - MultiLoginPolicy int `mapstructure:"multiLoginPolicy"` + Secret string `mapstructure:"secret"` + RpcRegisterName RpcRegisterName `mapstructure:"rpcRegisterName"` + IMAdminUserID []string `mapstructure:"imAdminUserID"` + MultiLogin MultiLogin `mapstructure:"multiLogin"` +} + +type MultiLogin struct { + Policy int `mapstructure:"policy"` + MaxNumOneEnd int `mapstructure:"maxNumOneEnd"` + CustomizeLoginNum struct { + IOS int `mapstructure:"ios"` + Android int `mapstructure:"android"` + Windows int `mapstructure:"windows"` + OSX int `mapstructure:"osx"` + Web int `mapstructure:"web"` + MiniWeb int `mapstructure:"miniWeb"` + Linux int `mapstructure:"linux"` + APad int `mapstructure:"aPad"` + IPad int `mapstructure:"iPad"` + Admin int `mapstructure:"admin"` + } `mapstructure:"customizeLoginNum"` } + type RpcRegisterName struct { User string `mapstructure:"user"` Friend string `mapstructure:"friend"` diff --git a/pkg/common/storage/cache/cachekey/token.go b/pkg/common/storage/cache/cachekey/token.go index 94468dc315..83ba2f2111 100644 --- a/pkg/common/storage/cache/cachekey/token.go +++ b/pkg/common/storage/cache/cachekey/token.go @@ -1,6 +1,9 @@ package cachekey -import "github.com/openimsdk/protocol/constant" +import ( + "github.com/openimsdk/protocol/constant" + "strings" +) const ( UidPidToken = "UID_PID_TOKEN_STATUS:" @@ -9,3 +12,17 @@ const ( func GetTokenKey(userID string, platformID int) string { return UidPidToken + userID + ":" + constant.PlatformIDToName(platformID) } + +func GetAllPlatformTokenKey(userID string) []string { + res := make([]string, len(constant.PlatformID2Name)) + for k := range constant.PlatformID2Name { + res[k-1] = GetTokenKey(userID, k) + } + return res +} + +func GetPlatformIDByTokenKey(key string) int { + splitKey := strings.Split(key, ":") + platform := splitKey[len(splitKey)-1] + return constant.PlatformNameToID(platform) +} diff --git a/pkg/common/storage/cache/redis/token.go b/pkg/common/storage/cache/redis/token.go index 24e9c30050..998b4f1c95 100644 --- a/pkg/common/storage/cache/redis/token.go +++ b/pkg/common/storage/cache/redis/token.go @@ -1,17 +1,3 @@ -// Copyright © 2023 OpenIM. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package redis import ( @@ -21,6 +7,7 @@ import ( "github.com/openimsdk/tools/errs" "github.com/redis/go-redis/v9" "strconv" + "sync" "time" ) @@ -67,6 +54,43 @@ func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, p return mm, nil } +func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) { + var ( + res = make(map[int]map[string]int) + resLock = sync.Mutex{} + ) + + keys := cachekey.GetAllPlatformTokenKey(userID) + if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error { + pipe := c.rdb.Pipeline() + mapRes := make([]*redis.MapStringStringCmd, len(keys)) + for i, key := range keys { + mapRes[i] = pipe.HGetAll(ctx, key) + } + _, err := pipe.Exec(ctx) + if err != nil { + return err + } + for i, m := range mapRes { + mm := make(map[string]int) + for k, v := range m.Val() { + state, err := strconv.Atoi(v) + if err != nil { + return errs.WrapMsg(err, "redis token value is not int", "value", v, "userID", userID) + } + mm[k] = state + } + resLock.Lock() + res[cachekey.GetPlatformIDByTokenKey(keys[i])] = mm + resLock.Unlock() + } + return nil + }); err != nil { + return nil, err + } + return res, nil +} + func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error { mm := make(map[string]any) for k, v := range m { @@ -75,6 +99,18 @@ func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, pla return errs.Wrap(c.rdb.HSet(ctx, cachekey.GetTokenKey(userID, platformID), mm).Err()) } +func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]int) error { + pipe := c.rdb.Pipeline() + for k, v := range tokens { + pipe.HSet(ctx, k, v) + } + _, err := pipe.Exec(ctx) + if err != nil { + return errs.Wrap(err) + } + return nil +} + func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error { return errs.Wrap(c.rdb.HDel(ctx, cachekey.GetTokenKey(userID, platformID), fields...).Err()) } diff --git a/pkg/common/storage/cache/token.go b/pkg/common/storage/cache/token.go index 4a0fee087d..ee0004d7f8 100644 --- a/pkg/common/storage/cache/token.go +++ b/pkg/common/storage/cache/token.go @@ -9,6 +9,8 @@ type TokenModel interface { // SetTokenFlagEx set token and flag with expire time SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) + GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error + BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]int) error DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error } diff --git a/pkg/common/storage/controller/auth.go b/pkg/common/storage/controller/auth.go index 94f18b3ae3..de8f93462f 100644 --- a/pkg/common/storage/controller/auth.go +++ b/pkg/common/storage/controller/auth.go @@ -1,21 +1,9 @@ -// Copyright © 2023 OpenIM. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package controller import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" "github.com/openimsdk/tools/log" "github.com/golang-jwt/jwt/v4" @@ -32,18 +20,41 @@ type AuthDatabase interface { // Create token CreateToken(ctx context.Context, userID string, platformID int) (string, error) + BatchSetTokenMapByUidPid(ctx context.Context, tokens []string) error + SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error } +type multiLoginConfig struct { + Policy int + MaxNumOneEnd int + CustomizeLoginNum map[int]int +} + type authDatabase struct { - cache cache.TokenModel - accessSecret string - accessExpire int64 - multiLoginPolicy int + cache cache.TokenModel + accessSecret string + accessExpire int64 + multiLogin multiLoginConfig } -func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, policy int) AuthDatabase { - return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, multiLoginPolicy: policy} +func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, multiLogin config.MultiLogin) AuthDatabase { + return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, multiLogin: multiLoginConfig{ + Policy: multiLogin.Policy, + MaxNumOneEnd: multiLogin.MaxNumOneEnd, + CustomizeLoginNum: map[int]int{ + constant.IOSPlatformID: multiLogin.CustomizeLoginNum.IOS, + constant.AndroidPlatformID: multiLogin.CustomizeLoginNum.Android, + constant.WindowsPlatformID: multiLogin.CustomizeLoginNum.Windows, + constant.OSXPlatformID: multiLogin.CustomizeLoginNum.OSX, + constant.WebPlatformID: multiLogin.CustomizeLoginNum.Web, + constant.MiniWebPlatformID: multiLogin.CustomizeLoginNum.MiniWeb, + constant.LinuxPlatformID: multiLogin.CustomizeLoginNum.Linux, + constant.AndroidPadPlatformID: multiLogin.CustomizeLoginNum.APad, + constant.IPadPlatformID: multiLogin.CustomizeLoginNum.IPad, + constant.AdminPlatformID: multiLogin.CustomizeLoginNum.Admin, + }, + }} } // If the result is empty. @@ -55,22 +66,38 @@ func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, p return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m) } +func (a *authDatabase) BatchSetTokenMapByUidPid(ctx context.Context, tokens []string) error { + setMap := make(map[string]map[string]int) + for _, token := range tokens { + claims, err := tokenverify.GetClaimFromToken(token, authverify.Secret(a.accessSecret)) + key := cachekey.GetTokenKey(claims.UserID, claims.PlatformID) + if err != nil { + continue + } else { + if v, ok := setMap[key]; ok { + v[token] = constant.KickedToken + } else { + setMap[key] = map[string]int{ + token: constant.KickedToken, + } + } + } + } + if err := a.cache.BatchSetTokenMapByUidPid(ctx, setMap); err != nil { + return err + } + return nil +} + // Create Token. func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) { - // todo: get all platform token - tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID) + tokens, err := a.cache.GetAllTokensWithoutError(ctx, userID) if err != nil { return "", err } - var deleteTokenKey []string - var kickedTokenKey []string - for k, v := range tokens { - t, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret)) - if err != nil || v != constant.NormalToken { - deleteTokenKey = append(deleteTokenKey, k) - } else if a.checkKickToken(ctx, platformID, t) { - kickedTokenKey = append(kickedTokenKey, k) - } + deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID) + if err != nil { + return "", err } if len(deleteTokenKey) != 0 { err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) @@ -78,16 +105,6 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI return "", err } } - - const adminTokenMaxNum = 30 - if platformID == constant.AdminPlatformID { - if len(kickedTokenKey) > adminTokenMaxNum { - kickedTokenKey = kickedTokenKey[:len(kickedTokenKey)-adminTokenMaxNum] - } else { - kickedTokenKey = nil - } - } - if len(kickedTokenKey) != 0 { for _, k := range kickedTokenKey { err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken) @@ -111,22 +128,140 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI return tokenString, nil } -func (a *authDatabase) checkKickToken(ctx context.Context, platformID int, token *tokenverify.Claims) bool { - switch a.multiLoginPolicy { +func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) ([]string, []string, error) { + // todo: Move the logic for handling old data to another location. + var ( + loginTokenMap = make(map[int][]string) // The length of the value of the map must be greater than 0 + deleteToken = make([]string, 0) + kickToken = make([]string, 0) + adminToken = make([]string, 0) + unkickTerminal = "" + ) + + for plfID, tks := range tokens { + for k, v := range tks { + _, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret)) + if err != nil || v != constant.NormalToken { + deleteToken = append(deleteToken, k) + } else { + if plfID != constant.AdminPlatformID { + loginTokenMap[plfID] = append(loginTokenMap[plfID], k) + } else { + adminToken = append(adminToken, k) + } + } + } + } + + switch a.multiLogin.Policy { case constant.DefalutNotKick: - return false - case constant.PCAndOther: - if constant.PlatformIDToClass(platformID) == constant.TerminalPC || - constant.PlatformIDToClass(token.PlatformID) == constant.TerminalPC { - return false + for plt, ts := range loginTokenMap { + l := len(ts) + if platformID == plt { + l++ + } + limit := a.multiLogin.MaxNumOneEnd + if l > limit { + kickToken = append(kickToken, ts[:l-limit]...) + } } - return true case constant.AllLoginButSameTermKick: - if platformID == token.PlatformID { - return true + for plt, ts := range loginTokenMap { + kickToken = append(kickToken, ts[:len(ts)-1]...) + if plt == platformID { + kickToken = append(kickToken, ts[len(ts)-1]) + } + } + case constant.SingleTerminalLogin: + for _, ts := range loginTokenMap { + kickToken = append(kickToken, ts...) + } + case constant.WebAndOther: + unkickTerminal = constant.WebPlatformStr + fallthrough + case constant.PCAndOther: + if unkickTerminal == "" { + unkickTerminal = constant.TerminalPC + } + if constant.PlatformIDToClass(platformID) != unkickTerminal { + for plt, ts := range loginTokenMap { + if constant.PlatformIDToClass(plt) != unkickTerminal { + kickToken = append(kickToken, ts...) + } + } + } else { + var ( + preKick []string + isReserve = true + ) + for plt, ts := range loginTokenMap { + if constant.PlatformIDToClass(plt) != unkickTerminal { + // Keep a token from another end + if isReserve { + isReserve = false + kickToken = append(kickToken, ts[:len(ts)-1]...) + preKick = append(preKick, ts[len(ts)-1]) + continue + } else { + // Prioritize keeping Android + if plt == constant.AndroidPlatformID { + kickToken = append(kickToken, preKick...) + kickToken = append(kickToken, ts[:len(ts)-1]...) + } else { + kickToken = append(kickToken, ts...) + } + } + } + } + } + case constant.PcMobileAndWeb: + var ( + reserved = make(map[string]bool) + ) + + for plt, ts := range loginTokenMap { + if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) { + kickToken = append(kickToken, ts...) + } else { + if !reserved[constant.PlatformIDToClass(plt)] { + reserved[constant.PlatformIDToClass(plt)] = true + kickToken = append(kickToken, ts[:len(ts)-1]...) + continue + } else { + kickToken = append(kickToken, ts...) + } + } + } + + case constant.Customize: + if a.multiLogin.CustomizeLoginNum[platformID] <= 0 { + return nil, nil, errs.New("Do not allow login on this end").Wrap() + } + for plt, ts := range loginTokenMap { + l := len(ts) + if platformID == plt { + l++ + } + // a.multiLogin.CustomizeLoginNum[platformID] must > 0 + limit := min(a.multiLogin.CustomizeLoginNum[plt], a.multiLogin.MaxNumOneEnd) + if l > limit { + kickToken = append(kickToken, ts[:l-limit]...) + } } - return false default: - return false + return nil, nil, errs.New("unknown multiLogin policy").Wrap() + } + + var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd + if a.multiLogin.Policy == constant.Customize { + adminTokenMaxNum = a.multiLogin.CustomizeLoginNum[constant.AdminPlatformID] + } + l := len(adminToken) + if platformID == constant.AdminPlatformID { + l++ + } + if l > adminTokenMaxNum { + kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...) } + return deleteToken, kickToken, nil } diff --git a/pkg/rpcclient/auth.go b/pkg/rpcclient/auth.go index fead624a3b..05fec35a08 100644 --- a/pkg/rpcclient/auth.go +++ b/pkg/rpcclient/auth.go @@ -61,3 +61,14 @@ func (a *Auth) InvalidateToken(ctx context.Context, preservedToken, userID strin } return resp, err } + +func (a *Auth) KickTokens(ctx context.Context, tokens []string) (*auth.KickTokensResp, error) { + req := auth.KickTokensReq{ + Tokens: tokens, + } + resp, err := a.Client.KickTokens(ctx, &req) + if err != nil { + return nil, err + } + return resp, err +}