Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add database and statement readonly checks to sqlite in redbean #914

Merged
merged 4 commits into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions test/tool/net/sqlite_test.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
-- Copyright 2023 Justine Alexandra Roberts Tunney
--
-- Permission to use, copy, modify, and/or distribute this software for
-- any purpose with or without fee is hereby granted, provided that the
-- above copyright notice and this permission notice appear in all copies.
--
-- THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL
-- WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
-- WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE
-- AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL
-- DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR
-- PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER
-- TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
-- PERFORMANCE OF THIS SOFTWARE.

local sqlite3 = require "lsqlite3"
local db = assert(sqlite3.open("file:/memdb1?vfs=memdb&mode=ro",
sqlite3.OPEN_URI + sqlite3.OPEN_READWRITE + sqlite3.OPEN_CREATE))
assert(db:readonly() == true)
db = sqlite3.open("file:/memdb1?vfs=memdb",
sqlite3.OPEN_URI + sqlite3.OPEN_READWRITE + sqlite3.OPEN_CREATE)
assert(db:readonly() == false)
assert(db:readonly("main") == false)
assert(db:readonly("foo") == nil)

assert(db:exec("create table foo(a)") == 0)

local st = assert(db:prepare("select * from foo"))
assert(st:readonly() == true)
st = assert(db:prepare("insert into foo (a) values (1)"))
assert(st:readonly() == false)
53 changes: 38 additions & 15 deletions tool/net/lsqlite3.c
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ static int dbvm_isopen(lua_State *L) {
return 1;
}

static int dbvm_readonly(lua_State *L) {
sdb_vm *svm = lsqlite_checkvm(L, 1);
lua_pushboolean(L, sqlite3_stmt_readonly(svm->vm));
return 1;
}

static int dbvm_tostring(lua_State *L) {
char buff[40];
sdb_vm *svm = lsqlite_getvm(L, 1);
Expand Down Expand Up @@ -908,6 +914,21 @@ static int pusherr(lua_State *L, int rc) {
return 2;
}

static int pusherrstr(lua_State *L, char *str) {
lua_pushnil(L);
lua_pushstring(L, str);
return 2;
}

static int db_readonly(lua_State *L) {
sdb *db = lsqlite_checkdb(L, 1);
const char *zDb = luaL_optstring(L, 2, "main");
int res = sqlite3_db_readonly(db->db, zDb);
if (res == -1) return pusherrstr(L, "unknown (not attached) database name");
lua_pushboolean(L, res);
return 1;
}

static int db_wal_checkpoint(lua_State *L) {
sdb *db = lsqlite_checkdb(L, 1);
int eMode = luaL_optinteger(L, 2, SQLITE_CHECKPOINT_PASSIVE);
Expand Down Expand Up @@ -1748,39 +1769,37 @@ static int db_gc(lua_State *L) {
return 0;
}

#ifdef SQLITE_ENABLE_DESERIALIZE

static int db_serialize(lua_State *L) {
sdb *db = lsqlite_getdb(L, 1);
sdb *db = lsqlite_checkdb(L, 1);
sqlite_int64 size = 0;

if (db->db == NULL) /* ignore closed databases */
return 0;

char *buffer = (char *)sqlite3_serialize(db->db, "main", &size, 0);
if (buffer == NULL) /* ignore failed database serialization */
return 0;
if (buffer == NULL)
return pusherrstr(L, "failed to serialize");

lua_pushlstring(L, buffer, size);
free(buffer);
return 1;
}

static int db_deserialize(lua_State *L) {
sdb *db = lsqlite_getdb(L, 1);
sdb *db = lsqlite_checkdb(L, 1);
size_t size = 0;

if (db->db == NULL) /* ignore closed databases */
return 0;

const char *buffer = luaL_checklstring(L, 2, &size);
if (buffer == NULL || size == 0) /* ignore empty database content */
return 0;
if (buffer == NULL || size == 0)
return pusherrstr(L, "failed to deserialize");

const char *sqlbuf = memcpy(sqlite3_malloc(size), buffer, size);
sqlite3_deserialize(db->db, "main", (void *)sqlbuf, size, size,
SQLITE_DESERIALIZE_FREEONCLOSE + SQLITE_DESERIALIZE_RESIZEABLE);
return 0;
}

#endif

#ifdef SQLITE_ENABLE_SESSION

/*
Expand Down Expand Up @@ -2602,6 +2621,7 @@ static const struct {

static const luaL_Reg dblib[] = {
{"isopen", db_isopen },
{"readonly", db_readonly },
{"last_insert_rowid", db_last_insert_rowid },
{"changes", db_changes },
{"total_changes", db_total_changes },
Expand Down Expand Up @@ -2634,9 +2654,6 @@ static const luaL_Reg dblib[] = {
{"close", db_close },
{"close_vm", db_close_vm },

{"serialize", db_serialize },
{"deserialize", db_deserialize },

#ifdef SQLITE_ENABLE_SESSION
{"create_session", db_create_session },
{"create_rebaser", db_create_rebaser },
Expand All @@ -2646,6 +2663,11 @@ static const luaL_Reg dblib[] = {
{"iterate_changeset", db_iterate_changeset },
#endif

#ifdef SQLITE_ENABLE_DESERIALIZE
{"serialize", db_serialize },
{"deserialize", db_deserialize },
#endif

{"__tostring", db_tostring },
{"__gc", db_gc },

Expand All @@ -2654,6 +2676,7 @@ static const luaL_Reg dblib[] = {

static const luaL_Reg vmlib[] = {
{"isopen", dbvm_isopen },
{"readonly", dbvm_readonly },

{"step", dbvm_step },
{"reset", dbvm_reset },
Expand Down