Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parser extension #3674

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ example:

extension-test-build:
$(call run-cmake-release, \
-DBUILD_EXTENSIONS="httpfs;duckdb;postgres" \
-DBUILD_EXTENSIONS="httpfs;duckdb;postgres;sample" \
-DBUILD_EXTENSION_TESTS=TRUE \
-DENABLE_ADDRESS_SANITIZER=TRUE \
)
Expand All @@ -175,13 +175,13 @@ extension-test: extension-test-build

extension-debug:
$(call run-cmake-debug, \
-DBUILD_EXTENSIONS="httpfs;duckdb;postgres" \
-DBUILD_EXTENSIONS="httpfs;duckdb;postgres;sample" \
-DBUILD_KUZU=FALSE \
)

extension-release:
$(call run-cmake-release, \
-DBUILD_EXTENSIONS="httpfs;duckdb;postgres" \
-DBUILD_EXTENSIONS="httpfs;duckdb;postgres;sample" \
-DBUILD_KUZU=FALSE \
)

Expand Down
4 changes: 4 additions & 0 deletions extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ if ("postgres" IN_LIST BUILD_EXTENSIONS)
endif()
endif()

if ("sample" IN_LIST BUILD_EXTENSIONS)
add_subdirectory(sample)
endif()

if (${BUILD_EXTENSION_TESTS})
include_directories(${CMAKE_SOURCE_DIR}/third_party/spdlog)
add_definitions(-DTEST_FILES_DIR="extension")
Expand Down
Empty file.
48 changes: 48 additions & 0 deletions extension/sample/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
include_directories(
${PROJECT_SOURCE_DIR}/src/include
src/include)

add_library(sample
SHARED
src/sample_extension.cpp)

set_target_properties(sample PROPERTIES
OUTPUT_NAME sample
PREFIX "lib"
SUFFIX ".kuzu_extension"
)

set_target_properties(sample
PROPERTIES
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build"
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build"
)

if (WIN32)
# On windows, there is no dynamic lookup available, so it's not
# possible to generically look for symbols on library load. There are
# two main alternatives to statically linking kuzu, neither of which is
# appealing:
# 1. Link against the shared library. This works well assuming
# the DLL is locatable, but this assumption isn't valid for users
# of kuzu.exe.
# 2. Link against the executable (kuzu.exe). This is
# strange but works well for kuzu.exe. However, it forces
# users who are embedding kuzu in their application to recompile
# the extension _and_ export the symbols for the extension to
# locate on load.
# We choose the simplest option. Windows isn't known
# for its small libraries anyways...
# Future work could make it possible to embed extension into kuzu,
# which would help fix this problem.
target_link_libraries(sample PRIVATE kuzu)
endif()

