Skip to content

Commit

Permalink
Reorganized user-defined function proxy, and used a provenance fence
Browse files Browse the repository at this point in the history
  • Loading branch information
trueqbit committed Nov 15, 2023
1 parent f60f2a1 commit 42999ac
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 260 deletions.
81 changes: 0 additions & 81 deletions dev/function.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#pragma once

#include <sqlite3.h>
#include <type_traits> // std::is_member_function_pointer, std::remove_const, std::decay, std::is_same, std::false_type, std::true_type
#include <string> // std::string
#include <tuple> // std::tuple, std::tuple_size, std::tuple_element
#include <functional> // std::function
#include <algorithm> // std::min
#include <utility> // std::move, std::forward

Expand All @@ -21,84 +18,6 @@ namespace sqlite_orm {
class pointer_binding;

namespace internal {

/*
* Stores type-erased information about a user-defined scalar or aggregate function:
* - name and argument count
* - function pointers for construction/destruction
* - function dispatch
* - allocated memory and location
*
* As such, it also serves as a context for aggregation operations instead of using `sqlite3_aggregate_context()`.
*/
struct udf_proxy_base {
using func_call_fn_t = void (*)(void* udfHandle,
sqlite3_context* context,
int argsCount,
sqlite3_value** values);
using final_call_fn_t = void (*)(void* udfHandle, sqlite3_context* context);

struct destruct_only_deleter {
template<class F>
void operator()(F* f) const noexcept {
f->~F();
}
};

std::string name;
int argumentsCount;
std::function<void*(void* place)> constructAt;
xdestroy_fn_t destroy;
func_call_fn_t func;
final_call_fn_t finalAggregateCall;
// flag whether the UDF has been constructed at `udfHandle`;
// necessary for aggregation operations
bool constructed;
// pointer to memory for UDF in derived proxy class
void* const udfHandle;
};

template<class UDF>
struct scalar_udf_proxy : udf_proxy_base {
// allocated memory for user-defined function
alignas(UDF) char udfMem[sizeof(UDF)];

scalar_udf_proxy(std::string name,
int argumentsCount,
std::function<void*(void* place)> constructAt,
xdestroy_fn_t destroy,
func_call_fn_t run) :
udf_proxy_base{std::move(name),
argumentsCount,
std::move(constructAt),
destroy,
run,
nullptr,
false,
udfMem} {}
};

template<class UDF>
struct aggregate_udf_proxy : udf_proxy_base {
// allocated memory for user-defined function
alignas(UDF) char udfMem[sizeof(UDF)];

aggregate_udf_proxy(std::string name,
int argumentsCount,
std::function<void*(void* place)> constructAt,
xdestroy_fn_t destroy,
func_call_fn_t step,
final_call_fn_t finalCall) :
udf_proxy_base{std::move(name),
argumentsCount,
std::move(constructAt),
destroy,
step,
finalCall,
false,
udfMem} {}
};

template<class F>
using scalar_call_function_t = decltype(&F::operator());

Expand Down
64 changes: 15 additions & 49 deletions dev/storage_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "arg_values.h"
#include "util.h"
#include "xdestroy_handling.h"
#include "udf_proxy.h"
#include "serializing_util.h"

namespace sqlite_orm {
Expand Down Expand Up @@ -263,15 +264,15 @@ namespace sqlite_orm {
constexpr auto argsCount = std::is_same<args_tuple, std::tuple<arg_values>>::value
? -1
: int(std::tuple_size<args_tuple>::value);
this->scalarFunctions.push_back(std::make_unique<scalar_udf_proxy<F>>(
this->scalarFunctions.push_back(std::make_unique<udf_proxy_veneer<F>>(
std::move(name),
argsCount,
/* constructAt = */
[](void* place) -> void* {
return new(place) F();
[](void* location) -> void* {
return new(location) F();
},
/* destroy = */
obtain_xdestroy_for<F>(udf_proxy_base::destruct_only_deleter{}),
obtain_xdestroy_for<F>(udf_proxy::destruct_only_deleter{}),
/* call = */
[](void* udfHandle, sqlite3_context* context, int argsCount, sqlite3_value** values) {
F& udf = *static_cast<F*>(udfHandle);
Expand Down Expand Up @@ -330,15 +331,15 @@ namespace sqlite_orm {
constexpr auto argsCount = std::is_same<args_tuple, std::tuple<arg_values>>::value
? -1
: int(std::tuple_size<args_tuple>::value);
this->aggregateFunctions.push_back(std::make_unique<aggregate_udf_proxy<F>>(
this->aggregateFunctions.push_back(std::make_unique<udf_proxy_veneer<F>>(
std::move(name),
argsCount,
/* constructAt = */
[](void* place) -> void* {
return new(place) F();
[](void* location) -> void* {
return new(location) F();
},
/* destroy = */
obtain_xdestroy_for<F>(udf_proxy_base::destruct_only_deleter{}),
obtain_xdestroy_for<F>(udf_proxy::destruct_only_deleter{}),
/* step = */
[](void* udfHandle, sqlite3_context*, int argsCount, sqlite3_value** values) {
F& udf = *static_cast<F*>(udfHandle);
Expand Down Expand Up @@ -669,9 +670,9 @@ namespace sqlite_orm {
}

void delete_function_impl(const std::string& name,
std::vector<std::unique_ptr<udf_proxy_base>>& functions) const {
std::vector<std::unique_ptr<udf_proxy>>& functions) const {
#if __cpp_lib_ranges >= 201911L
auto it = std::ranges::find(functions, name, &udf_proxy_base::name);
auto it = std::ranges::find(functions, name, &udf_proxy::name);
#else
auto it = std::find_if(functions.begin(), functions.end(), [&name](auto& udfProxy) {
return udfProxy->name == name;
Expand Down Expand Up @@ -699,7 +700,7 @@ namespace sqlite_orm {
}
}

static void try_to_create_scalar_function(sqlite3* db, udf_proxy_base& udfProxy) {
static void try_to_create_scalar_function(sqlite3* db, udf_proxy& udfProxy) {
int rc = sqlite3_create_function_v2(db,
udfProxy.name.c_str(),
udfProxy.argumentsCount,
Expand All @@ -714,7 +715,7 @@ namespace sqlite_orm {
}
}

static void try_to_create_aggregate_function(sqlite3* db, udf_proxy_base& udfProxy) {
static void try_to_create_aggregate_function(sqlite3* db, udf_proxy& udfProxy) {
int rc = sqlite3_create_function(db,
udfProxy.name.c_str(),
udfProxy.argumentsCount,
Expand All @@ -728,41 +729,6 @@ namespace sqlite_orm {
}
}

static void
aggregate_function_step_callback(sqlite3_context* context, int argsCount, sqlite3_value** values) {
auto* udfProxy = static_cast<udf_proxy_base*>(sqlite3_user_data(context));
if(!udfProxy->constructed) {
if(udfProxy->argumentsCount != -1 && udfProxy->argumentsCount != argsCount) {
throw std::system_error{orm_error_code::arguments_count_does_not_match};
}
udfProxy->constructAt(udfProxy->udfHandle);
udfProxy->constructed = true;
}
udfProxy->func(udfProxy->udfHandle, context, argsCount, values);
}

static void aggregate_function_final_callback(sqlite3_context* context) {
auto* udfProxy = static_cast<udf_proxy_base*>(sqlite3_user_data(context));
// note: it is possible that the 'step' function was never called
if(!udfProxy->constructed) {
udfProxy->constructAt(udfProxy->udfHandle);
udfProxy->constructed = true;
}
udfProxy->finalAggregateCall(udfProxy->udfHandle, context);
udfProxy->destroy(udfProxy->udfHandle);
udfProxy->constructed = false;
}

static void scalar_function_callback(sqlite3_context* context, int argsCount, sqlite3_value** values) {
auto* udfProxy = static_cast<udf_proxy_base*>(sqlite3_user_data(context));
if(udfProxy->argumentsCount != -1 && udfProxy->argumentsCount != argsCount) {
throw std::system_error{orm_error_code::arguments_count_does_not_match};
}
const std::unique_ptr<void, xdestroy_fn_t> udfHandle{udfProxy->constructAt(udfProxy->udfHandle),
udfProxy->destroy};
udfProxy->func(udfHandle.get(), context, argsCount, values);
}

std::string current_time(sqlite3* db) {
std::string result;
perform_exec(db, "SELECT CURRENT_TIME", extract_single_value<std::string>, &result);
Expand Down Expand Up @@ -851,8 +817,8 @@ namespace sqlite_orm {
std::map<std::string, collating_function> collatingFunctions;
const int cachedForeignKeysCount;
std::function<int(int)> _busy_handler;
std::vector<std::unique_ptr<udf_proxy_base>> scalarFunctions;
std::vector<std::unique_ptr<udf_proxy_base>> aggregateFunctions;
std::vector<std::unique_ptr<udf_proxy>> scalarFunctions;
std::vector<std::unique_ptr<udf_proxy>> aggregateFunctions;
};
}
}
124 changes: 124 additions & 0 deletions dev/udf_proxy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#pragma once

#include <sqlite3.h>
#include <new> // std::launder
#include <string> // std::string
#include <functional> // std::function
#include <utility> // std::move

#include "error_code.h"

namespace sqlite_orm {
namespace internal {
/*
* Stores type-erased information in relation to a user-defined scalar or aggregate function object:
* - name and argument count
* - function pointers for construction/destruction
* - function dispatch
* - allocated memory and location
*
* As such, it also serves as a context for aggregation operations instead of using `sqlite3_aggregate_context()`.
*/
struct udf_proxy {
using func_call_fn_t = void (*)(void* udfHandle,
sqlite3_context* context,
int argsCount,
sqlite3_value** values);
using final_call_fn_t = void (*)(void* udfHandle, sqlite3_context* context);

struct destruct_only_deleter {
template<class F>
void operator()(F* f) const noexcept {
f->~F();
}
};

std::string name;
int argumentsCount;
std::function<void*(void* location)> constructAt;
xdestroy_fn_t destroy;
func_call_fn_t func;
final_call_fn_t finalAggregateCall;
// flag whether the UDF has been constructed at `udfHandle`;
// necessary for aggregation operations
bool constructed;
// pointer to memory for UDF in derived proxy veneer
void* const udfHandle;
};

/*
* A veneer to `udf_proxy` that provides memory space for a user-defined function object.
*
* Note: it must be a veneer, i.e. w/o any non-trivially destructible member variables.
*/
template<class UDF>
struct udf_proxy_veneer : udf_proxy {
// allocated memory for user-defined function
alignas(UDF) char storage[sizeof(UDF)];

udf_proxy_veneer(std::string name,
int argumentsCount,
std::function<void*(void* location)> constructAt,
xdestroy_fn_t destroy,
func_call_fn_t run) :
udf_proxy {
std::move(name), argumentsCount, std::move(constructAt), destroy, run, nullptr, false,
#if __cpp_lib_launder >= 201606L
std::launder((UDF*)this->storage)
#else
storage
#endif
}
{}

udf_proxy_veneer(std::string name,
int argumentsCount,
std::function<void*(void* location)> constructAt,
xdestroy_fn_t destroy,
func_call_fn_t step,
final_call_fn_t finalCall) :
udf_proxy {
std::move(name), argumentsCount, std::move(constructAt), destroy, step, finalCall, false,
#if __cpp_lib_launder >= 201606L
std::launder((UDF*)this->storage)
#else
storage
#endif
}
{}
};

inline void scalar_function_callback(sqlite3_context* context, int argsCount, sqlite3_value** values) {
udf_proxy* proxy = static_cast<udf_proxy*>(sqlite3_user_data(context));
if(proxy->argumentsCount != -1 && proxy->argumentsCount != argsCount) {
throw std::system_error{orm_error_code::arguments_count_does_not_match};
}
const std::unique_ptr<void, xdestroy_fn_t> udfHandle{proxy->constructAt(proxy->udfHandle), proxy->destroy};
proxy->func(udfHandle.get(), context, argsCount, values);
}

inline void aggregate_function_step_callback(sqlite3_context* context, int argsCount, sqlite3_value** values) {
udf_proxy* proxy = static_cast<udf_proxy*>(sqlite3_user_data(context));
if(!proxy->constructed) {
if(proxy->argumentsCount != -1 && proxy->argumentsCount != argsCount) {
throw std::system_error{orm_error_code::arguments_count_does_not_match};
}
proxy->constructAt(proxy->udfHandle);
proxy->constructed = true;
}
proxy->func(proxy->udfHandle, context, argsCount, values);
}

inline void aggregate_function_final_callback(sqlite3_context* context) {
udf_proxy* proxy = static_cast<udf_proxy*>(sqlite3_user_data(context));
// note: it is possible that the 'step' function was never called
if(!proxy->constructed) {
proxy->constructAt(proxy->udfHandle);
proxy->constructed = true;
}
proxy->finalAggregateCall(proxy->udfHandle, context);
proxy->destroy(proxy->udfHandle);
proxy->constructed = false;
}
}
}
Loading

0 comments on commit 42999ac

Please sign in to comment.