Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

redbean: add tls socket lua binding #1279

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
3 changes: 2 additions & 1 deletion tool/net/BUILD.mk
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ TOOL_NET_REDBEAN_LUA_MODULES = \
o/$(MODE)/tool/net/lmaxmind.o \
o/$(MODE)/tool/net/lsqlite3.o \
o/$(MODE)/tool/net/largon2.o \
o/$(MODE)/tool/net/launch.o
o/$(MODE)/tool/net/launch.o \
o/$(MODE)/tool/net/ltls.o

o/$(MODE)/tool/net/redbean.dbg: \
$(TOOL_NET_DEPS) \
Expand Down
31 changes: 31 additions & 0 deletions tool/net/definitions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4982,6 +4982,37 @@ unix = {
X_OK = nil
}

---@class TlsContext
---@field connect fun(self: TlsContext, server_name: string, server_port: string): boolean, string?
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of API design, have you considered this?

tls = require "tls"
conn = tls.TlsClient(ResolveIp("google.com"), 80) -- returns object of TlsClient class
conn.write("GET / HTTP/1.0\r\n\r\n")
print(conn.read())
conn.close()

That would be more similar and composable with other redbean APIs. For example, we like to pass IPs as uint32. You wouldn't need to maintain a state machine with this design. You could have functions or variables associated with the tls module for doing context-wide configuration, like whether or not SSL client verification should be enabled. You could also have that default to the redbean settings.

Another thing you might do that's even better is:

tls = require "tls"
unix = require "unix"

fd = assert(unix.socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP))
assert(unix.connect(fd, ResolveIp("google.com"), 80))
conn = assert(tls.TlsClient(fd)) -- returns object of TlsSocket class
assert(conn.write("GET / HTTP/1.0\r\n\r\n"))
response = assert(conn.read())
print(response)
assert(unix.close(fd))

I noticed you're drawing a lot of influence from mbedTLS's net_sockets.c API. I don't like that API. I don't think it's very good. If you use that abstraction, then you lose the ability to compose with redbean APIs like ResolveIp(), unix.poll(), etc. For examples of how I've made mbedTLS work with raw file descriptors, see tool/curl/curl.c and tool/build/lib/eztls.c.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API indeed seems more natural. I will make the modifications to get as close as possible to this example. Thank you

Copy link
Author

@chamot1111 chamot1111 Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use now:

---@class tls
local tls = {}

--- Creates a new TLS client.
---@param fd integer File descriptor of the socket
---@param verify? boolean Whether to verify the server's certificate (default: true)
---@param timeout? integer Read timeout in milliseconds (default: 0, no timeout)
---@return TlsContext|nil context
---@return string? error
function tls.TlsClient(fd, verify, timeout) end

--- Writes data to the TLS connection.
---@param context TlsContext
---@param data string
---@return integer bytes_written
---@return string? error
function tls:write(data) end

--- Reads data from the TLS connection.
---@param context TlsContext
---@param bufsiz? integer Maximum number of bytes to read (default: BUFSIZ)
---@return string? data
---@return string? error
function tls:read(bufsiz) end

---@field write fun(self: TlsContext, data: string): integer, string?
---@field read fun(self: TlsContext, bufsiz?: integer): string?, string?
---@field close fun(self: TlsContext)

---@class tls
local tls = {}

--- Creates a new TLS client.
---@param fd integer File descriptor of the socket
---@param verify? boolean Whether to verify the server's certificate (default: true)
---@param timeout? integer Read timeout in milliseconds (default: 0, no timeout)
---@return TlsContext|nil context
---@return string? error
function tls.TlsClient(fd, verify, timeout) end

--- Writes data to the TLS connection.
---@param context TlsContext
---@param data string
---@return integer bytes_written
---@return string? error
function tls:write(data) end

--- Reads data from the TLS connection.
---@param context TlsContext
---@param bufsiz? integer Maximum number of bytes to read (default: BUFSIZ)
---@return string? data
---@return string? error
function tls:read(bufsiz) end

--- Opens file.
---
--- Returns a file descriptor integer that needs to be closed, e.g.
Expand Down
292 changes: 292 additions & 0 deletions tool/net/ltls.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│
│ vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi │
╞══════════════════════════════════════════════════════════════════════════════╡
│ Copyright 2022 Justine Alexandra Roberts Tunney │
│ │
│ Permission to use, copy, modify, and/or distribute this software for │
│ any purpose with or without fee is hereby granted, provided that the │
│ above copyright notice and this permission notice appear in all copies. │
│ │
│ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │
│ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │
│ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │
│ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │
│ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │
│ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │
│ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │
│ PERFORMANCE OF THIS SOFTWARE. │
╚─────────────────────────────────────────────────────────────────────────────*/
#include "ltls.h"
#include "libc/calls/struct/iovec.h"
#include "net/https/https.h"
#include "third_party/mbedtls/ctr_drbg.h"
#include "third_party/mbedtls/debug.h"
#include "third_party/mbedtls/entropy.h"
#include "third_party/mbedtls/iana.h"
#include "third_party/mbedtls/net_sockets.h"
#include "third_party/mbedtls/oid.h"
#include "third_party/mbedtls/san.h"
#include "third_party/mbedtls/ssl.h"
#include "third_party/mbedtls/ssl_ticket.h"
#include "third_party/mbedtls/x509.h"
#include "third_party/mbedtls/x509_crt.h"

