Skip to content

Commit

Permalink
feat: embedded cluster manager websocket (#5015)
Browse files Browse the repository at this point in the history
* feat: embedded cluster manager websocket
  • Loading branch information
sgalsaleh authored Nov 27, 2024
1 parent 2417b58 commit 767739e
Show file tree
Hide file tree
Showing 12 changed files with 461 additions and 0 deletions.
9 changes: 9 additions & 0 deletions pkg/apiserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ func Start(params *APIServerParams) {

handlers.RegisterUnauthenticatedRoutes(handler, kotsStore, debugRouter, loggingRouter)

/**********************************************************************
* Websocket routes (only for embedded cluster)
**********************************************************************/

if util.IsEmbeddedCluster() {
wsRouter := r.NewRoute().Subrouter()
wsRouter.HandleFunc("/ec-ws", handler.ConnectToECWebsocket)
}

/**********************************************************************
* KOTS token auth routes
**********************************************************************/
Expand Down
20 changes: 20 additions & 0 deletions pkg/handlers/debug.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package handlers

import (
"net/http"

"github.com/replicatedhq/kots/pkg/websocket"
websockettypes "github.com/replicatedhq/kots/pkg/websocket/types"
)

type DebugInfoResponse struct {
WSClients map[string]websockettypes.WSClient `json:"wsClients"`
}

func (h *Handler) GetDebugInfo(w http.ResponseWriter, r *http.Request) {
response := DebugInfoResponse{
WSClients: websocket.GetClients(),
}

JSON(w, http.StatusOK, response)
}
4 changes: 4 additions & 0 deletions pkg/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ func RegisterSessionAuthRoutes(r *mux.Router, kotsStore store.Store, handler KOT
r.Name("ChangePassword").Path("/api/v1/password/change").Methods("PUT").
HandlerFunc(middleware.EnforceAccess(policy.PasswordChange, handler.ChangePassword))

// Debug info
r.Name("GetDebugInfo").Path("/api/v1/debug").Methods("GET").
HandlerFunc(middleware.EnforceAccess(policy.ClusterRead, handler.GetDebugInfo))

// Upgrade service
r.Name("StartUpgradeService").Path("/api/v1/app/{appSlug}/start-upgrade-service").Methods("POST").
HandlerFunc(middleware.EnforceAccess(policy.AppUpdate, handler.StartUpgradeService))
Expand Down
10 changes: 10 additions & 0 deletions pkg/handlers/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,16 @@ var HandlerPolicyTests = map[string][]HandlerPolicyTest{
ExpectStatus: http.StatusOK,
},
},
"GetDebugInfo": {
{
Roles: []rbactypes.Role{rbac.ClusterAdminRole},
SessionRoles: []string{rbac.ClusterAdminRoleID},
Calls: func(storeRecorder *mock_store.MockStoreMockRecorder, handlerRecorder *mock_handlers.MockKOTSHandlerMockRecorder) {
handlerRecorder.GetDebugInfo(gomock.Any(), gomock.Any())
},
ExpectStatus: http.StatusOK,
},
},

// Upgrade Service
"StartUpgradeService": {
Expand Down
6 changes: 6 additions & 0 deletions pkg/handlers/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,13 @@ type KOTSHandler interface {
// Password change
ChangePassword(w http.ResponseWriter, r *http.Request)

// Debug info
GetDebugInfo(w http.ResponseWriter, r *http.Request)

// Upgrade service
StartUpgradeService(w http.ResponseWriter, r *http.Request)
GetUpgradeServiceStatus(w http.ResponseWriter, r *http.Request)

// EC Websocket
ConnectToECWebsocket(w http.ResponseWriter, r *http.Request)
}
24 changes: 24 additions & 0 deletions pkg/handlers/mock/mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions pkg/handlers/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package handlers

import (
"net/http"

"github.com/pkg/errors"
"github.com/replicatedhq/kots/pkg/logger"
"github.com/replicatedhq/kots/pkg/websocket"
)

type ConnectToECWebsocketResponse struct {
Error string `json:"error,omitempty"`
}

func (h *Handler) ConnectToECWebsocket(w http.ResponseWriter, r *http.Request) {
response := ConnectToECWebsocketResponse{}

nodeName := r.URL.Query().Get("nodeName")
if nodeName == "" {
response.Error = "missing node name"
logger.Error(errors.New(response.Error))
JSON(w, http.StatusBadRequest, response)
return
}

if err := websocket.Connect(w, r, nodeName); err != nil {
response.Error = "failed to establish websocket connection"
logger.Error(errors.Wrap(err, response.Error))
JSON(w, http.StatusInternalServerError, response)
return
}
}
21 changes: 21 additions & 0 deletions pkg/websocket/types/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package types

import (
"time"

"github.com/gorilla/websocket"
)

type WSClient struct {
Conn *websocket.Conn `json:"-"`
ConnectedAt time.Time `json:"connectedAt"`
LastPingSent PingPongInfo `json:"lastPingSent"`
LastPongRecv PingPongInfo `json:"lastPongRecv"`
LastPingRecv PingPongInfo `json:"lastPingRecv"`
LastPongSent PingPongInfo `json:"lastPongSent"`
}

type PingPongInfo struct {
Time time.Time `json:"time"`
Message string `json:"message"`
}
185 changes: 185 additions & 0 deletions pkg/websocket/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package websocket

import (
"fmt"
"math/rand"
"net"
"net/http"
"sync"
"time"

"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/replicatedhq/kots/pkg/logger"
"github.com/replicatedhq/kots/pkg/websocket/types"
)

var wsUpgrader = websocket.Upgrader{}
var wsClients = make(map[string]types.WSClient)
var wsMutex = sync.Mutex{}

func Connect(w http.ResponseWriter, r *http.Request, nodeName string) error {
conn, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
return errors.Wrap(err, "failed to upgrade to websocket")
}
defer conn.Close()

conn.SetPingHandler(wsPingHandler(nodeName, conn))
conn.SetPongHandler(wsPongHandler(nodeName))
conn.SetCloseHandler(wsCloseHandler(nodeName, conn))

// register the client
registerWSClient(nodeName, conn)

// ping client on a regular interval to make sure it's still connected
go pingWSClient(nodeName, conn)

// listen to client messages
listenToWSClient(nodeName, conn)
return nil
}

func pingWSClient(nodeName string, conn *websocket.Conn) {
for {
sleepDuration := time.Second * time.Duration(5+rand.Intn(16)) // 5-20 seconds
time.Sleep(sleepDuration)

pingMsg := fmt.Sprintf("%x", rand.Int())

if err := conn.WriteControl(websocket.PingMessage, []byte(pingMsg), time.Now().Add(1*time.Second)); err != nil {
if isWSConnClosed(nodeName, err) {
removeWSClient(nodeName, err)
return
}
logger.Debugf("Failed to send ping message to %s: %v", nodeName, err)
continue
}

wsMutex.Lock()
client := wsClients[nodeName]
wsMutex.Unlock()

client.LastPingSent = types.PingPongInfo{
Time: time.Now(),
Message: pingMsg,
}
wsClients[nodeName] = client
}
}

func listenToWSClient(nodeName string, conn *websocket.Conn) {
for {
_, _, err := conn.ReadMessage() // this is required to receive ping/pong messages
if err != nil {
if isWSConnClosed(nodeName, err) {
removeWSClient(nodeName, err)
return
}
logger.Debugf("Error reading websocket message from %s: %v", nodeName, err)
}
}
}

func registerWSClient(nodeName string, conn *websocket.Conn) {
wsMutex.Lock()
defer wsMutex.Unlock()

if e, ok := wsClients[nodeName]; ok {
e.Conn.Close()
delete(wsClients, nodeName)
}

wsClients[nodeName] = types.WSClient{
Conn: conn,
ConnectedAt: time.Now(),
}

logger.Infof("Registered new websocket for %s", nodeName)
}

func removeWSClient(nodeName string, err error) {
wsMutex.Lock()
defer wsMutex.Unlock()

if _, ok := wsClients[nodeName]; !ok {
return
}
logger.Infof("Websocket connection closed for %s: %v", nodeName, err)
delete(wsClients, nodeName)
}

func wsPingHandler(nodeName string, conn *websocket.Conn) func(message string) error {
return func(message string) error {
wsMutex.Lock()
defer wsMutex.Unlock()

client := wsClients[nodeName]
client.LastPingRecv = types.PingPongInfo{
Time: time.Now(),
Message: message,
}

if err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(1*time.Second)); err != nil {
logger.Debugf("Failed to send pong message to %s: %v", nodeName, err)
} else {
client.LastPongSent = types.PingPongInfo{
Time: time.Now(),
Message: message,
}
}

wsClients[nodeName] = client
return nil
}
}

func wsPongHandler(nodeName string) func(message string) error {
return func(message string) error {
wsMutex.Lock()
defer wsMutex.Unlock()

client := wsClients[nodeName]
client.LastPongRecv = types.PingPongInfo{
Time: time.Now(),
Message: message,
}
wsClients[nodeName] = client

return nil
}
}

func wsCloseHandler(nodeName string, conn *websocket.Conn) func(code int, text string) error {
return func(code int, text string) error {
logger.Infof("Websocket connection closed for %s: %d (exit code), message: %q", nodeName, code, text)

wsMutex.Lock()
delete(wsClients, nodeName)
wsMutex.Unlock()

message := websocket.FormatCloseMessage(code, text)
conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second))
return nil
}
}

