diff --git a/libc/str/isutf8.c b/libc/str/isutf8.c index 6c9a6334eaa..9dae24c6221 100644 --- a/libc/str/isutf8.c +++ b/libc/str/isutf8.c @@ -27,8 +27,8 @@ static const char kUtf8Dispatch[] = { 1, 1, 1, 1, 1, 1, 1, 1, // 0320 1, 1, 1, 1, 1, 1, 1, 1, // 0330 2, 3, 3, 3, 3, 3, 3, 3, // 0340 utf8-3 - 3, 3, 3, 3, 3, 3, 3, 3, // 0350 - 4, 5, 5, 5, 5, 0, 0, 0, // 0360 utf8-4 + 3, 3, 3, 3, 3, 4, 3, 3, // 0350 + 5, 6, 6, 6, 7, 0, 0, 0, // 0360 utf8-4 0, 0, 0, 0, 0, 0, 0, 0, // 0370 }; @@ -94,6 +94,7 @@ bool32 isutf8(const void *data, size_t size) { } // fallthrough case 3: + case3: if (p + 2 <= e && // (p[0] & 0300) == 0200 && // (p[1] & 0300) == 0200) { // @@ -103,11 +104,17 @@ bool32 isutf8(const void *data, size_t size) { return false; // missing cont } case 4: + if (p < e && (*p & 040)) { + return false; // utf-16 surrogate + } + goto case3; + case 5: if (p < e && (*p & 0377) < 0220) { return false; // overlong } // fallthrough - case 5: + case 6: + case6: if (p + 3 <= e && // (((uint32_t)(p[+2] & 0377) << 030 | // (uint32_t)(p[+1] & 0377) << 020 | // @@ -119,6 +126,11 @@ bool32 isutf8(const void *data, size_t size) { } else { return false; // missing cont } + case 7: + if (p < e && (*p & 0x3F) > 0xF) { + return false; // over limit + } + goto case6; default: __builtin_unreachable(); } diff --git a/net/http/gethttpheader.gperf b/net/http/gethttpheader.gperf index 26c3c2dd013..fad88e6bb15 100644 --- a/net/http/gethttpheader.gperf +++ b/net/http/gethttpheader.gperf @@ -104,3 +104,4 @@ CF-Visitor, kHttpCfVisitor CF-Connecting-IP, kHttpCfConnectingIp CF-IPCountry, kHttpCfIpcountry CDN-Loop, kHttpCdnLoop +Sec-WebSocket-Key, kHttpWebsocketKey diff --git a/net/http/gethttpheader.inc b/net/http/gethttpheader.inc index 72f3b7afe60..ae5a682a03d 100644 --- a/net/http/gethttpheader.inc +++ b/net/http/gethttpheader.inc @@ -39,7 +39,7 @@ #line 12 "gethttpheader.gperf" struct thatispacked HttpHeaderSlot { char *name; char code; }; -#define TOTAL_KEYWORDS 93 +#define TOTAL_KEYWORDS 94 #define MIN_WORD_LENGTH 2 #define MAX_WORD_LENGTH 32 #define MIN_HASH_VALUE 3 @@ -387,7 +387,10 @@ LookupHttpHeader (register const char *str, register size_t len) #line 87 "gethttpheader.gperf" {"Strict-Transport-Security", kHttpStrictTransportSecurity}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, - {""}, {""}, {""}, {""}, {""}, + {""}, {""}, +#line 107 "gethttpheader.gperf" + {"Sec-WebSocket-Key", kHttpWebsocketKey}, + {""}, {""}, #line 22 "gethttpheader.gperf" {"X-Forwarded-For", kHttpXForwardedFor}, {""}, diff --git a/net/http/gethttpheadername.c b/net/http/gethttpheadername.c index 898cf327a8d..b01f68e1fc7 100644 --- a/net/http/gethttpheadername.c +++ b/net/http/gethttpheadername.c @@ -206,6 +206,8 @@ const char *GetHttpHeaderName(int h) { return "CDN-Loop"; case kHttpSecChUaPlatform: return "Sec-CH-UA-Platform"; + case kHttpWebsocketKey: + return "Sec-WebSocket-Key"; default: return NULL; } diff --git a/net/http/http.h b/net/http/http.h index 5e70c0370fa..ee196399942 100644 --- a/net/http/http.h +++ b/net/http/http.h @@ -146,7 +146,8 @@ #define kHttpCfIpcountry 90 #define kHttpSecChUaPlatform 91 #define kHttpCdnLoop 92 -#define kHttpHeadersMax 93 +#define kHttpWebsocketKey 93 +#define kHttpHeadersMax 94 #if !(__ASSEMBLER__ + __LINKER__ + 0) COSMOPOLITAN_C_START_ diff --git a/tool/net/redbean.c b/tool/net/redbean.c index f7f934bf895..22ea79c998b 100644 --- a/tool/net/redbean.c +++ b/tool/net/redbean.c @@ -43,6 +43,7 @@ #include "libc/intrin/bits.h" #include "libc/intrin/bsr.h" #include "libc/intrin/likely.h" +#include "libc/intrin/newbie.h" #include "libc/intrin/nomultics.internal.h" #include "libc/intrin/safemacros.internal.h" #include "libc/log/appendresourcereport.internal.h" @@ -124,6 +125,7 @@ #include "third_party/mbedtls/net_sockets.h" #include "third_party/mbedtls/oid.h" #include "third_party/mbedtls/san.h" +#include "third_party/mbedtls/sha1.h" #include "third_party/mbedtls/ssl.h" #include "third_party/mbedtls/ssl_ticket.h" #include "third_party/mbedtls/x509.h" @@ -410,6 +412,7 @@ struct ClearedPerMessage { bool hascontenttype; bool gotcachecontrol; bool gotxcontenttypeoptions; + char wstype; int frags; int statuscode; int isyielding; @@ -490,6 +493,8 @@ static uint8_t *zmap; static uint8_t *zcdir; static size_t hdrsize; static size_t amtread; +static size_t wsfragread; +static char wsfragtype; static reader_f reader; static writer_f writer; static char *extrahdrs; @@ -5053,6 +5058,195 @@ static bool LuaRunAsset(const char *path, bool mandatory) { return !!a; } +static int LuaWSUpgrade(lua_State *L) { + size_t i; + char *p, *q; + bool haskey; + mbedtls_sha1_context ctx; + unsigned char hash[20]; + + if (cpm.generator) + luaL_error(L, "Cannot upgrade to websocket after yielding normally"); + + if (!HasHeader(kHttpWebsocketKey)) + luaL_error(L, "No Sec-WebSocket-Key header"); + + mbedtls_sha1_init(&ctx); + mbedtls_sha1_starts_ret(&ctx); + mbedtls_sha1_update_ret(&ctx, (unsigned char*) + HeaderData(kHttpWebsocketKey), + HeaderLength(kHttpWebsocketKey)); + + p = SetStatus(101, "Switching Protocols"); + while (p - hdrbuf.p + (20 + 21 + (20 + 28 + 4)) + 512 > hdrbuf.n) { + hdrbuf.n += hdrbuf.n >> 1; + q = xrealloc(hdrbuf.p, hdrbuf.n); + cpm.luaheaderp = p = q + (p - hdrbuf.p); + hdrbuf.p = q; + } + + mbedtls_sha1_update_ret( + &ctx, (unsigned char *)"258EAFA5-E914-47DA-95CA-C5AB0DC85B11", 36); + mbedtls_sha1_finish_ret(&ctx, hash); + char *accept = EncodeBase64((char *)hash, 20, NULL); + + p = stpcpy(p, "Upgrade: websocket\r\n"); + p = stpcpy(p, "Connection: upgrade\r\n"); + p = AppendHeader(p, "Sec-WebSocket-Accept", accept); + + cpm.luaheaderp = p; + cpm.wstype = 1; + + return 0; +} + +static int LuaWSRead(lua_State *L) { + ssize_t rc; + size_t i, got, amt, bufsize; + unsigned char wshdr[10], wshdrlen, *extlen, *mask, op; + char *bufstart; + uint64_t len; + struct iovec iov[2]; + OnlyCallDuringRequest(L, "ws.Read"); + + got = 0; + do { + if ((rc = reader(client, wshdr + got, 2 - got)) == -1) + luaL_error(L, "Could not read WS header"); + } while ((got += rc) < 2); + + op = wshdr[0] & 0xF; + + if (wshdr[0] & 0x70) goto close; // reserved bit set + if (!(wshdr[1] | (1 << 7))) goto close; // unmasked + if ((wshdr[0] & 0x7) >= 0x3) goto close; // reserved opcode + if (!wsfragtype && !op) goto close; // not in continuation + + len = wshdr[1] & ~(1 << 7); + if (wshdr[0] & 0x8) { // control frame + if (!(wshdr[0] & 0x80) || len >= 126) goto close; // fragmented or too long + } else { + if (op && wsfragtype) goto close; // during fragmented seq + } + + wshdrlen = 6; + if (len == 126) { + wshdrlen = 8; + } else if (len == 127) { + wshdrlen = 14; + } + + while (got < wshdrlen) { + if ((rc = reader(client, wshdr + got, wshdrlen - got)) == -1) + luaL_error(L, "Could not read WS extended length"); + got += rc; + } + + extlen = &wshdr[2]; + mask = &wshdr[wshdrlen - 4]; + if (len == 126) { + len = be16toh(*(uint16_t *)extlen); + } else if (len == 127) { + len = be64toh(*(uint64_t *)extlen); + } + + if (len >= inbuf.n - wsfragread) + luaL_error(L, "Required %d bytes to read WS frame, %d bytes available", len, + inbuf.n - wsfragread); + + for (got = 0, amt = wsfragread; got < len; got += rc, amt += rc) { + if ((rc = reader(client, inbuf.p + amt, len - got)) == -1) + luaL_error(L, "Could not read WS data"); + } + + for (i = 0, amt = wsfragread; i < got; ++i, ++amt) + inbuf.p[amt] ^= mask[i & 0x3]; + + if (op == 0x9) { + wshdr[0] = (wshdr[0] & ~0xF) | 0xA; + wshdr[1] = wshdr[1] & ~0x80; + iov[0].iov_base = wshdr; + iov[0].iov_len = wshdrlen - 4; + iov[1].iov_base = inbuf.p + wsfragread; + iov[1].iov_len = got; + Send(iov, 2); + } + + if (wshdr[0] & 0x80) { + if (op) { + bufstart = inbuf.p + wsfragread; + bufsize = got; + + if (op == 0x1 && !isutf8(bufstart, bufsize)) goto close; + lua_pushlstring(L, bufstart, bufsize); + lua_pushinteger(L, op); + } else { + bufstart = inbuf.p + amtread; + bufsize = (wsfragread - amtread) + got; + + if (wsfragtype == 0x1 && !isutf8(bufstart, bufsize)) goto close; + lua_pushlstring(L, bufstart, bufsize); + lua_pushinteger(L, wsfragtype); + + wsfragread = amtread; + wsfragtype = 0; + } + } else { + lua_pushnil(L); + lua_pushinteger(L, 0); + + if (!wsfragtype) wsfragtype = op; + wsfragread += got; + } + + return 2; + +close: + lua_pushnil(L); + lua_pushinteger(L, 0x08); + return 2; +} + +static int LuaWSWrite(lua_State *L) { + int type; + size_t size; + const char *data; + + OnlyCallDuringRequest(L, "ws.Write"); + if (!cpm.wstype) + LuaWSUpgrade(L); + + type = luaL_optinteger(L, 2, -1); + if (type == 1 || type == 2) { + cpm.wstype = type; + } else if (type != -1) { + luaL_error(L, "Invalid WS type"); + } + + if (!lua_isnil(L, 1)) { + data = luaL_checklstring(L, 1, &size); + appendd(&cpm.outbuf, data, size); + } + return 0; +} + +static const luaL_Reg kLuaWS[] = { + {"Read", LuaWSRead}, // + {"Write", LuaWSWrite}, // + {0} // +}; + +int LuaWS(lua_State *L) { + luaL_newlib(L, kLuaWS); + lua_pushinteger(L, 0); lua_setfield(L, -2, "CONT"); + lua_pushinteger(L, 1); lua_setfield(L, -2, "TEXT"); + lua_pushinteger(L, 2); lua_setfield(L, -2, "BIN"); + lua_pushinteger(L, 8); lua_setfield(L, -2, "CLOSE"); + lua_pushinteger(L, 9); lua_setfield(L, -2, "PING"); + lua_pushinteger(L, 10); lua_setfield(L, -2, "PONG"); + return 1; +} + // // list of functions that can't be run from the repl static const char *const kDontAutoComplete[] = { @@ -5317,6 +5511,7 @@ static const luaL_Reg kLuaLibs[] = { {"path", LuaPath}, // {"re", LuaRe}, // {"unix", LuaUnix}, // + {"ws", LuaWS} // }; static void LuaSetArgv(lua_State *L) { @@ -6334,6 +6529,73 @@ static bool StreamResponse(char *p) { return true; } +static bool StreamWS(char *p) { + ssize_t rc; + struct iovec iov[2]; + char *s, wshdr[10], *extlen; + int nresults, status; + + p = AppendCrlf(p); + CHECK_LE(p - hdrbuf.p, hdrbuf.n); + if (logmessages) { + LogMessage("sending", hdrbuf.p, p - hdrbuf.p); + } + iov[0].iov_base = hdrbuf.p; + iov[0].iov_len = p - hdrbuf.p; + Send(iov, 1); + + bzero(iov, sizeof(iov)); + iov[0].iov_base = wshdr; + + extlen = &wshdr[2]; + wsfragread = amtread; + wsfragtype = 0; + + for (;;) { + if (!YL || lua_status(YL) != LUA_YIELD) break; // done yielding + cpm.contentlength = 0; + status = lua_resume(YL, NULL, 0, &nresults); + if (status == LUA_OK) { + lua_pop(YL, nresults); + break; + } else if (status != LUA_YIELD) { + LogLuaError("resume", lua_tostring(YL, -1)); + lua_pop(YL, 1); + break; + } + lua_pop(YL, nresults); + if (!cpm.contentlength) UseOutput(); + + DEBUGF("(lua) ws yielded with %ld bytes generated", cpm.contentlength); + + iov[1].iov_base = cpm.content; + iov[1].iov_len = rc = cpm.contentlength; + + if (rc < 126) { + wshdr[1] = rc; + iov[0].iov_len = 2; + } else if (rc <= 0xFFFF) { + wshdr[1] = 126; + *(uint16_t *)extlen = htobe16(rc); + iov[0].iov_len = 4; + } else { + wshdr[1] = 127; + *(uint64_t *)extlen = htobe64(rc); + iov[0].iov_len = 10; + } + wshdr[0] = cpm.wstype | (1 << 7); + if (Send(iov, 2) == -1) break; + } + + wshdr[0] = 0x8 | (1 << 7); + wshdr[1] = 0; + iov[0].iov_len = 2; + Send(iov, 1); + connectionclose = true; + + return true; +} + static bool HandleMessageActual(void) { int rc; long reqtime, contime; @@ -6392,6 +6654,8 @@ static bool HandleMessageActual(void) { } if (!cpm.generator) { return TransmitResponse(p); + } else if (cpm.wstype) { + return StreamWS(p); } else { return StreamResponse(p); }