Skip to content

Commit

Permalink
Minor changes to upgrade logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
can1357 committed Mar 9, 2024
1 parent 122fca5 commit 0693555
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 32 deletions.
42 changes: 19 additions & 23 deletions example/api/lambda.ts
Original file line number Diff line number Diff line change
@@ -1,40 +1,36 @@
import { WebSocket, createWebSocketStream } from "ws";

type LambdaHandler = Record<string, (body: any) => any> | ((method: string, body: any) => any);
type LambdaHandler = Record<string, (body: any) => any>;
async function newLambda(handler: LambdaHandler): Promise<string> {
let cb: (method: string, body: any) => Promise<any>;
if (typeof handler === "function") {
cb = async (method, body) => {
return await handler(method, body);
};
} else {
cb = async (method, body) => {
if (handler[method]) {
return await handler[method](body);
} else {
throw new Error(`Method not found: ${method}`);
const duplex = createWebSocketStream(new WebSocket("ws://pm3/lambda/new"), { encoding: "utf-8" });

async function handle(id: any, method: string, body: any) {
let resp;
try {
const cb = handler[method];
if (!cb) {
throw new Error(`method not found: ${method}`);
}
};
const result = await cb(body);
resp = { id, result };
} catch (error) {
resp = { id, error: `${error}` };
}
console.log("->", resp);
duplex.write(JSON.stringify(resp));
}

const duplex = createWebSocketStream(new WebSocket("ws://pm3/lambda/new"), { encoding: "utf-8" });
return new Promise(async (resolve) => {
for await (const message of duplex) {
console.log("<-", message);
const { method, params, id } = JSON.parse(message);
if (method === "open") {
console.log("->", { id, result: "ok" });
duplex.write(JSON.stringify({ id, result: "ok" }));
resolve(params[0]);
continue;
}

cb(method, params[0] ?? {})
.then(
(result) => ({ id, result }),
(error) => ({ id, error: error.message })
)
.then((response) => {
duplex.write(JSON.stringify(response));
});
handle(id, method, params[0]);
}
});
}
Expand Down
17 changes: 16 additions & 1 deletion pmtp/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"net"
"net/http"
"net/url"
"sync/atomic"
"time"

Expand All @@ -29,6 +30,20 @@ func MakeUpgradeServer[Proto ServerProtocol[Arg], Arg any](protos ...Proto) Upgr
Websocket: websocket.Upgrader{
ReadBufferSize: 32 * 1024,
WriteBufferSize: 32 * 1024,
CheckOrigin: func(r *http.Request) bool {
// No browser will send a request with Origin not set or invalid.
origin := r.Header.Get("Origin")
if origin == "" {
return true
}
url, err := url.Parse(origin)
if err != nil {
return true
}

// Accept if pm3, localhost, 127.0.0.1.
return url.Host == "pm3" || url.Host == "localhost" || url.Host == "127.0.0.1"
},
},
}
for _, proto := range protos {
Expand All @@ -43,7 +58,7 @@ func (u *UpgradeServer[Proto, Arg]) Upgrade(w http.ResponseWriter, r *http.Reque
vhttp.Error(w, r, http.StatusMethodNotAllowed)
return
}
if r.Header.Get("Connection") != "Upgrade" {
if conn := r.Header.Get("Connection"); conn != "Upgrade" && conn != "upgrade" {
vhttp.Error(w, r, http.StatusBadRequest)
return
}
Expand Down
12 changes: 4 additions & 8 deletions session/api_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ type ServiceMetrics struct {
Type string `json:"type"`
Server lb.LoadBalancerMetrics `json:"server"`
Processes []service.ProcTreeMetrics `json:"processes"`
}
type ServiceInfo struct {
ServiceMetrics
ServiceHealth
}

Expand Down Expand Up @@ -70,6 +67,7 @@ func (m *ServiceMetrics) Fill(sv *ServiceState) {
return
}
m.ID = sv.ID
m.ServiceHealth.Fill(sv)

if sv.Instance != nil {
ty := reflect.TypeOf(sv.Instance)
Expand All @@ -93,10 +91,6 @@ func (m *ServiceMetrics) Fill(sv *ServiceState) {
m.Server = l.Metrics()
}
}
func (m *ServiceInfo) Fill(sv *ServiceState) {
m.ServiceMetrics.Fill(sv)
m.ServiceHealth.Fill(sv)
}

func registerServiceView(name string, view func(*ServiceState) any) {
Match("/service/"+name+"/{svc}", func(session *Session, r *http.Request, _ struct{}) (h any, _ error) {
Expand Down Expand Up @@ -132,11 +126,13 @@ func init() {
m.Fill(sv)
return m
})
// deprecated alias
registerServiceView("info", func(sv *ServiceState) any {
var m ServiceInfo
var m ServiceMetrics
m.Fill(sv)
return m
})

Match("/service", func(session *Session, r *http.Request, _ struct{}) (res map[string]snowflake.ID, _ error) {
res = make(map[string]snowflake.ID)
session.ServiceMap.Range(func(_ string, v *ServiceState) bool {
Expand Down

0 comments on commit 0693555

Please sign in to comment.