From 97d10f99e4916c639ed01fc2153040eca799897e Mon Sep 17 00:00:00 2001 From: withchao <993506633@qq.com> Date: Thu, 14 Nov 2024 11:32:35 +0800 Subject: [PATCH] feat: gob json encoder --- internal/msggateway/client.go | 17 +++++++++++--- internal/msggateway/constant.go | 6 +++++ internal/msggateway/context.go | 6 ++++- internal/msggateway/encoder.go | 38 ++++++++++++++++++++++++++----- internal/msggateway/hub_server.go | 15 +++--------- internal/msggateway/ws_server.go | 16 ++++--------- 6 files changed, 64 insertions(+), 34 deletions(-) diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index 48cbf2b6ad..19548a71c4 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -70,6 +70,8 @@ type Client struct { IsCompress bool `json:"isCompress"` UserID string `json:"userID"` IsBackground bool `json:"isBackground"` + SDKType string `json:"sdkType"` + Encoder Encoder ctx *UserConnContext longConnServer LongConnServer closed atomic.Bool @@ -82,7 +84,7 @@ type Client struct { } // ResetClient updates the client's state with new connection and context information. -func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) { +func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer, sdkType string) { c.w = new(sync.Mutex) c.conn = conn c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID()) @@ -95,11 +97,20 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer c.closed.Store(false) c.closedErr = nil c.token = ctx.GetToken() + c.SDKType = sdkType c.hbCtx, c.hbCancel = context.WithCancel(c.ctx) c.subLock = new(sync.Mutex) if c.subUserIDs != nil { clear(c.subUserIDs) } + if c.SDKType == "" { + c.SDKType = GoSDK + } + if c.SDKType == GoSDK { + c.Encoder = NewGobEncoder() + } else { + c.Encoder = NewJsonEncoder() + } c.subUserIDs = make(map[string]struct{}) } @@ -192,7 +203,7 @@ func (c *Client) handleMessage(message []byte) error { var binaryReq = getReq() defer freeReq(binaryReq) - err := c.longConnServer.Decode(message, binaryReq) + err := c.Encoder.Decode(message, binaryReq) if err != nil { return err } @@ -339,7 +350,7 @@ func (c *Client) writeBinaryMsg(resp Resp) error { return nil } - encodedBuf, err := c.longConnServer.Encode(resp) + encodedBuf, err := c.Encoder.Encode(resp) if err != nil { return err } diff --git a/internal/msggateway/constant.go b/internal/msggateway/constant.go index 584cebe1e1..a825c05196 100644 --- a/internal/msggateway/constant.go +++ b/internal/msggateway/constant.go @@ -27,6 +27,12 @@ const ( GzipCompressionProtocol = "gzip" BackgroundStatus = "isBackground" SendResponse = "isMsgResp" + SDKType = "sdkType" +) + +const ( + GoSDK = "go" + JsSDK = "js" ) const ( diff --git a/internal/msggateway/context.go b/internal/msggateway/context.go index 3909766b1b..f3f168f616 100644 --- a/internal/msggateway/context.go +++ b/internal/msggateway/context.go @@ -193,7 +193,11 @@ func (c *UserConnContext) ParseEssentialArgs() error { _, err := strconv.Atoi(platformIDStr) if err != nil { return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int") - + } + switch sdkType, _ := c.Query(SDKType); sdkType { + case "", GoSDK, JsSDK: + default: + return servererrs.ErrConnArgsErr.WrapMsg("sdkType is not go or js") } return nil } diff --git a/internal/msggateway/encoder.go b/internal/msggateway/encoder.go index 056f462367..6a5936d6d2 100644 --- a/internal/msggateway/encoder.go +++ b/internal/msggateway/encoder.go @@ -15,6 +15,8 @@ package msggateway import ( + "bytes" + "encoding/gob" "encoding/json" "github.com/openimsdk/tools/errs" @@ -27,22 +29,46 @@ type Encoder interface { type GobEncoder struct{} -func NewGobEncoder() *GobEncoder { - return &GobEncoder{} +func NewGobEncoder() Encoder { + return GobEncoder{} } -func (g *GobEncoder) Encode(data any) ([]byte, error) { +func (g GobEncoder) Encode(data any) ([]byte, error) { + var buff bytes.Buffer + enc := gob.NewEncoder(&buff) + if err := enc.Encode(data); err != nil { + return nil, errs.WrapMsg(err, "GobEncoder.Encode failed", "action", "encode") + } + return buff.Bytes(), nil +} + +func (g GobEncoder) Decode(encodeData []byte, decodeData any) error { + buff := bytes.NewBuffer(encodeData) + dec := gob.NewDecoder(buff) + if err := dec.Decode(decodeData); err != nil { + return errs.WrapMsg(err, "GobEncoder.Decode failed", "action", "decode") + } + return nil +} + +type JsonEncoder struct{} + +func NewJsonEncoder() Encoder { + return JsonEncoder{} +} + +func (g JsonEncoder) Encode(data any) ([]byte, error) { b, err := json.Marshal(data) if err != nil { - return nil, errs.New("Encoder.Encode failed", "action", "encode") + return nil, errs.New("JsonEncoder.Encode failed", "action", "encode") } return b, nil } -func (g *GobEncoder) Decode(encodeData []byte, decodeData any) error { +func (g JsonEncoder) Decode(encodeData []byte, decodeData any) error { err := json.Unmarshal(encodeData, decodeData) if err != nil { - return errs.New("Encoder.Decode failed", "action", "decode") + return errs.New("JsonEncoder.Decode failed", "action", "decode") } return nil } diff --git a/internal/msggateway/hub_server.go b/internal/msggateway/hub_server.go index cc29876447..23d9150133 100644 --- a/internal/msggateway/hub_server.go +++ b/internal/msggateway/hub_server.go @@ -83,17 +83,11 @@ func NewServer(rpcPort int, longConnServer LongConnServer, conf *Config, ready f return s } -func (s *Server) OnlinePushMsg( - context context.Context, - req *msggateway.OnlinePushMsgReq, -) (*msggateway.OnlinePushMsgResp, error) { +func (s *Server) OnlinePushMsg(context context.Context, req *msggateway.OnlinePushMsgReq) (*msggateway.OnlinePushMsgResp, error) { panic("implement me") } -func (s *Server) GetUsersOnlineStatus( - ctx context.Context, - req *msggateway.GetUsersOnlineStatusReq, -) (*msggateway.GetUsersOnlineStatusResp, error) { +func (s *Server) GetUsersOnlineStatus(ctx context.Context, req *msggateway.GetUsersOnlineStatusReq) (*msggateway.GetUsersOnlineStatusResp, error) { if !authverify.IsAppManagerUid(ctx, s.config.Share.IMAdminUserID) { return nil, errs.ErrNoPermission.WrapMsg("only app manager") } @@ -221,10 +215,7 @@ func (s *Server) SuperGroupOnlineBatchPushOneMsg(ctx context.Context, req *msgga } } -func (s *Server) KickUserOffline( - ctx context.Context, - req *msggateway.KickUserOfflineReq, -) (*msggateway.KickUserOfflineResp, error) { +func (s *Server) KickUserOffline(ctx context.Context, req *msggateway.KickUserOfflineReq) (*msggateway.KickUserOfflineResp, error) { for _, v := range req.KickUserIDList { clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)) if !ok { diff --git a/internal/msggateway/ws_server.go b/internal/msggateway/ws_server.go index b92d7eb442..32cdeaee64 100644 --- a/internal/msggateway/ws_server.go +++ b/internal/msggateway/ws_server.go @@ -37,7 +37,6 @@ type LongConnServer interface { SetKickHandlerInfo(i *kickHandler) SubUserOnlineStatus(ctx context.Context, client *Client, data *Req) ([]byte, error) Compressor - Encoder MessageHandler } @@ -61,7 +60,7 @@ type WsServer struct { authClient *rpcclient.Auth disCov discovery.SvcDiscoveryRegistry Compressor - Encoder + //Encoder MessageHandler webhookClient *webhook.Client } @@ -135,7 +134,6 @@ func NewWsServer(msgGatewayConfig *Config, opts ...Option) *WsServer { clients: newUserMap(), subscription: newSubscription(), Compressor: NewGzipCompressor(), - Encoder: NewGobEncoder(), webhookClient: webhook.NewWebhookClient(msgGatewayConfig.WebhooksConfig.URL), } } @@ -278,14 +276,7 @@ func (ws *WsServer) registerClient(client *Client) { wg.Wait() - log.ZDebug( - client.ctx, - "user online", - "online user Num", - ws.onlineUserNum.Load(), - "online user conn Num", - ws.onlineUserConnNum.Load(), - ) + log.ZDebug(client.ctx, "user online", "online user Num", ws.onlineUserNum.Load(), "online user conn Num", ws.onlineUserConnNum.Load()) } func getRemoteAdders(client []*Client) string { @@ -484,7 +475,8 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { // Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection client := ws.clientPool.Get().(*Client) - client.ResetClient(connContext, wsLongConn, ws) + sdkType, _ := connContext.Query(SDKType) + client.ResetClient(connContext, wsLongConn, ws, sdkType) // Register the client with the server and start message processing ws.registerChan <- client