Skip to content

Commit

Permalink
hack query planner
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani committed Dec 11, 2024
1 parent 73d288e commit 9b830d2
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 132 deletions.
63 changes: 57 additions & 6 deletions src/include/wvlet_extension.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#include "duckdb/function/table_function.hpp"
#include "duckdb/main/client_context.hpp"

// Declare the external wvlet_compile_query function
extern "C" {
int wvlet_compile_main(const char*);
const char* wvlet_compile_query(const char* json_query);
extern int ScalaNativeInit(void);
extern int wvlet_compile_main(const char*);
extern const char* wvlet_compile_query(const char* json_query); // Changed from wvlet_compile_compile
}

namespace duckdb {
Expand All @@ -34,11 +34,62 @@ struct WvletScriptFunction {
vector<unique_ptr<Expression>> &arguments);
};


class WvletExtension : public Extension {
public:
void Load(DuckDB &db) override;
std::string Name() override;
std::string Version() const override;
void Load(DuckDB &db) override;
std::string Name() override { return "wvlet"; }
};

BoundStatement wvlet_bind(ClientContext &context, Binder &binder,
OperatorExtensionInfo *info, SQLStatement &statement);

struct WvletOperatorExtension : public OperatorExtension {
WvletOperatorExtension() : OperatorExtension() { Bind = wvlet_bind; }

std::string GetName() override { return "wvlet"; }

unique_ptr<LogicalExtensionOperator>
Deserialize(Deserializer &deserializer) override {
throw InternalException("wvlet operator should not be serialized");
}
};

ParserExtensionParseResult wvlet_parse(ParserExtensionInfo *,
const std::string &query);

ParserExtensionPlanResult wvlet_plan(ParserExtensionInfo *, ClientContext &,
unique_ptr<ParserExtensionParseData>);

struct WvletParserExtension : public ParserExtension {
WvletParserExtension() : ParserExtension() {
parse_function = wvlet_parse;
plan_function = wvlet_plan;
}
};

struct WvletParseData : ParserExtensionParseData {
unique_ptr<SQLStatement> statement;

unique_ptr<ParserExtensionParseData> Copy() const override {
return make_uniq_base<ParserExtensionParseData, WvletParseData>(
statement->Copy());
}

virtual string ToString() const override { return "WvletParseData"; }

WvletParseData(unique_ptr<SQLStatement> statement)
: statement(std::move(statement)) {}
};

class WvletState : public ClientContextState {
public:
explicit WvletState(unique_ptr<ParserExtensionParseData> parse_data)
: parse_data(std::move(parse_data)) {}

void QueryEnd() override { parse_data.reset(); }

unique_ptr<ParserExtensionParseData> parse_data;
};

} // namespace duckdb
179 changes: 53 additions & 126 deletions src/wvlet_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
#define DUCKDB_EXTENSION_MAIN

#include "wvlet_extension.hpp"
#include "duckdb.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/parser/parser.hpp"
#include "duckdb/parser/statement/extension_statement.hpp"
#include "duckdb/common/string_util.hpp"
#include "duckdb/function/table_function.hpp"
#include "duckdb/main/extension_util.hpp"
#include <duckdb/parser/parsed_data/create_table_function_info.hpp>
#include <fstream>
#include <sstream>
#include <stdexcept>
#include <codecvt>
#include <string>

#ifdef __cplusplus
extern "C" {
#endif
extern int ScalaNativeInit(void);

extern int wvlet_compile_main(const char*);
extern const char* wvlet_compile_compile(const char*);

extern const char* wvlet_compile_query(const char* json_query);
#ifdef __cplusplus
}
#endif


namespace duckdb {

// EXPERIMENT INIT
Expand All @@ -46,146 +50,69 @@ bool InitializeWvletRuntime() {
}
}

void WvletScriptFunction::ParseWvletScript(DataChunk &args, ExpressionState &state, Vector &result) {
auto &input_vector = args.data[0];
auto input = FlatVector::GetData<string_t>(input_vector);

for (idx_t i = 0; i < args.size(); i++) {
string query = input[i].GetString();
std::string json = "[\"-q\", \"" + query + "\"]";

const char* sql_result = wvlet_compile_query(json.c_str());

if (!sql_result || strlen(sql_result) == 0) {
throw std::runtime_error("Failed to compile wvlet script");
}

FlatVector::GetData<string_t>(result)[i] = StringVector::AddString(result, sql_result);
}

result.Verify(args.size());
}

unique_ptr<FunctionData> WvletScriptFunction::Bind(ClientContext &context, ScalarFunction &bound_function,
vector<unique_ptr<Expression>> &arguments) {
return nullptr;
}

static unique_ptr<FunctionData> WvletBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
auto result = make_uniq<WvletBindData>();
result->query = input.inputs[0].GetValue<string>();