func isWSConnClosed(nodeName string, err error) bool {
wsMutex.Lock()
defer wsMutex.Unlock()

if _, ok := wsClients[nodeName]; !ok {
return true
}
if _, ok := err.(*websocket.CloseError); ok {
return true
}
if e, ok := err.(*net.OpError); ok && !e.Temporary() {
return true
}
return false
}

func GetClients() map[string]types.WSClient {
return wsClients
}
2 changes: 2 additions & 0 deletions web/src/Root.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import AppLicense from "@components/apps/AppLicense";
import AppRegistrySettings from "@components/apps/AppRegistrySettings";
import AppIdentityServiceSettings from "@components/apps/AppIdentityServiceSettings";
import TroubleshootContainer from "@components/troubleshoot/TroubleshootContainer";
import DebugInfo from "@components/DebugInfo";

import Footer from "./components/shared/Footer";
import NavBar from "./components/shared/NavBar";
Expand Down Expand Up @@ -749,6 +750,7 @@ const Root = () => {
/>
<Route path="/crashz" element={<Crashz />} />{" "}
<Route path="*" element={<NotFound />} />
<Route path="/debug" element={<DebugInfo />} />
<Route
path="/secure-console"
element={
Expand Down
Loading

0 comments on commit 767739e

Please sign in to comment.