#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif

static const char *const tls_meta = ":mbedtls";

typedef struct {
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_ssl_context ssl;
mbedtls_ssl_config conf;
int ref; // Reference to self in the Lua registry
char *read_buffer;
size_t read_buffer_size;
int fd; // File descriptor
} TlsContext;

static TlsContext **checktls(lua_State *L) {
TlsContext **tls = (TlsContext **)luaL_checkudata(L, 1, tls_meta);
if (tls == NULL || *tls == NULL)
luaL_typeerror(L, 1, tls_meta);
return tls;
}

static int TlsSend(void *c, const unsigned char *p, size_t n) {
int rc;
if ((rc = write(*(int *)c, p, n)) == -1) {
return -1; // Return error code instead of exiting
}
return rc;
}

static int TlsRecv(void *c, unsigned char *p, size_t n, uint32_t o) {
int r;
struct iovec v[2];
static unsigned a, b;
static unsigned char t[4096];
if (a < b) {
r = MIN(n, b - a);
memcpy(p, t + a, r);
if ((a += r) == b) {
a = b = 0;
}
return r;
}
v[0].iov_base = p;
v[0].iov_len = n;
v[1].iov_base = t;
v[1].iov_len = sizeof(t);
if ((r = readv(*(int *)c, v, 2)) == -1) {
return -1; // Return error code instead of exiting
}
if (r > n) {
b = r - n;
}
return MIN(n, r);
}

static int tls_gc(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;

if (tls) {
mbedtls_ssl_free(&tls->ssl);
mbedtls_ssl_config_free(&tls->conf);
mbedtls_ctr_drbg_free(&tls->ctr_drbg);
mbedtls_entropy_free(&tls->entropy);
luaL_unref(L, LUA_REGISTRYINDEX, tls->ref);
free(tls->read_buffer);
free(tls);
*tlsp = NULL;
}
return 0;
}

static void my_debug(void *ctx, int level, const char *file, int line,
const char *str) {
((void)level);
fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str);
fflush((FILE *)ctx);
}
static int tls_client(lua_State *L) {
int fd = luaL_checkinteger(L, 1);

printf("fd: %d\n", fd);

TlsContext **tlsp = (TlsContext **)lua_newuserdata(L, sizeof(TlsContext *));
*tlsp = NULL;

luaL_getmetatable(L, tls_meta);
lua_setmetatable(L, -2);

TlsContext *tls = (TlsContext *)malloc(sizeof(TlsContext));
if (tls == NULL) {
lua_pushnil(L);
lua_pushstring(L, "Failed to allocate memory for TLS context");
return 2;
}
*tlsp = tls;

tls->read_buffer = NULL;
tls->read_buffer_size = 0;
tls->fd = fd;

mbedtls_ssl_init(&tls->ssl);
mbedtls_ssl_config_init(&tls->conf);
mbedtls_ctr_drbg_init(&tls->ctr_drbg);
mbedtls_entropy_init(&tls->entropy);

int sslVerify = lua_isnone(L, 2) ? 1 : lua_toboolean(L, 2);
if (sslVerify) {
mbedtls_ssl_conf_ca_chain(&tls->conf, GetSslRoots(), 0);
mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
} else {
mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_NONE);
}

int timeout = lua_isnone(L, 3) ? 0 : luaL_checkinteger(L, 3);
mbedtls_ssl_conf_read_timeout(&tls->conf, timeout);

const char *pers = "tls_client";
int ret;
if ((ret = mbedtls_ctr_drbg_seed(&tls->ctr_drbg, mbedtls_entropy_func,
&tls->entropy, (const unsigned char *)pers,
strlen(pers))) != 0) {
free(tls);
*tlsp = NULL;
lua_pushnil(L);
lua_pushfstring(L, "mbedtls_ctr_drbg_seed returned %d", ret);
return 2;
}

if ((ret = mbedtls_ssl_config_defaults(&tls->conf, MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT)) != 0) {
free(tls);
*tlsp = NULL;
lua_pushnil(L);
lua_pushfstring(L, "mbedtls_ssl_config_defaults failed: %d", ret);
return 2;
}

