From 069355504e1fe6077d057ee1bb6a46d9324ab05c Mon Sep 17 00:00:00 2001 From: can1357 Date: Sat, 9 Mar 2024 05:15:33 +0100 Subject: [PATCH] Minor changes to upgrade logic. --- example/api/lambda.ts | 42 +++++++++++++++++++----------------------- pmtp/upgrade.go | 17 ++++++++++++++++- session/api_service.go | 12 ++++-------- 3 files changed, 39 insertions(+), 32 deletions(-) diff --git a/example/api/lambda.ts b/example/api/lambda.ts index bd7d8b3..7e76d5f 100644 --- a/example/api/lambda.ts +++ b/example/api/lambda.ts @@ -1,40 +1,36 @@ import { WebSocket, createWebSocketStream } from "ws"; -type LambdaHandler = Record any> | ((method: string, body: any) => any); +type LambdaHandler = Record any>; async function newLambda(handler: LambdaHandler): Promise { - let cb: (method: string, body: any) => Promise; - if (typeof handler === "function") { - cb = async (method, body) => { - return await handler(method, body); - }; - } else { - cb = async (method, body) => { - if (handler[method]) { - return await handler[method](body); - } else { - throw new Error(`Method not found: ${method}`); + const duplex = createWebSocketStream(new WebSocket("ws://pm3/lambda/new"), { encoding: "utf-8" }); + + async function handle(id: any, method: string, body: any) { + let resp; + try { + const cb = handler[method]; + if (!cb) { + throw new Error(`method not found: ${method}`); } - }; + const result = await cb(body); + resp = { id, result }; + } catch (error) { + resp = { id, error: `${error}` }; + } + console.log("->", resp); + duplex.write(JSON.stringify(resp)); } - const duplex = createWebSocketStream(new WebSocket("ws://pm3/lambda/new"), { encoding: "utf-8" }); return new Promise(async (resolve) => { for await (const message of duplex) { + console.log("<-", message); const { method, params, id } = JSON.parse(message); if (method === "open") { + console.log("->", { id, result: "ok" }); duplex.write(JSON.stringify({ id, result: "ok" })); resolve(params[0]); continue; } - - cb(method, params[0] ?? {}) - .then( - (result) => ({ id, result }), - (error) => ({ id, error: error.message }) - ) - .then((response) => { - duplex.write(JSON.stringify(response)); - }); + handle(id, method, params[0]); } }); } diff --git a/pmtp/upgrade.go b/pmtp/upgrade.go index 0c77427..b357f13 100644 --- a/pmtp/upgrade.go +++ b/pmtp/upgrade.go @@ -4,6 +4,7 @@ import ( "bytes" "net" "net/http" + "net/url" "sync/atomic" "time" @@ -29,6 +30,20 @@ func MakeUpgradeServer[Proto ServerProtocol[Arg], Arg any](protos ...Proto) Upgr Websocket: websocket.Upgrader{ ReadBufferSize: 32 * 1024, WriteBufferSize: 32 * 1024, + CheckOrigin: func(r *http.Request) bool { + // No browser will send a request with Origin not set or invalid. + origin := r.Header.Get("Origin") + if origin == "" { + return true + } + url, err := url.Parse(origin) + if err != nil { + return true + } + + // Accept if pm3, localhost, 127.0.0.1. + return url.Host == "pm3" || url.Host == "localhost" || url.Host == "127.0.0.1" + }, }, } for _, proto := range protos { @@ -43,7 +58,7 @@ func (u *UpgradeServer[Proto, Arg]) Upgrade(w http.ResponseWriter, r *http.Reque vhttp.Error(w, r, http.StatusMethodNotAllowed) return } - if r.Header.Get("Connection") != "Upgrade" { + if conn := r.Header.Get("Connection"); conn != "Upgrade" && conn != "upgrade" { vhttp.Error(w, r, http.StatusBadRequest) return } diff --git a/session/api_service.go b/session/api_service.go index 3882c05..744c12c 100644 --- a/session/api_service.go +++ b/session/api_service.go @@ -22,9 +22,6 @@ type ServiceMetrics struct { Type string `json:"type"` Server lb.LoadBalancerMetrics `json:"server"` Processes []service.ProcTreeMetrics `json:"processes"` -} -type ServiceInfo struct { - ServiceMetrics ServiceHealth } @@ -70,6 +67,7 @@ func (m *ServiceMetrics) Fill(sv *ServiceState) { return } m.ID = sv.ID + m.ServiceHealth.Fill(sv) if sv.Instance != nil { ty := reflect.TypeOf(sv.Instance) @@ -93,10 +91,6 @@ func (m *ServiceMetrics) Fill(sv *ServiceState) { m.Server = l.Metrics() } } -func (m *ServiceInfo) Fill(sv *ServiceState) { - m.ServiceMetrics.Fill(sv) - m.ServiceHealth.Fill(sv) -} func registerServiceView(name string, view func(*ServiceState) any) { Match("/service/"+name+"/{svc}", func(session *Session, r *http.Request, _ struct{}) (h any, _ error) { @@ -132,11 +126,13 @@ func init() { m.Fill(sv) return m }) + // deprecated alias registerServiceView("info", func(sv *ServiceState) any { - var m ServiceInfo + var m ServiceMetrics m.Fill(sv) return m }) + Match("/service", func(session *Session, r *http.Request, _ struct{}) (res map[string]snowflake.ID, _ error) { res = make(map[string]snowflake.ID) session.ServiceMap.Range(func(_ string, v *ServiceState) bool {