Skip to content

Commit

Permalink
rpc client
Browse files Browse the repository at this point in the history
  • Loading branch information
withchao committed Dec 23, 2024
1 parent a11dd60 commit 774b3a0
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 37 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,5 @@ require (

replace (
github.com/openimsdk/protocol => /Users/chao/Desktop/code/protocol
github.com/openimsdk/tools => /Users/chao/Desktop/code/tools
)
11 changes: 10 additions & 1 deletion internal/msggateway/hub_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package msggateway

import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
"sync/atomic"

"github.com/openimsdk/open-im-server/v3/pkg/authverify"
Expand All @@ -34,7 +35,14 @@ import (
)

func (s *Server) InitServer(ctx context.Context, config *Config, disCov discovery.SvcDiscoveryRegistry, server *grpc.Server) error {
s.LongConnServer.SetDiscoveryRegistry(disCov, config)
userConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return err
}
s.userClient = rpcli.NewUserClient(userConn)
if err := s.LongConnServer.SetDiscoveryRegistry(ctx, disCov, config); err != nil {
return err
}
msggateway.RegisterMsgGatewayServer(server, s)
if s.ready != nil {
return s.ready(s)
Expand All @@ -61,6 +69,7 @@ type Server struct {
pushTerminal map[int]struct{}
ready func(srv *Server) error
queue *memamq.MemoryQueue
userClient *rpcli.UserClient
}

func (s *Server) SetLongConnServer(LongConnServer LongConnServer) {
Expand Down
2 changes: 1 addition & 1 deletion internal/msggateway/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func Start(ctx context.Context, index int, conf *Config) error {

hubServer := NewServer(longServer, conf, func(srv *Server) error {
var err error
longServer.online, err = rpccache.NewOnlineCache(conf.Share.IMAdminUserID, nil, rdb, false, longServer.subscriberUserOnlineStatusChanges)
longServer.online, err = rpccache.NewOnlineCache(srv.userClient, nil, rdb, false, longServer.subscriberUserOnlineStatusChanges)
return err
})

Expand Down
59 changes: 33 additions & 26 deletions internal/msggateway/message_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package msggateway
import (
"context"
"encoding/json"
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
"sync"

"github.com/go-playground/validator/v10"
Expand Down Expand Up @@ -99,35 +100,41 @@ func (r *Resp) String() string {
}

type MessageHandler interface {
GetSeq(context context.Context, data *Req) ([]byte, error)
SendMessage(context context.Context, data *Req) ([]byte, error)
SendSignalMessage(context context.Context, data *Req) ([]byte, error)
PullMessageBySeqList(context context.Context, data *Req) ([]byte, error)
GetConversationsHasReadAndMaxSeq(context context.Context, data *Req) ([]byte, error)
GetSeqMessage(context context.Context, data *Req) ([]byte, error)
UserLogout(context context.Context, data *Req) ([]byte, error)
SetUserDeviceBackground(context context.Context, data *Req) ([]byte, bool, error)
GetSeq(ctx context.Context, data *Req) ([]byte, error)
SendMessage(ctx context.Context, data *Req) ([]byte, error)
SendSignalMessage(ctx context.Context, data *Req) ([]byte, error)
PullMessageBySeqList(ctx context.Context, data *Req) ([]byte, error)
GetConversationsHasReadAndMaxSeq(ctx context.Context, data *Req) ([]byte, error)
GetSeqMessage(ctx context.Context, data *Req) ([]byte, error)
UserLogout(ctx context.Context, data *Req) ([]byte, error)
SetUserDeviceBackground(ctx context.Context, data *Req) ([]byte, bool, error)
}

var _ MessageHandler = (*GrpcHandler)(nil)

type GrpcHandler struct {
validate *validator.Validate
validate *validator.Validate
msgClient *rpcli.MsgClient
pushClient *rpcli.PushMsgServiceClient
}

func NewGrpcHandler(validate *validator.Validate) *GrpcHandler {
return &GrpcHandler{validate: validate}
func NewGrpcHandler(validate *validator.Validate, msgClient *rpcli.MsgClient, pushClient *rpcli.PushMsgServiceClient) *GrpcHandler {
return &GrpcHandler{
validate: validate,
msgClient: msgClient,
pushClient: pushClient,
}
}

func (g GrpcHandler) GetSeq(ctx context.Context, data *Req) ([]byte, error) {
func (g *GrpcHandler) GetSeq(ctx context.Context, data *Req) ([]byte, error) {
req := sdkws.GetMaxSeqReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "GetSeq: error unmarshaling request", "action", "unmarshal", "dataType", "GetMaxSeqReq")
}
if err := g.validate.Struct(&req); err != nil {
return nil, errs.WrapMsg(err, "GetSeq: validation failed", "action", "validate", "dataType", "GetMaxSeqReq")
}
resp, err := msg.GetMaxSeqCaller.Invoke(ctx, &req)
resp, err := g.msgClient.MsgClient.GetMaxSeq(ctx, &req)
if err != nil {
return nil, err
}
Expand All @@ -140,7 +147,7 @@ func (g GrpcHandler) GetSeq(ctx context.Context, data *Req) ([]byte, error) {

// SendMessage handles the sending of messages through gRPC. It unmarshals the request data,
// validates the message, and then sends it using the message RPC client.
func (g GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error) {
func (g *GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error) {
var msgData sdkws.MsgData
if err := proto.Unmarshal(data.Data, &msgData); err != nil {
return nil, errs.WrapMsg(err, "SendMessage: error unmarshaling message data", "action", "unmarshal", "dataType", "MsgData")
Expand All @@ -151,7 +158,7 @@ func (g GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error)
}

req := msg.SendMsgReq{MsgData: &msgData}
resp, err := msg.SendMsgCaller.Invoke(ctx, &req)
resp, err := g.msgClient.MsgClient.SendMsg(ctx, &req)
if err != nil {
return nil, err
}
Expand All @@ -164,8 +171,8 @@ func (g GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error)
return c, nil
}

func (g GrpcHandler) SendSignalMessage(context context.Context, data *Req) ([]byte, error) {
resp, err := msg.SendMsgCaller.Invoke(context, nil)
func (g *GrpcHandler) SendSignalMessage(ctx context.Context, data *Req) ([]byte, error) {
resp, err := g.msgClient.MsgClient.SendMsg(ctx, nil)
if err != nil {
return nil, err
}
Expand All @@ -176,15 +183,15 @@ func (g GrpcHandler) SendSignalMessage(context context.Context, data *Req) ([]by
return c, nil
}

func (g GrpcHandler) PullMessageBySeqList(ctx context.Context, data *Req) ([]byte, error) {
func (g *GrpcHandler) PullMessageBySeqList(ctx context.Context, data *Req) ([]byte, error) {
req := sdkws.PullMessageBySeqsReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "err proto unmarshal", "action", "unmarshal", "dataType", "PullMessageBySeqsReq")
}
if err := g.validate.Struct(data); err != nil {
return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "PullMessageBySeqsReq")
}
resp, err := msg.PullMessageBySeqsCaller.Invoke(ctx, &req)
resp, err := g.msgClient.MsgClient.PullMessageBySeqs(ctx, &req)
if err != nil {
return nil, err
}
Expand All @@ -195,15 +202,15 @@ func (g GrpcHandler) PullMessageBySeqList(ctx context.Context, data *Req) ([]byt
return c, nil
}

func (g GrpcHandler) GetConversationsHasReadAndMaxSeq(ctx context.Context, data *Req) ([]byte, error) {
func (g *GrpcHandler) GetConversationsHasReadAndMaxSeq(ctx context.Context, data *Req) ([]byte, error) {
req := msg.GetConversationsHasReadAndMaxSeqReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "err proto unmarshal", "action", "unmarshal", "dataType", "GetConversationsHasReadAndMaxSeq")
}
if err := g.validate.Struct(data); err != nil {
return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "GetConversationsHasReadAndMaxSeq")
}
resp, err := msg.GetConversationsHasReadAndMaxSeqCaller.Invoke(ctx, &req)
resp, err := g.msgClient.MsgClient.GetConversationsHasReadAndMaxSeq(ctx, &req)
if err != nil {
return nil, err
}
Expand All @@ -214,15 +221,15 @@ func (g GrpcHandler) GetConversationsHasReadAndMaxSeq(ctx context.Context, data
return c, nil
}

func (g GrpcHandler) GetSeqMessage(ctx context.Context, data *Req) ([]byte, error) {
func (g *GrpcHandler) GetSeqMessage(ctx context.Context, data *Req) ([]byte, error) {
req := msg.GetSeqMessageReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "GetSeqMessage")
}
if err := g.validate.Struct(data); err != nil {
return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "GetSeqMessage")
}
resp, err := msg.GetSeqMessageCaller.Invoke(ctx, &req)
resp, err := g.msgClient.MsgClient.GetSeqMessage(ctx, &req)
if err != nil {
return nil, err
}
Expand All @@ -233,12 +240,12 @@ func (g GrpcHandler) GetSeqMessage(ctx context.Context, data *Req) ([]byte, erro
return c, nil
}

func (g GrpcHandler) UserLogout(ctx context.Context, data *Req) ([]byte, error) {
func (g *GrpcHandler) UserLogout(ctx context.Context, data *Req) ([]byte, error) {
req := push.DelUserPushTokenReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "DelUserPushTokenReq")
}
resp, err := push.DelUserPushTokenCaller.Invoke(ctx, &req)
resp, err := g.pushClient.PushMsgServiceClient.DelUserPushToken(ctx, &req)
if err != nil {
return nil, err
}
Expand All @@ -249,7 +256,7 @@ func (g GrpcHandler) UserLogout(ctx context.Context, data *Req) ([]byte, error)
return c, nil
}

func (g GrpcHandler) SetUserDeviceBackground(_ context.Context, data *Req) ([]byte, bool, error) {
func (g *GrpcHandler) SetUserDeviceBackground(ctx context.Context, data *Req) ([]byte, bool, error) {
req := sdkws.SetAppBackgroundStatusReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, false, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "SetAppBackgroundStatusReq")
Expand Down
2 changes: 1 addition & 1 deletion internal/msggateway/online.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (ws *WsServer) ChangeOnlineStatus(concurrent int) {
opIdCtx := mcontext.SetOperationID(context.Background(), operationIDPrefix+strconv.FormatInt(count.Add(1), 10))
ctx, cancel := context.WithTimeout(opIdCtx, time.Second*5)
defer cancel()
if err := pbuser.SetUserOnlineStatusCaller.Execute(ctx, req); err != nil {
if err := ws.userClient.SetUserOnlineStatus(ctx, req); err != nil {
log.ZError(ctx, "update user online status", err)
}
for _, ss := range req.Status {
Expand Down
38 changes: 30 additions & 8 deletions internal/msggateway/ws_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package msggateway
import (
"context"
"fmt"
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
"net/http"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -31,7 +32,7 @@ type LongConnServer interface {
GetUserAllCons(userID string) ([]*Client, bool)
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
Validate(s any) error
SetDiscoveryRegistry(client discovery.SvcDiscoveryRegistry, config *Config)
SetDiscoveryRegistry(ctx context.Context, client discovery.SvcDiscoveryRegistry, config *Config) error
KickUserConn(client *Client) error
UnRegister(c *Client)
SetKickHandlerInfo(i *kickHandler)
Expand Down Expand Up @@ -61,6 +62,8 @@ type WsServer struct {
//Encoder
MessageHandler
webhookClient *webhook.Client
userClient *rpcli.UserClient
authClient *rpcli.AuthClient
}

type kickHandler struct {
Expand All @@ -69,9 +72,28 @@ type kickHandler struct {
newClient *Client
}

func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *Config) {
ws.MessageHandler = NewGrpcHandler(ws.validate)
func (ws *WsServer) SetDiscoveryRegistry(ctx context.Context, disCov discovery.SvcDiscoveryRegistry, config *Config) error {
userConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return err
}
pushConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Push)
if err != nil {
return err
}
authConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Auth)
if err != nil {
return err
}
msgConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Msg)
if err != nil {
return err
}
ws.userClient = rpcli.NewUserClient(userConn)
ws.authClient = rpcli.NewAuthClient(authConn)
ws.MessageHandler = NewGrpcHandler(ws.validate, rpcli.NewMsgClient(msgConn), rpcli.NewPushMsgServiceClient(pushConn))
ws.disCov = disCov
return nil
}

//func (ws *WsServer) SetUserOnlineStatus(ctx context.Context, client *Client, status int32) {
Expand Down Expand Up @@ -306,8 +328,7 @@ func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Clien
[]string{newClient.ctx.GetOperationID(), newClient.ctx.GetUserID(),
constant.PlatformIDToName(newClient.PlatformID), newClient.ctx.GetConnID()},
)

if err := pbAuth.KickTokensCaller.Execute(ctx, &pbAuth.KickTokensReq{Tokens: kickTokens}); err != nil {
if err := ws.authClient.KickTokens(ctx, kickTokens); err != nil {
log.ZWarn(newClient.ctx, "kickTokens err", err)
}
}
Expand All @@ -334,11 +355,12 @@ func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Clien
[]string{newClient.ctx.GetOperationID(), newClient.ctx.GetUserID(),
constant.PlatformIDToName(newClient.PlatformID), newClient.ctx.GetConnID()},
)
if err := pbAuth.InvalidateTokenCaller.Execute(ctx, &pbAuth.InvalidateTokenReq{
req := &pbAuth.InvalidateTokenReq{
PreservedToken: newClient.token,
UserID: newClient.UserID,
PlatformID: int32(newClient.PlatformID),
}); err != nil {
}
if err := ws.authClient.InvalidateToken(ctx, req); err != nil {
log.ZWarn(newClient.ctx, "InvalidateToken err", err, "userID", newClient.UserID,
"platformID", newClient.PlatformID)
}
Expand Down Expand Up @@ -409,7 +431,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
}

// Call the authentication client to parse the Token obtained from the context
resp, err := pbAuth.ParseTokenCaller.Invoke(connContext, &pbAuth.ParseTokenReq{Token: connContext.GetToken()})
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
if err != nil {
// If there's an error parsing the Token, decide whether to send the error message via WebSocket based on the context flag
shouldSendError := connContext.ShouldSendResp()
Expand Down
13 changes: 13 additions & 0 deletions pkg/rpcli/auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rpcli

import (
"context"
"github.com/openimsdk/protocol/auth"
"google.golang.org/grpc"
)
Expand All @@ -12,3 +13,15 @@ func NewAuthClient(cc grpc.ClientConnInterface) *AuthClient {
type AuthClient struct {
auth.AuthClient
}

func (x *AuthClient) KickTokens(ctx context.Context, tokens []string) error {
return ignoreResp(x.AuthClient.KickTokens(ctx, &auth.KickTokensReq{Tokens: tokens}))
}

func (x *AuthClient) InvalidateToken(ctx context.Context, req *auth.InvalidateTokenReq) error {
return ignoreResp(x.AuthClient.InvalidateToken(ctx, req))
}

func (x *AuthClient) ParseToken(ctx context.Context, token string) (*auth.ParseTokenResp, error) {
return x.AuthClient.ParseToken(ctx, &auth.ParseTokenReq{Token: token})
}
4 changes: 4 additions & 0 deletions pkg/rpcli/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,7 @@ func (x *UserClient) GetUserOnlinePlatform(ctx context.Context, userID string) (
}
return status[0].PlatformIDs, nil
}

func (x *UserClient) SetUserOnlineStatus(ctx context.Context, req *user.SetUserOnlineStatusReq) error {
return ignoreResp(x.UserClient.SetUserOnlineStatus(ctx, req))
}

0 comments on commit 774b3a0

Please sign in to comment.