if (APPLE)
set_target_properties(sample PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif ()

if (${BUILD_EXTENSION_TESTS})
add_subdirectory(test)
endif()
23 changes: 23 additions & 0 deletions extension/sample/src/include/sample_extension.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "common/exception/parser.h"
#include "extension/extension.h"
#include "extension/extension_clause.h"
#include "extension/extension_clause_handler.h"
#include "parser/statement.h"

namespace kuzu {
namespace sample {

class SampleExtension : public extension::Extension {
public:
static void load(main::ClientContext* context);
};

class SampleClauseHandler : public extension::ExtensionClauseHandler {
public:
SampleClauseHandler();
};

} // namespace sample
} // namespace kuzu
188 changes: 188 additions & 0 deletions extension/sample/src/sample_extension.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#include "sample_extension.h"

#include "binder/binder.h"
#include "catalog/catalog.h"
#include "common/exception/binder.h"
#include "main/client_context.h"
#include "main/database.h"
#include "storage/storage_manager.h"

namespace kuzu {
namespace sample {

enum class SampleTableType : uint8_t { NODE = 0, REL = 1 };

class Sample final : public extension::ExtensionClause {
public:
explicit Sample(std::string tableName)
: extension::ExtensionClause{SampleClauseHandler{}}, tableName{std::move(tableName)} {};

std::string getTableName() const { return tableName; }

private:
std::string tableName;
};

class BoundSample final : public extension::BoundExtensionClause {
public:
explicit BoundSample(std::string tableName)
: extension::BoundExtensionClause{SampleClauseHandler{},
binder::BoundStatementResult::createSingleStringColumnResult()},
tableName{std::move(tableName)} {};

std::string getTableName() const { return tableName; }

private:
std::string tableName;
};

class LogicalSample final : public extension::LogicalExtensionClause {
public:
explicit LogicalSample(std::string tableName,
std::shared_ptr<binder::Expression> outputExpression)
: extension::LogicalExtensionClause{SampleClauseHandler{}}, tableName{std::move(tableName)},
outputExpression{std::move(outputExpression)} {}

void computeFactorizedSchema() override {
createEmptySchema();
auto groupPos = schema->createGroup();
schema->insertToGroupAndScope(outputExpression, groupPos);
schema->setGroupAsSingleState(groupPos);
}
void computeFlatSchema() override {
createEmptySchema();
schema->createGroup();
schema->insertToGroupAndScope(outputExpression, 0);
}

std::string getExpressionsForPrinting() const override { return "sample extension"; }

std::string getTableName() const { return tableName; }

std::shared_ptr<binder::Expression> getOutputExpression() const { return outputExpression; }

std::unique_ptr<LogicalOperator> copy() final {
return std::make_unique<LogicalSample>(tableName, outputExpression);
}

private:
std::string tableName;
std::shared_ptr<binder::Expression> outputExpression;
};

struct SampleInfo {
std::string tableName;
processor::DataPos outputPos;
};

class PhysicalSample : public extension::PhysicalExtensionClause {
public:
PhysicalSample(SampleInfo sampleInfo, uint32_t id,
std::unique_ptr<processor::OPPrintInfo> printInfo)
: PhysicalExtensionClause{id, std::move(printInfo)}, sampleInfo{std::move(sampleInfo)} {}

bool isSource() const override { return true; }
bool isParallel() const final { return false; }

void initLocalStateInternal(processor::ResultSet* resultSet,
processor::ExecutionContext* /*context*/) override {
outputVector = resultSet->getValueVector(sampleInfo.outputPos).get();
}

bool getNextTuplesInternal(processor::ExecutionContext* context) override {
if (hasExecuted) {
return false;
}
hasExecuted = true;
std::vector<binder::PropertyInfo> propertyInfos;
propertyInfos.push_back(binder::PropertyInfo{"name", common::LogicalType::STRING()});
binder::BoundCreateTableInfo info{common::TableType::NODE, sampleInfo.tableName,
common::ConflictAction::ON_CONFLICT_THROW,
std::make_unique<binder::BoundExtraCreateNodeTableInfo>(0, std::move(propertyInfos))};
auto catalog = context->clientContext->getCatalog();
auto newTableID = catalog->createTableSchema(context->clientContext->getTx(), info);
auto storageManager = context->clientContext->getStorageManager();
storageManager->createTable(newTableID, catalog, context->clientContext);
outputVector->setValue(0, std::string("New table has been created by extension."));
outputVector->state->getSelVectorUnsafe().setSelSize(1);
metrics->numOutputTuple.increase(1);
return true;
}

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<PhysicalSample>(sampleInfo, id, printInfo->copy());
}

private:
SampleInfo sampleInfo;
common::ValueVector* outputVector;
bool hasExecuted = false;
};

void SampleExtension::load(main::ClientContext* context) {
auto db = context->getDatabase();
db->registerExtensionClauseHandler("sample", std::make_unique<SampleClauseHandler>());
}

static std::vector<std::shared_ptr<parser::Statement>> parseFunction(std::string_view query) {
auto tableName = std::string(query).substr(std::string(query).find_last_of(' ') + 1);
if (std::string(query).find("NODE") == std::string::npos) {
throw common::ParserException{"Invalid query: " + std::string(query)};
}
return {std::make_shared<Sample>(tableName)};
}

static std::unique_ptr<binder::BoundStatement> bindFunction(
const extension::ExtensionClause& statement, const binder::Binder& binder) {
auto& sample = statement.constCast<Sample>();
auto context = binder.getClientContext();
if (context->getCatalog()->containsTable(context->getTx(), sample.getTableName())) {
throw common::BinderException{
common::stringFormat("Table {} already exists.", sample.getTableName())};
}
return std::make_unique<BoundSample>(sample.getTableName());
}

static void planFunction(planner::LogicalPlan& plan,
const extension::BoundExtensionClause& statement) {
auto& boundSample = statement.constCast<BoundSample>();
auto op = std::make_unique<LogicalSample>(boundSample.getTableName(),
statement.getStatementResult()->getSingleColumnExpr());
plan.setLastOperator(std::move(op));
}

static std::unique_ptr<processor::PhysicalOperator> mapFunction(
const extension::LogicalExtensionClause& op, uint32_t operatorID) {
auto& logicalSample = op.constCast<LogicalSample>();
auto outSchema = logicalSample.getSchema();
auto outputExpression = logicalSample.getOutputExpression();
auto dataPos = processor::DataPos(outSchema->getExpressionPos(*outputExpression));
SampleInfo sampleInfo{logicalSample.getTableName(), dataPos};
return std::make_unique<PhysicalSample>(sampleInfo, operatorID,
std::make_unique<processor::OPPrintInfo>(logicalSample.getExpressionsForPrinting()));
}

static bool defaultReadWriteAnalyzer(const extension::ExtensionClause& /*statement*/) {
return false;
}

SampleClauseHandler::SampleClauseHandler()
: extension::ExtensionClauseHandler{parseFunction, bindFunction, planFunction, mapFunction} {
readWriteAnalyzer = defaultReadWriteAnalyzer;
}

} // namespace sample
} // namespace kuzu