mbedtls_ssl_conf_rng(&tls->conf, mbedtls_ctr_drbg_random, &tls->ctr_drbg);
mbedtls_ssl_conf_dbg(&tls->conf, my_debug, stdout);

if ((ret = mbedtls_ssl_setup(&tls->ssl, &tls->conf)) != 0) {
free(tls);
*tlsp = NULL;
lua_pushnil(L);
lua_pushfstring(L, "mbedtls_ssl_setup returned %d", ret);
return 2;
}

mbedtls_ssl_set_bio(&tls->ssl, &tls->fd, TlsSend, 0, TlsRecv);

if ((ret = mbedtls_ssl_handshake(&tls->ssl)) != 0) {
lua_pushnil(L);
lua_pushfstring(L, "SSL handshake failed: %d", ret);
return 2;
}

tls->ref = luaL_ref(L, LUA_REGISTRYINDEX);
lua_rawgeti(L, LUA_REGISTRYINDEX, tls->ref);

return 1;
}

static int tls_write(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;
size_t len;
const char *data = luaL_checklstring(L, 2, &len);
int ret = mbedtls_ssl_write(&tls->ssl, (const unsigned char *)data, len);

if (ret < 0) {
lua_pushnil(L);
lua_pushfstring(L, "SSL write failed: %d", ret);
return 2;
}

lua_pushinteger(L, ret);
return 1;
}

static int tls_read(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;
lua_Integer bufsiz = luaL_optinteger(L, 2, BUFSIZ);
bufsiz = MIN(bufsiz, 0x7ffff000);

if (tls->read_buffer == NULL || tls->read_buffer_size < bufsiz) {
char *new_buffer = realloc(tls->read_buffer, bufsiz);
if (new_buffer == NULL) {
lua_pushnil(L);
lua_pushstring(L, "Memory allocation failed");
return 2;
}
tls->read_buffer = new_buffer;
tls->read_buffer_size = bufsiz;
}

int ret =
mbedtls_ssl_read(&tls->ssl, (unsigned char *)tls->read_buffer, bufsiz);

if (ret > 0) {
lua_pushlstring(L, tls->read_buffer, ret);
return 1;
} else if (ret == 0) {
// End of file
lua_pushnil(L);
return 1;
} else {
// All negative values are treated as errors
lua_pushnil(L);
lua_pushfstring(L, "Read error: %d", ret);
return 2;
}
}

static int tls_tostring(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;

lua_pushfstring(L, "tls.TlsClient({fd=%d})", tls->fd);
return 1;
}

static const struct luaL_Reg tls_methods[] = {
{"write", tls_write},
{"read", tls_read},
{"__gc", tls_gc},
{"__tostring", tls_tostring},
{"__repr", tls_tostring},
{NULL, NULL}
};

static const struct luaL_Reg tlslib[] = {
{"TlsClient", tls_client},
{NULL, NULL}
};

static void create_meta(lua_State *L, const char *name,
const luaL_Reg *methods) {
luaL_newmetatable(L, name);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
luaL_setfuncs(L, methods, 0);
}

LUALIB_API int luaopen_tls(lua_State *L) {
create_meta(L, tls_meta, tls_methods);

luaL_newlib(L, tlslib);

lua_pushvalue(L, -1);
lua_setmetatable(L, -2);

return 1;
}
9 changes: 9 additions & 0 deletions tool/net/ltls.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef COSMOPOLITAN_TOOL_NET_LTLS_H_
#define COSMOPOLITAN_TOOL_NET_LTLS_H_
#include "third_party/lua/lauxlib.h"
COSMOPOLITAN_C_START_

int luaopen_tls(lua_State *);

COSMOPOLITAN_C_END_
#endif /* COSMOPOLITAN_TOOL_NET_LTLS_H_ */
4 changes: 4 additions & 0 deletions tool/net/redbean.c
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
#include "tool/net/lfuncs.h"
#include "tool/net/ljson.h"
#include "tool/net/lpath.h"
#include "tool/net/ltls.h"
#include "tool/net/luacheck.h"
#include "tool/net/sandbox.h"

Expand Down Expand Up @@ -5401,6 +5402,9 @@ static const luaL_Reg kLuaFuncs[] = {
static const luaL_Reg kLuaLibs[] = {
{"argon2", luaopen_argon2}, //
{"lsqlite3", luaopen_lsqlite3}, //
#ifndef UNSECURE
{"tls", luaopen_tls}, //
#endif
{"maxmind", LuaMaxmind}, //
{"finger", LuaFinger}, //
{"path", LuaPath}, //
Expand Down