diff --git a/go.mod b/go.mod index 9379af790c..d9bdb8bfe2 100644 --- a/go.mod +++ b/go.mod @@ -225,4 +225,5 @@ require ( replace ( github.com/openimsdk/protocol => /Users/chao/Desktop/code/protocol + github.com/openimsdk/tools => /Users/chao/Desktop/code/tools ) \ No newline at end of file diff --git a/internal/msggateway/hub_server.go b/internal/msggateway/hub_server.go index f6c12350c9..7533267260 100644 --- a/internal/msggateway/hub_server.go +++ b/internal/msggateway/hub_server.go @@ -16,6 +16,7 @@ package msggateway import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/rpcli" "sync/atomic" "github.com/openimsdk/open-im-server/v3/pkg/authverify" @@ -34,7 +35,14 @@ import ( ) func (s *Server) InitServer(ctx context.Context, config *Config, disCov discovery.SvcDiscoveryRegistry, server *grpc.Server) error { - s.LongConnServer.SetDiscoveryRegistry(disCov, config) + userConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.User) + if err != nil { + return err + } + s.userClient = rpcli.NewUserClient(userConn) + if err := s.LongConnServer.SetDiscoveryRegistry(ctx, disCov, config); err != nil { + return err + } msggateway.RegisterMsgGatewayServer(server, s) if s.ready != nil { return s.ready(s) @@ -61,6 +69,7 @@ type Server struct { pushTerminal map[int]struct{} ready func(srv *Server) error queue *memamq.MemoryQueue + userClient *rpcli.UserClient } func (s *Server) SetLongConnServer(LongConnServer LongConnServer) { diff --git a/internal/msggateway/init.go b/internal/msggateway/init.go index 00cc79ff68..6614b96bdb 100644 --- a/internal/msggateway/init.go +++ b/internal/msggateway/init.go @@ -63,7 +63,7 @@ func Start(ctx context.Context, index int, conf *Config) error { hubServer := NewServer(longServer, conf, func(srv *Server) error { var err error - longServer.online, err = rpccache.NewOnlineCache(conf.Share.IMAdminUserID, nil, rdb, false, longServer.subscriberUserOnlineStatusChanges) + longServer.online, err = rpccache.NewOnlineCache(srv.userClient, nil, rdb, false, longServer.subscriberUserOnlineStatusChanges) return err }) diff --git a/internal/msggateway/message_handler.go b/internal/msggateway/message_handler.go index d88d2fbfd8..9b59867d61 100644 --- a/internal/msggateway/message_handler.go +++ b/internal/msggateway/message_handler.go @@ -17,6 +17,7 @@ package msggateway import ( "context" "encoding/json" + "github.com/openimsdk/open-im-server/v3/pkg/rpcli" "sync" "github.com/go-playground/validator/v10" @@ -99,27 +100,33 @@ func (r *Resp) String() string { } type MessageHandler interface { - GetSeq(context context.Context, data *Req) ([]byte, error) - SendMessage(context context.Context, data *Req) ([]byte, error) - SendSignalMessage(context context.Context, data *Req) ([]byte, error) - PullMessageBySeqList(context context.Context, data *Req) ([]byte, error) - GetConversationsHasReadAndMaxSeq(context context.Context, data *Req) ([]byte, error) - GetSeqMessage(context context.Context, data *Req) ([]byte, error) - UserLogout(context context.Context, data *Req) ([]byte, error) - SetUserDeviceBackground(context context.Context, data *Req) ([]byte, bool, error) + GetSeq(ctx context.Context, data *Req) ([]byte, error) + SendMessage(ctx context.Context, data *Req) ([]byte, error) + SendSignalMessage(ctx context.Context, data *Req) ([]byte, error) + PullMessageBySeqList(ctx context.Context, data *Req) ([]byte, error) + GetConversationsHasReadAndMaxSeq(ctx context.Context, data *Req) ([]byte, error) + GetSeqMessage(ctx context.Context, data *Req) ([]byte, error) + UserLogout(ctx context.Context, data *Req) ([]byte, error) + SetUserDeviceBackground(ctx context.Context, data *Req) ([]byte, bool, error) } var _ MessageHandler = (*GrpcHandler)(nil) type GrpcHandler struct { - validate *validator.Validate + validate *validator.Validate + msgClient *rpcli.MsgClient + pushClient *rpcli.PushMsgServiceClient } -func NewGrpcHandler(validate *validator.Validate) *GrpcHandler { - return &GrpcHandler{validate: validate} +func NewGrpcHandler(validate *validator.Validate, msgClient *rpcli.MsgClient, pushClient *rpcli.PushMsgServiceClient) *GrpcHandler { + return &GrpcHandler{ + validate: validate, + msgClient: msgClient, + pushClient: pushClient, + } } -func (g GrpcHandler) GetSeq(ctx context.Context, data *Req) ([]byte, error) { +func (g *GrpcHandler) GetSeq(ctx context.Context, data *Req) ([]byte, error) { req := sdkws.GetMaxSeqReq{} if err := proto.Unmarshal(data.Data, &req); err != nil { return nil, errs.WrapMsg(err, "GetSeq: error unmarshaling request", "action", "unmarshal", "dataType", "GetMaxSeqReq") @@ -127,7 +134,7 @@ func (g GrpcHandler) GetSeq(ctx context.Context, data *Req) ([]byte, error) { if err := g.validate.Struct(&req); err != nil { return nil, errs.WrapMsg(err, "GetSeq: validation failed", "action", "validate", "dataType", "GetMaxSeqReq") } - resp, err := msg.GetMaxSeqCaller.Invoke(ctx, &req) + resp, err := g.msgClient.MsgClient.GetMaxSeq(ctx, &req) if err != nil { return nil, err } @@ -140,7 +147,7 @@ func (g GrpcHandler) GetSeq(ctx context.Context, data *Req) ([]byte, error) { // SendMessage handles the sending of messages through gRPC. It unmarshals the request data, // validates the message, and then sends it using the message RPC client. -func (g GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error) { +func (g *GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error) { var msgData sdkws.MsgData if err := proto.Unmarshal(data.Data, &msgData); err != nil { return nil, errs.WrapMsg(err, "SendMessage: error unmarshaling message data", "action", "unmarshal", "dataType", "MsgData") @@ -151,7 +158,7 @@ func (g GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error) } req := msg.SendMsgReq{MsgData: &msgData} - resp, err := msg.SendMsgCaller.Invoke(ctx, &req) + resp, err := g.msgClient.MsgClient.SendMsg(ctx, &req) if err != nil { return nil, err } @@ -164,8 +171,8 @@ func (g GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error) return c, nil } -func (g GrpcHandler) SendSignalMessage(context context.Context, data *Req) ([]byte, error) { - resp, err := msg.SendMsgCaller.Invoke(context, nil) +func (g *GrpcHandler) SendSignalMessage(ctx context.Context, data *Req) ([]byte, error) { + resp, err := g.msgClient.MsgClient.SendMsg(ctx, nil) if err != nil { return nil, err } @@ -176,7 +183,7 @@ func (g GrpcHandler) SendSignalMessage(context context.Context, data *Req) ([]by return c, nil } -func (g GrpcHandler) PullMessageBySeqList(ctx context.Context, data *Req) ([]byte, error) { +func (g *GrpcHandler) PullMessageBySeqList(ctx context.Context, data *Req) ([]byte, error) { req := sdkws.PullMessageBySeqsReq{} if err := proto.Unmarshal(data.Data, &req); err != nil { return nil, errs.WrapMsg(err, "err proto unmarshal", "action", "unmarshal", "dataType", "PullMessageBySeqsReq") @@ -184,7 +191,7 @@ func (g GrpcHandler) PullMessageBySeqList(ctx context.Context, data *Req) ([]byt if err := g.validate.Struct(data); err != nil { return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "PullMessageBySeqsReq") } - resp, err := msg.PullMessageBySeqsCaller.Invoke(ctx, &req) + resp, err := g.msgClient.MsgClient.PullMessageBySeqs(ctx, &req) if err != nil { return nil, err } @@ -195,7 +202,7 @@ func (g GrpcHandler) PullMessageBySeqList(ctx context.Context, data *Req) ([]byt return c, nil } -func (g GrpcHandler) GetConversationsHasReadAndMaxSeq(ctx context.Context, data *Req) ([]byte, error) { +func (g *GrpcHandler) GetConversationsHasReadAndMaxSeq(ctx context.Context, data *Req) ([]byte, error) { req := msg.GetConversationsHasReadAndMaxSeqReq{} if err := proto.Unmarshal(data.Data, &req); err != nil { return nil, errs.WrapMsg(err, "err proto unmarshal", "action", "unmarshal", "dataType", "GetConversationsHasReadAndMaxSeq") @@ -203,7 +210,7 @@ func (g GrpcHandler) GetConversationsHasReadAndMaxSeq(ctx context.Context, data if err := g.validate.Struct(data); err != nil { return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "GetConversationsHasReadAndMaxSeq") } - resp, err := msg.GetConversationsHasReadAndMaxSeqCaller.Invoke(ctx, &req) + resp, err := g.msgClient.MsgClient.GetConversationsHasReadAndMaxSeq(ctx, &req) if err != nil { return nil, err } @@ -214,7 +221,7 @@ func (g GrpcHandler) GetConversationsHasReadAndMaxSeq(ctx context.Context, data return c, nil } -func (g GrpcHandler) GetSeqMessage(ctx context.Context, data *Req) ([]byte, error) { +func (g *GrpcHandler) GetSeqMessage(ctx context.Context, data *Req) ([]byte, error) { req := msg.GetSeqMessageReq{} if err := proto.Unmarshal(data.Data, &req); err != nil { return nil, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "GetSeqMessage") @@ -222,7 +229,7 @@ func (g GrpcHandler) GetSeqMessage(ctx context.Context, data *Req) ([]byte, erro if err := g.validate.Struct(data); err != nil { return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "GetSeqMessage") } - resp, err := msg.GetSeqMessageCaller.Invoke(ctx, &req) + resp, err := g.msgClient.MsgClient.GetSeqMessage(ctx, &req) if err != nil { return nil, err } @@ -233,12 +240,12 @@ func (g GrpcHandler) GetSeqMessage(ctx context.Context, data *Req) ([]byte, erro return c, nil } -func (g GrpcHandler) UserLogout(ctx context.Context, data *Req) ([]byte, error) { +func (g *GrpcHandler) UserLogout(ctx context.Context, data *Req) ([]byte, error) { req := push.DelUserPushTokenReq{} if err := proto.Unmarshal(data.Data, &req); err != nil { return nil, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "DelUserPushTokenReq") } - resp, err := push.DelUserPushTokenCaller.Invoke(ctx, &req) + resp, err := g.pushClient.PushMsgServiceClient.DelUserPushToken(ctx, &req) if err != nil { return nil, err } @@ -249,7 +256,7 @@ func (g GrpcHandler) UserLogout(ctx context.Context, data *Req) ([]byte, error) return c, nil } -func (g GrpcHandler) SetUserDeviceBackground(_ context.Context, data *Req) ([]byte, bool, error) { +func (g *GrpcHandler) SetUserDeviceBackground(ctx context.Context, data *Req) ([]byte, bool, error) { req := sdkws.SetAppBackgroundStatusReq{} if err := proto.Unmarshal(data.Data, &req); err != nil { return nil, false, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "SetAppBackgroundStatusReq") diff --git a/internal/msggateway/online.go b/internal/msggateway/online.go index bff2639976..52b6c5d05f 100644 --- a/internal/msggateway/online.go +++ b/internal/msggateway/online.go @@ -88,7 +88,7 @@ func (ws *WsServer) ChangeOnlineStatus(concurrent int) { opIdCtx := mcontext.SetOperationID(context.Background(), operationIDPrefix+strconv.FormatInt(count.Add(1), 10)) ctx, cancel := context.WithTimeout(opIdCtx, time.Second*5) defer cancel() - if err := pbuser.SetUserOnlineStatusCaller.Execute(ctx, req); err != nil { + if err := ws.userClient.SetUserOnlineStatus(ctx, req); err != nil { log.ZError(ctx, "update user online status", err) } for _, ss := range req.Status { diff --git a/internal/msggateway/ws_server.go b/internal/msggateway/ws_server.go index 7271c37274..44b6ddb89f 100644 --- a/internal/msggateway/ws_server.go +++ b/internal/msggateway/ws_server.go @@ -3,6 +3,7 @@ package msggateway import ( "context" "fmt" + "github.com/openimsdk/open-im-server/v3/pkg/rpcli" "net/http" "sync" "sync/atomic" @@ -31,7 +32,7 @@ type LongConnServer interface { GetUserAllCons(userID string) ([]*Client, bool) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) Validate(s any) error - SetDiscoveryRegistry(client discovery.SvcDiscoveryRegistry, config *Config) + SetDiscoveryRegistry(ctx context.Context, client discovery.SvcDiscoveryRegistry, config *Config) error KickUserConn(client *Client) error UnRegister(c *Client) SetKickHandlerInfo(i *kickHandler) @@ -61,6 +62,8 @@ type WsServer struct { //Encoder MessageHandler webhookClient *webhook.Client + userClient *rpcli.UserClient + authClient *rpcli.AuthClient } type kickHandler struct { @@ -69,9 +72,28 @@ type kickHandler struct { newClient *Client } -func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *Config) { - ws.MessageHandler = NewGrpcHandler(ws.validate) +func (ws *WsServer) SetDiscoveryRegistry(ctx context.Context, disCov discovery.SvcDiscoveryRegistry, config *Config) error { + userConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.User) + if err != nil { + return err + } + pushConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Push) + if err != nil { + return err + } + authConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Auth) + if err != nil { + return err + } + msgConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Msg) + if err != nil { + return err + } + ws.userClient = rpcli.NewUserClient(userConn) + ws.authClient = rpcli.NewAuthClient(authConn) + ws.MessageHandler = NewGrpcHandler(ws.validate, rpcli.NewMsgClient(msgConn), rpcli.NewPushMsgServiceClient(pushConn)) ws.disCov = disCov + return nil } //func (ws *WsServer) SetUserOnlineStatus(ctx context.Context, client *Client, status int32) { @@ -306,8 +328,7 @@ func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Clien []string{newClient.ctx.GetOperationID(), newClient.ctx.GetUserID(), constant.PlatformIDToName(newClient.PlatformID), newClient.ctx.GetConnID()}, ) - - if err := pbAuth.KickTokensCaller.Execute(ctx, &pbAuth.KickTokensReq{Tokens: kickTokens}); err != nil { + if err := ws.authClient.KickTokens(ctx, kickTokens); err != nil { log.ZWarn(newClient.ctx, "kickTokens err", err) } } @@ -334,11 +355,12 @@ func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Clien []string{newClient.ctx.GetOperationID(), newClient.ctx.GetUserID(), constant.PlatformIDToName(newClient.PlatformID), newClient.ctx.GetConnID()}, ) - if err := pbAuth.InvalidateTokenCaller.Execute(ctx, &pbAuth.InvalidateTokenReq{ + req := &pbAuth.InvalidateTokenReq{ PreservedToken: newClient.token, UserID: newClient.UserID, PlatformID: int32(newClient.PlatformID), - }); err != nil { + } + if err := ws.authClient.InvalidateToken(ctx, req); err != nil { log.ZWarn(newClient.ctx, "InvalidateToken err", err, "userID", newClient.UserID, "platformID", newClient.PlatformID) } @@ -409,7 +431,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { } // Call the authentication client to parse the Token obtained from the context - resp, err := pbAuth.ParseTokenCaller.Invoke(connContext, &pbAuth.ParseTokenReq{Token: connContext.GetToken()}) + resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken()) if err != nil { // If there's an error parsing the Token, decide whether to send the error message via WebSocket based on the context flag shouldSendError := connContext.ShouldSendResp() diff --git a/pkg/rpcli/auth.go b/pkg/rpcli/auth.go index 1f9731c21e..22a0db1094 100644 --- a/pkg/rpcli/auth.go +++ b/pkg/rpcli/auth.go @@ -1,6 +1,7 @@ package rpcli import ( + "context" "github.com/openimsdk/protocol/auth" "google.golang.org/grpc" ) @@ -12,3 +13,15 @@ func NewAuthClient(cc grpc.ClientConnInterface) *AuthClient { type AuthClient struct { auth.AuthClient } + +func (x *AuthClient) KickTokens(ctx context.Context, tokens []string) error { + return ignoreResp(x.AuthClient.KickTokens(ctx, &auth.KickTokensReq{Tokens: tokens})) +} + +func (x *AuthClient) InvalidateToken(ctx context.Context, req *auth.InvalidateTokenReq) error { + return ignoreResp(x.AuthClient.InvalidateToken(ctx, req)) +} + +func (x *AuthClient) ParseToken(ctx context.Context, token string) (*auth.ParseTokenResp, error) { + return x.AuthClient.ParseToken(ctx, &auth.ParseTokenReq{Token: token}) +} diff --git a/pkg/rpcli/user.go b/pkg/rpcli/user.go index 77d8c53c48..330d5e6408 100644 --- a/pkg/rpcli/user.go +++ b/pkg/rpcli/user.go @@ -71,3 +71,7 @@ func (x *UserClient) GetUserOnlinePlatform(ctx context.Context, userID string) ( } return status[0].PlatformIDs, nil } + +func (x *UserClient) SetUserOnlineStatus(ctx context.Context, req *user.SetUserOnlineStatusReq) error { + return ignoreResp(x.UserClient.SetUserOnlineStatus(ctx, req)) +}