extern "C" {
// Because we link against the static library on windows, we implicitly inherit KUZU_STATIC_DEFINE,
// which cancels out any exporting, so we can't use KUZU_API.
#if defined(_WIN32)
#define INIT_EXPORT __declspec(dllexport)
#else
#define INIT_EXPORT __attribute__((visibility("default")))
#endif
INIT_EXPORT void init(kuzu::main::ClientContext* context) {
kuzu::sample::SampleExtension::load(context);
}
}
Empty file.
18 changes: 18 additions & 0 deletions extension/sample/test/test_files/sample.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
-DATASET CSV empty

--

-CASE SampleTest
-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/sample/build/libsample.kuzu_extension"
---- ok
-STATEMENT NEW NODE TABLE UW
---- 1
New table has been created by extension.
-STATEMENT MATCH (u:UW) return count(u)
---- 1
0
-STATEMENT NEW REL TABLE UW
---- error
Parser exception: Kuzu default parser throws an exception: "Parser exception: extraneous input 'NEW' expecting {ALTER, ATTACH, BEGIN, CALL, COMMENT, COMMIT, COMMIT_SKIP_CHECKPOINT, COPY, CREATE, DELETE, DETACH, DROP, EXPLAIN, EXPORT, IMPORT, INSTALL, LOAD, MATCH, MERGE, OPTIONAL, PROFILE, PROJECT, RETURN, ROLLBACK, ROLLBACK_SKIP_CHECKPOINT, SET, UNWIND, USE, WITH, SP} (line: 1, offset: 0)
"NEW REL TABLE UW"
^^^", and none of the extensions can compile the query.
1 change: 1 addition & 0 deletions src/binder/bind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_library(
bind_attach_database.cpp
bind_create_macro.cpp
bind_project_graph.cpp
bind_extension_clause.cpp
bind_ddl.cpp
bind_detach_database.cpp
bind_explain.cpp
Expand Down
3 changes: 2 additions & 1 deletion src/binder/bind/bind_export_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ bool Binder::bindExportTableData(ExportedTableData& tableData, const TableCatalo
if (!bindExportQuery(exportQuery, entry, catalog, tx)) {
return false;
}
auto parsedStatement = Parser::parseQuery(exportQuery);
auto parser = parser::Parser(clientContext->getDatabase());
auto parsedStatement = parser.parseQuery(exportQuery);
KU_ASSERT(parsedStatement.size() == 1);
auto parsedQuery = parsedStatement[0]->constPtrCast<RegularQuery>();
auto query = bindQuery(*parsedQuery);
Expand Down
14 changes: 14 additions & 0 deletions src/binder/bind/bind_extension_clause.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "binder/binder.h"
#include "extension/extension_clause.h"

namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindExtensionClause(
const parser::Statement& statement) const {
auto& extensionClause = statement.constCast<extension::ExtensionClause>();
return extensionClause.getExtensionClauseHandler().bindFunc(extensionClause, *this);
}

} // namespace binder
} // namespace kuzu
3 changes: 2 additions & 1 deletion src/binder/bind/bind_import_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ std::unique_ptr<BoundStatement> Binder::bindImportDatabaseClause(const Statement
auto copyQuery =
getQueryFromFile(fs, boundFilePath, ImportDBConstants::COPY_NAME, clientContext);
if (!copyQuery.empty()) {
auto parsedStatements = Parser::parseQuery(copyQuery);
auto parser = parser::Parser{clientContext->getDatabase()};
auto parsedStatements = parser.parseQuery(copyQuery);
for (auto& parsedStatement : parsedStatements) {
KU_ASSERT(parsedStatement->getStatementType() == StatementType::COPY_FROM);
auto copyFromStatement =
Expand Down
3 changes: 3 additions & 0 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
case StatementType::USE_DATABASE: {
boundStatement = bindUseDatabase(statement);
} break;
case StatementType::EXTENSION_CLAUSE: {
boundStatement = bindExtensionClause(statement);
} break;
default: {
KU_UNREACHABLE;
}
Expand Down
3 changes: 3 additions & 0 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ void BoundStatementVisitor::visit(const BoundStatement& statement) {
case StatementType::USE_DATABASE: {
visitUseDatabase(statement);
} break;
case StatementType::EXTENSION_CLAUSE: {
visitExtensionClause(statement);
} break;
default:
KU_UNREACHABLE;
}
Expand Down
3 changes: 1 addition & 2 deletions src/function/list/list_slice_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ struct ListSlice {
auto startIdx = begin;
auto endIdx = end;
normalizeIndices(startIdx, endIdx, str.len);
SubStr::operation(str, startIdx, std::min(endIdx - startIdx + 1, str.len - startIdx + 1),
result, resultValueVector);
SubStr::operation(str, startIdx, endIdx - startIdx + 1, result, resultValueVector);
}

private:
Expand Down
Loading
Loading