std::string json = "[\"-q\", \"" + result->query + "\"]";

wvlet_compile_main(json.c_str());
const char* sql_result = wvlet_compile_query(json.c_str());

if (!sql_result || strlen(sql_result) == 0) {
throw std::runtime_error("Failed to compile wvlet script");
}

result->query = std::string(sql_result);

// Create a temporary connection to execute the query and get the schema
Connection conn(*context.db);
auto result_set = conn.Query(result->query);

if (result_set->HasError()) {
throw std::runtime_error(result_set->GetError());
}

// Get the types and names of the columns from the result set
for (auto &column : result_set->types) {
return_types.push_back(column);
}
for (auto &name : result_set->names) {
names.push_back(name);
}

return std::move(result);
}

static void WvletFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
auto &bind_data = data_p.bind_data->Cast<WvletBindData>();

if (!bind_data.query_result) {
throw std::runtime_error("query_result is nullptr");
}

if (!bind_data.query_result->initialized) {

try {
Connection conn(*context.db);

auto result = conn.Query(bind_data.query);

if (result->HasError()) {
throw std::runtime_error(result->GetError());
}

bind_data.query_result->result = std::move(result);
bind_data.query_result->initialized = true;

auto &types = bind_data.query_result->result->types;

output.Destroy(); // Clean up the existing chunk
output.Initialize(context, types); // Initialize with actual types
} catch (const std::exception &e) {
throw;
}
}

auto chunk = bind_data.query_result->result->Fetch();

if (!chunk || chunk->size() == 0) {
output.SetCardinality(0);
return;
}

output.Reference(*chunk);
output.SetCardinality(chunk->size());
}

static void LoadInternal(DatabaseInstance &instance) {
auto wvlet_fun = ScalarFunction("wvlet", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
WvletScriptFunction::ParseWvletScript,
WvletScriptFunction::Bind);
ExtensionUtil::RegisterFunction(instance, wvlet_fun);

TableFunction wvlet_func("wvlet", {LogicalType::VARCHAR}, WvletFunction, WvletBind);
ExtensionUtil::RegisterFunction(instance, wvlet_func);
auto &config = DBConfig::GetConfig(instance);
// Register the custom Wvlet parser extension
WvletParserExtension wvlet_parser;
config.parser_extensions.push_back(wvlet_parser);
// No operator extensions added for now
}

void WvletExtension::Load(DuckDB &db) {
LoadInternal(*db.instance);
// EXPERIMENT
if (!InitializeWvletRuntime()) {
LoadInternal(*db.instance);
if (!InitializeWvletRuntime()) {
throw std::runtime_error("Failed to initialize Wvlet runtime");
}
}
}

std::string WvletExtension::Name() {
return "wvlet";
ParserExtensionParseResult wvlet_parse(ParserExtensionInfo *, const std::string &query) {
// Directly pass through the query with no transformation
auto sql_query = query;

std::string json = "[\"-q\", \"" + query + "\"]";
std::cout << "in: " << json << "\n";
wvlet_compile_main(json.c_str());
std::cout << "in2: " << json.c_str() << "\n";
const char* sql_result = wvlet_compile_query(json.c_str());
std::cout << "out: " << sql_result << "\n";
if (!sql_result || strlen(sql_result) == 0) {
throw std::runtime_error("Failed to compile wvlet script");
}

Parser parser; // Parse the SQL query
parser.ParseQuery(sql_query);
auto statements = std::move(parser.statements);

return ParserExtensionParseResult(
make_uniq_base<ParserExtensionParseData, WvletParseData>(
std::move(statements[0])));
}

std::string WvletExtension::Version() const {
#ifdef EXT_VERSION_WVLET
return EXT_VERSION_WVLET;
#else
return "";
#endif
ParserExtensionPlanResult wvlet_plan(ParserExtensionInfo *, ClientContext &context,
unique_ptr<ParserExtensionParseData> parse_data) {
// Placeholder plan result
return ParserExtensionPlanResult();
}

BoundStatement wvlet_bind(ClientContext &context, Binder &binder,
OperatorExtensionInfo *info, SQLStatement &statement) {
// Directly return a no-op bound statement
return {};
}

} // namespace duckdb

extern "C" {

DUCKDB_EXTENSION_API void wvlet_init(duckdb::DatabaseInstance &db) {
duckdb::DuckDB db_wrapper(db);
db_wrapper.LoadExtension<duckdb::WvletExtension>();
LoadInternal(db);
}

DUCKDB_EXTENSION_API const char *wvlet_version() {
return duckdb::DuckDB::LibraryVersion();
return duckdb::DuckDB::LibraryVersion();
}
}

#ifndef DUCKDB_EXTENSION_MAIN
#error DUCKDB_EXTENSION_MAIN not defined
#endif

0 comments on commit 9b830d2

Please sign in to comment.