diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index 19548a71c4..af96e7d460 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -84,7 +84,7 @@ type Client struct { } // ResetClient updates the client's state with new connection and context information. -func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer, sdkType string) { +func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) { c.w = new(sync.Mutex) c.conn = conn c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID()) @@ -97,15 +97,12 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer c.closed.Store(false) c.closedErr = nil c.token = ctx.GetToken() - c.SDKType = sdkType + c.SDKType = ctx.GetSDKType() c.hbCtx, c.hbCancel = context.WithCancel(c.ctx) c.subLock = new(sync.Mutex) if c.subUserIDs != nil { clear(c.subUserIDs) } - if c.SDKType == "" { - c.SDKType = GoSDK - } if c.SDKType == GoSDK { c.Encoder = NewGobEncoder() } else { diff --git a/internal/msggateway/context.go b/internal/msggateway/context.go index f3f168f616..d73a96df4c 100644 --- a/internal/msggateway/context.go +++ b/internal/msggateway/context.go @@ -153,6 +153,14 @@ func (c *UserConnContext) GetCompression() bool { return false } +func (c *UserConnContext) GetSDKType() string { + sdkType := c.Req.URL.Query().Get(SDKType) + if sdkType == "" { + sdkType = GoSDK + } + return sdkType +} + func (c *UserConnContext) ShouldSendResp() bool { errResp, exists := c.Query(SendResponse) if exists { diff --git a/internal/msggateway/ws_server.go b/internal/msggateway/ws_server.go index 050369149e..e6b4f3fa47 100644 --- a/internal/msggateway/ws_server.go +++ b/internal/msggateway/ws_server.go @@ -455,8 +455,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { // Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection client := ws.clientPool.Get().(*Client) - sdkType, _ := connContext.Query(SDKType) - client.ResetClient(connContext, wsLongConn, ws, sdkType) + client.ResetClient(connContext, wsLongConn, ws) // Register the client with the server and start message processing ws.registerChan <- client