diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index b3588fbf6..94e6b7b35 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -361,6 +361,7 @@ if(NOT EMSCRIPTEN) ${CMAKE_SOURCE_DIR}/test/json_dataview_test.cc ${CMAKE_SOURCE_DIR}/test/memory_filesystem_test.cc ${CMAKE_SOURCE_DIR}/test/parquet_test.cc + ${CMAKE_SOURCE_DIR}/test/prepared_statement_test.cc ${CMAKE_SOURCE_DIR}/test/readahead_buffer_test.cc ${CMAKE_SOURCE_DIR}/test/tablenames_test.cc ${CMAKE_SOURCE_DIR}/test/web_filesystem_test.cc diff --git a/lib/src/webdb.cc b/lib/src/webdb.cc index 0161443aa..3b9d1aa2e 100644 --- a/lib/src/webdb.cc +++ b/lib/src/webdb.cc @@ -301,6 +301,39 @@ arrow::Result WebDB::Connection::CreatePreparedStatement(std::string_vie } } +static arrow::Result> BuildParameters(const rapidjson::Document::ConstArray& source) { + duckdb::vector values; + size_t index = 0; + + for (const auto& v : source) { + if (v.IsLosslessDouble()) + values.emplace_back(v.GetDouble()); + else if (v.IsString()) + // Use GetStringLenght otherwise null bytes will be counted as terminators + values.emplace_back(string_t(v.GetString(), v.GetStringLength())); + else if (v.IsNull()) + values.emplace_back(nullptr); + else if (v.IsBool()) + values.emplace_back(v.GetBool()); + else if (v.IsArray()) { + auto item_values = BuildParameters(v.GetArray()); + if (!item_values.ok()) { + return item_values; + } + values.emplace_back(duckdb::Value::LIST(std::move(*item_values))); + } else + return arrow::Status{arrow::StatusCode::Invalid, + "Invalid column type encountered for argument " + std::to_string(index)}; + ++index; + } + + return arrow::Result>(std::move(values)); +} + +static arrow::Result> BuildParameters(const rapidjson::Document& doc) { + return BuildParameters(doc.GetArray()); +} + arrow::Result> WebDB::Connection::ExecutePreparedStatement( size_t statement_id, std::string_view args_json) { try { @@ -313,25 +346,12 @@ arrow::Result> WebDB::Connection::Execut if (!ok) return arrow::Status{arrow::StatusCode::Invalid, rapidjson::GetParseError_En(ok.Code())}; if (!args_doc.IsArray()) return arrow::Status{arrow::StatusCode::Invalid, "Arguments must be given as array"}; - duckdb::vector values; - size_t index = 0; - for (const auto& v : args_doc.GetArray()) { - if (v.IsLosslessDouble()) - values.emplace_back(v.GetDouble()); - else if (v.IsString()) - // Use GetStringLenght otherwise null bytes will be counted as terminators - values.emplace_back(string_t(v.GetString(), v.GetStringLength())); - else if (v.IsNull()) - values.emplace_back(nullptr); - else if (v.IsBool()) - values.emplace_back(v.GetBool()); - else - return arrow::Status{arrow::StatusCode::Invalid, - "Invalid column type encountered for argument " + std::to_string(index)}; - ++index; + auto values = BuildParameters(args_doc); + if (!values.ok()) { + return values.status(); } - auto result = stmt->second->Execute(values); + auto result = stmt->second->Execute(*values); if (result->HasError()) return arrow::Status{arrow::StatusCode::ExecutionError, std::move(result->GetError())}; return result; } catch (std::exception& e) { diff --git a/lib/test/prepared_statement_test.cc b/lib/test/prepared_statement_test.cc new file mode 100644 index 000000000..7e9ac5ec0 --- /dev/null +++ b/lib/test/prepared_statement_test.cc @@ -0,0 +1,74 @@ +#include +#include + +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/array_primitive.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/table.h" +#include "duckdb/web/test/config.h" +#include "duckdb/web/webdb.h" + +using namespace duckdb::web; +using namespace std; +namespace fs = std::filesystem; + +namespace { + +TEST(PreparedStatent, WithArrayParams) { + auto db = make_shared(NATIVE); + WebDB::Connection conn{*db}; + + auto data = test::SOURCE_DIR / ".." / "data" / "uni" / "studenten.parquet"; + if (!fs::exists(data)) GTEST_SKIP_(": Missing data"); + + std::stringstream ss; + ss << "SELECT * FROM parquet_scan('" << data.string() << "') WHERE semester = ANY(?)"; + + auto stmt = conn.CreatePreparedStatement(ss.str()); + ASSERT_TRUE(stmt.ok()) << stmt.status().message(); + + auto buffer = conn.RunPreparedStatement(*stmt, "[[12, 2]]"); + ASSERT_TRUE(buffer.ok()) << buffer.status().message(); + + ::arrow::io::BufferReader buffer_reader(*buffer); + auto const reader = ::arrow::ipc::RecordBatchFileReader::Open(&buffer_reader); + ASSERT_TRUE(reader.ok()) << reader.status().message(); + + auto const batches = (*reader)->ToRecordBatches(); + ASSERT_TRUE(batches.ok()) << batches.status().message(); + + for (auto& batch : *batches) { + auto const rows = batch->GetColumnByName("semester"); + ASSERT_TRUE(rows) << "Must contain `semester` column"; + + auto const int_rows = dynamic_pointer_cast(rows); + + for (auto i = 0; i < int_rows->length(); ++i) { + EXPECT_THAT(int_rows->Value(i), testing::AnyOf(testing::Eq(2), testing::Eq(12))); + } + } +} + +TEST(PreparedStatent, WithArrayParamsIllegal) { + auto db = make_shared(NATIVE); + WebDB::Connection conn{*db}; + + auto data = test::SOURCE_DIR / ".." / "data" / "uni" / "studenten.parquet"; + if (!fs::exists(data)) GTEST_SKIP_(": Missing data"); + + std::stringstream ss; + ss << "SELECT * FROM parquet_scan('" << data.string() << "') WHERE semester = ANY(?)"; + + auto stmt = conn.CreatePreparedStatement(ss.str()); + ASSERT_TRUE(stmt.ok()) << stmt.status().message(); + + // passed ununiformed type + auto buffer = conn.RunPreparedStatement(*stmt, "[[12, [2]]]"); + ASSERT_FALSE(buffer.ok()); + ASSERT_EQ(buffer.status().code(), arrow::StatusCode::ExecutionError); +} + +} // namespace