From 765d29f00978cab890015de565d638975aaee766 Mon Sep 17 00:00:00 2001 From: dentiny Date: Fri, 6 Dec 2024 11:22:10 +0000 Subject: [PATCH] clang format all files --- Makefile | 2 +- .../iterative_length_function_data.cpp | 4 +- ...l_clustering_coefficient_function_data.cpp | 7 +- ...akly_connected_component_function_data.cpp | 27 +- src/core/functions/scalar/csr_creation.cpp | 24 +- src/core/functions/scalar/csr_deletion.cpp | 7 +- src/core/functions/scalar/csr_get_w_type.cpp | 8 +- src/core/functions/scalar/iterativelength.cpp | 13 +- .../functions/scalar/iterativelength2.cpp | 12 +- .../scalar/iterativelength_bidirectional.cpp | 2 - .../scalar/local_clustering_coefficient.cpp | 12 +- src/core/functions/scalar/pagerank.cpp | 211 +++++----- src/core/functions/scalar/reachability.cpp | 14 +- src/core/functions/scalar/shortest_path.cpp | 15 +- .../scalar/weakly_connected_component.cpp | 7 +- .../functions/table/create_property_graph.cpp | 200 ++++++--- .../table/describe_property_graph.cpp | 5 +- .../functions/table/drop_property_graph.cpp | 13 +- .../table/local_clustering_coefficient.cpp | 53 +-- src/core/functions/table/match.cpp | 176 ++++---- src/core/functions/table/pagerank.cpp | 10 +- src/core/functions/table/pgq_scan.cpp | 62 +-- .../table/weakly_connected_component.cpp | 16 +- src/core/module.cpp | 9 +- src/core/operator/duckpgq_bind.cpp | 6 +- src/core/parser/duckpgq_parser.cpp | 19 +- src/core/utils/compressed_sparse_row.cpp | 383 +++++++++++------- src/core/utils/duckpgq_bitmap.cpp | 4 +- src/core/utils/duckpgq_utils.cpp | 52 ++- src/duckpgq_extension.cpp | 20 +- src/duckpgq_state.cpp | 172 ++++---- .../cheapest_path_length_function_data.hpp | 2 - ...l_clustering_coefficient_function_data.hpp | 8 +- .../function_data/pagerank_function_data.hpp | 7 +- ...akly_connected_component_function_data.hpp | 10 +- src/include/duckpgq/core/functions/scalar.hpp | 15 +- src/include/duckpgq/core/functions/table.hpp | 7 +- .../functions/table/create_property_graph.hpp | 29 +- .../table/describe_property_graph.hpp | 1 - .../functions/table/drop_property_graph.hpp | 1 - .../table/local_clustering_coefficient.hpp | 70 ++-- .../duckpgq/core/functions/table/match.hpp | 70 ++-- .../duckpgq/core/functions/table/pagerank.hpp | 13 +- .../duckpgq/core/functions/table/pgq_scan.hpp | 3 +- .../table/weakly_connected_component.hpp | 18 +- src/include/duckpgq/core/module.hpp | 1 - .../duckpgq/core/operator/duckpgq_bind.hpp | 1 - .../core/operator/duckpgq_operator.hpp | 4 +- .../duckpgq/core/parser/duckpgq_parser.hpp | 5 +- .../core/utils/compressed_sparse_row.hpp | 36 +- .../duckpgq/core/utils/duckpgq_utils.hpp | 21 +- src/include/duckpgq_extension.hpp | 2 +- src/include/duckpgq_extension_callback.hpp | 7 +- src/include/duckpgq_state.hpp | 9 +- 54 files changed, 1097 insertions(+), 808 deletions(-) diff --git a/Makefile b/Makefile index d306f732..5e2664e2 100644 --- a/Makefile +++ b/Makefile @@ -5,4 +5,4 @@ EXT_NAME=duckpgq EXT_CONFIG=${PROJ_DIR}extension_config.cmake # Include the Makefile from extension-ci-tools -include extension-ci-tools/makefiles/duckdb_extension.Makefile \ No newline at end of file +include extension-ci-tools/makefiles/duckdb_extension.Makefile diff --git a/src/core/functions/function_data/iterative_length_function_data.cpp b/src/core/functions/function_data/iterative_length_function_data.cpp index 5d7f20df..7f4c48bb 100644 --- a/src/core/functions/function_data/iterative_length_function_data.cpp +++ b/src/core/functions/function_data/iterative_length_function_data.cpp @@ -15,7 +15,6 @@ bool IterativeLengthFunctionData::Equals(const FunctionData &other_p) const { return other.csr_id == csr_id; } - unique_ptr IterativeLengthFunctionData::IterativeLengthBind( ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { @@ -29,7 +28,6 @@ unique_ptr IterativeLengthFunctionData::IterativeLengthBind( return make_uniq(context, csr_id); } - } // namespace core -} //namespace duckpgq \ No newline at end of file +} // namespace duckpgq \ No newline at end of file diff --git a/src/core/functions/function_data/local_clustering_coefficient_function_data.cpp b/src/core/functions/function_data/local_clustering_coefficient_function_data.cpp index cbc1672e..29df3206 100644 --- a/src/core/functions/function_data/local_clustering_coefficient_function_data.cpp +++ b/src/core/functions/function_data/local_clustering_coefficient_function_data.cpp @@ -7,9 +7,7 @@ namespace core { LocalClusteringCoefficientFunctionData::LocalClusteringCoefficientFunctionData( ClientContext &context, int32_t csr_id) - : context(context), csr_id(csr_id) { - -} + : context(context), csr_id(csr_id) {} unique_ptr LocalClusteringCoefficientFunctionData::LocalClusteringCoefficientBind( @@ -29,7 +27,8 @@ unique_ptr LocalClusteringCoefficientFunctionData::Copy() const { return make_uniq(context, csr_id); } -bool LocalClusteringCoefficientFunctionData::Equals(const FunctionData &other_p) const { +bool LocalClusteringCoefficientFunctionData::Equals( + const FunctionData &other_p) const { auto &other = (const LocalClusteringCoefficientFunctionData &)other_p; return other.csr_id == csr_id; } diff --git a/src/core/functions/function_data/weakly_connected_component_function_data.cpp b/src/core/functions/function_data/weakly_connected_component_function_data.cpp index ca63eb02..c4663f26 100644 --- a/src/core/functions/function_data/weakly_connected_component_function_data.cpp +++ b/src/core/functions/function_data/weakly_connected_component_function_data.cpp @@ -4,19 +4,23 @@ namespace duckpgq { namespace core { -WeaklyConnectedComponentFunctionData::WeaklyConnectedComponentFunctionData(ClientContext &context, int32_t csr_id) +WeaklyConnectedComponentFunctionData::WeaklyConnectedComponentFunctionData( + ClientContext &context, int32_t csr_id) : context(context), csr_id(csr_id) { - componentId = vector(); - component_id_initialized = false; + componentId = vector(); + component_id_initialized = false; } -WeaklyConnectedComponentFunctionData::WeaklyConnectedComponentFunctionData(ClientContext &context, int32_t csr_id, const vector &componentId) +WeaklyConnectedComponentFunctionData::WeaklyConnectedComponentFunctionData( + ClientContext &context, int32_t csr_id, const vector &componentId) : context(context), csr_id(csr_id), componentId(componentId) { - component_id_initialized = false; + component_id_initialized = false; } -unique_ptr WeaklyConnectedComponentFunctionData::WeaklyConnectedComponentBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { +unique_ptr +WeaklyConnectedComponentFunctionData::WeaklyConnectedComponentBind( + ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { if (!arguments[0]->IsFoldable()) { throw InvalidInputException("Id must be constant."); } @@ -24,16 +28,17 @@ unique_ptr WeaklyConnectedComponentFunctionData::WeaklyConnectedCo int32_t csr_id = ExpressionExecutor::EvaluateScalar(context, *arguments[0]) .GetValue(); - return make_uniq(context, csr_id); + return make_uniq(context, csr_id); } unique_ptr WeaklyConnectedComponentFunctionData::Copy() const { - auto result = make_uniq(context, csr_id, componentId); + auto result = make_uniq(context, csr_id, + componentId); result->component_id_initialized = component_id_initialized; return std::move(result); - } -bool WeaklyConnectedComponentFunctionData::Equals(const FunctionData &other_p) const { +bool WeaklyConnectedComponentFunctionData::Equals( + const FunctionData &other_p) const { auto &other = (const WeaklyConnectedComponentFunctionData &)other_p; if (csr_id != other.csr_id) { return false; diff --git a/src/core/functions/scalar/csr_creation.cpp b/src/core/functions/scalar/csr_creation.cpp index 9cd60fc0..f3c0fbb8 100644 --- a/src/core/functions/scalar/csr_creation.cpp +++ b/src/core/functions/scalar/csr_creation.cpp @@ -124,7 +124,8 @@ static void CreateCsrEdgeFunction(DataChunk &args, ExpressionState &state, auto &func_expr = (BoundFunctionExpression &)state.expr; auto &info = (CSRFunctionData &)*func_expr.bind_info; - auto duckpgq_state = info.context.registered_state->Get("duckpgq"); + auto duckpgq_state = + info.context.registered_state->Get("duckpgq"); if (!duckpgq_state) { //! Wondering how you can get here if the extension wasn't loaded, but //! leaving this check in anyways @@ -136,8 +137,9 @@ static void CreateCsrEdgeFunction(DataChunk &args, ExpressionState &state, int64_t edge_size = args.data[2].GetValue(0).GetValue(); int64_t edge_size_count = args.data[3].GetValue(0).GetValue(); if (edge_size != edge_size_count) { - duckpgq_state->csr_to_delete.insert(info.id); - throw ConstraintException("Non-unique vertices detected. Make sure all vertices are unique for path-finding queries."); + duckpgq_state->csr_to_delete.insert(info.id); + throw ConstraintException("Non-unique vertices detected. Make sure all " + "vertices are unique for path-finding queries."); } auto csr_entry = duckpgq_state->csr_list.find(info.id); @@ -211,24 +213,25 @@ ScalarFunctionSet GetCSREdgeFunction() { //! No edge weight set.AddFunction(ScalarFunction({LogicalType::INTEGER, LogicalType::BIGINT, - LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, - LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::BIGINT, LogicalType::BIGINT, + LogicalType::BIGINT, LogicalType::BIGINT, + LogicalType::BIGINT}, LogicalType::INTEGER, CreateCsrEdgeFunction, CSRFunctionData::CSREdgeBind)); //! Integer for edge weight set.AddFunction(ScalarFunction({LogicalType::INTEGER, LogicalType::BIGINT, - LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, - LogicalType::BIGINT}, + LogicalType::BIGINT, LogicalType::BIGINT, + LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::INTEGER, CreateCsrEdgeFunction, CSRFunctionData::CSREdgeBind)); //! Double for edge weight set.AddFunction(ScalarFunction({LogicalType::INTEGER, LogicalType::BIGINT, - LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, - LogicalType::DOUBLE}, + LogicalType::BIGINT, LogicalType::BIGINT, + LogicalType::BIGINT, LogicalType::DOUBLE}, LogicalType::INTEGER, CreateCsrEdgeFunction, CSRFunctionData::CSREdgeBind)); @@ -238,7 +241,8 @@ ScalarFunctionSet GetCSREdgeFunction() { //------------------------------------------------------------------------------ // Register functions //------------------------------------------------------------------------------ -void CoreScalarFunctions::RegisterCSRCreationScalarFunctions(DatabaseInstance &db) { +void CoreScalarFunctions::RegisterCSRCreationScalarFunctions( + DatabaseInstance &db) { ExtensionUtil::RegisterFunction(db, GetCSREdgeFunction()); ExtensionUtil::RegisterFunction(db, GetCSRVertexFunction()); } diff --git a/src/core/functions/scalar/csr_deletion.cpp b/src/core/functions/scalar/csr_deletion.cpp index b1d8c4ac..1e29846b 100644 --- a/src/core/functions/scalar/csr_deletion.cpp +++ b/src/core/functions/scalar/csr_deletion.cpp @@ -16,7 +16,6 @@ static void DeleteCsrFunction(DataChunk &args, ExpressionState &state, auto duckpgq_state = GetDuckPGQState(info.context); - int flag = duckpgq_state->csr_list.erase(info.id); result.SetVectorType(VectorType::CONSTANT_VECTOR); auto result_data = ConstantVector::GetData(result); @@ -30,12 +29,10 @@ void CoreScalarFunctions::RegisterCSRDeletionScalarFunction( DatabaseInstance &db) { ExtensionUtil::RegisterFunction( db, - ScalarFunction("delete_csr", {LogicalType::INTEGER}, LogicalType::BOOLEAN, DeleteCsrFunction, - CSRFunctionData::CSRBind)); + ScalarFunction("delete_csr", {LogicalType::INTEGER}, LogicalType::BOOLEAN, + DeleteCsrFunction, CSRFunctionData::CSRBind)); } - } // namespace core } // namespace duckpgq - diff --git a/src/core/functions/scalar/csr_get_w_type.cpp b/src/core/functions/scalar/csr_get_w_type.cpp index ea735b0c..c64e8851 100644 --- a/src/core/functions/scalar/csr_get_w_type.cpp +++ b/src/core/functions/scalar/csr_get_w_type.cpp @@ -45,13 +45,11 @@ static void GetCsrWTypeFunction(DataChunk &args, ExpressionState &state, void CoreScalarFunctions::RegisterGetCSRWTypeScalarFunction( DatabaseInstance &db) { ExtensionUtil::RegisterFunction( - db, - ScalarFunction("csr_get_w_type", {LogicalType::INTEGER}, - LogicalType::INTEGER, GetCsrWTypeFunction, - CSRFunctionData::CSRBind)); + db, ScalarFunction("csr_get_w_type", {LogicalType::INTEGER}, + LogicalType::INTEGER, GetCsrWTypeFunction, + CSRFunctionData::CSRBind)); } } // namespace core } // namespace duckpgq - diff --git a/src/core/functions/scalar/iterativelength.cpp b/src/core/functions/scalar/iterativelength.cpp index a195c67c..deb84ce3 100644 --- a/src/core/functions/scalar/iterativelength.cpp +++ b/src/core/functions/scalar/iterativelength.cpp @@ -41,7 +41,6 @@ static void IterativeLengthFunction(DataChunk &args, ExpressionState &state, auto &info = (IterativeLengthFunctionData &)*func_expr.bind_info; auto duckpgq_state = GetDuckPGQState(info.context); - D_ASSERT(duckpgq_state->csr_list[info.csr_id]); if ((uint64_t)info.csr_id + 1 > duckpgq_state->csr_list.size()) { @@ -160,15 +159,13 @@ static void IterativeLengthFunction(DataChunk &args, ExpressionState &state, void CoreScalarFunctions::RegisterIterativeLengthScalarFunction( DatabaseInstance &db) { ExtensionUtil::RegisterFunction( - db, - ScalarFunction("iterativelength", - {LogicalType::INTEGER, LogicalType::BIGINT, - LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::BIGINT, IterativeLengthFunction, - IterativeLengthFunctionData::IterativeLengthBind)); + db, ScalarFunction("iterativelength", + {LogicalType::INTEGER, LogicalType::BIGINT, + LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::BIGINT, IterativeLengthFunction, + IterativeLengthFunctionData::IterativeLengthBind)); } } // namespace core } // namespace duckpgq - diff --git a/src/core/functions/scalar/iterativelength2.cpp b/src/core/functions/scalar/iterativelength2.cpp index c61fa095..9742fb0b 100644 --- a/src/core/functions/scalar/iterativelength2.cpp +++ b/src/core/functions/scalar/iterativelength2.cpp @@ -145,15 +145,13 @@ static void IterativeLength2Function(DataChunk &args, ExpressionState &state, void CoreScalarFunctions::RegisterIterativeLength2ScalarFunction( DatabaseInstance &db) { ExtensionUtil::RegisterFunction( - db, - ScalarFunction("iterativelength2", - {LogicalType::INTEGER, LogicalType::BIGINT, - LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::BIGINT, IterativeLength2Function, - IterativeLengthFunctionData::IterativeLengthBind)); + db, ScalarFunction("iterativelength2", + {LogicalType::INTEGER, LogicalType::BIGINT, + LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::BIGINT, IterativeLength2Function, + IterativeLengthFunctionData::IterativeLengthBind)); } } // namespace core } // namespace duckpgq - diff --git a/src/core/functions/scalar/iterativelength_bidirectional.cpp b/src/core/functions/scalar/iterativelength_bidirectional.cpp index 87046d7f..6fbc6d0f 100644 --- a/src/core/functions/scalar/iterativelength_bidirectional.cpp +++ b/src/core/functions/scalar/iterativelength_bidirectional.cpp @@ -53,7 +53,6 @@ static void IterativeLengthBidirectionalFunction(DataChunk &args, auto duckpgq_state = GetDuckPGQState(info.context); - D_ASSERT(duckpgq_state->csr_list[info.csr_id]); int64_t v_size = args.data[1].GetValue(0).GetValue(); int64_t *v = (int64_t *)duckpgq_state->csr_list[info.csr_id]->v; @@ -182,4 +181,3 @@ void CoreScalarFunctions::RegisterIterativeLengthBidirectionalScalarFunction( } // namespace core } // namespace duckpgq - diff --git a/src/core/functions/scalar/local_clustering_coefficient.cpp b/src/core/functions/scalar/local_clustering_coefficient.cpp index d83d5d2e..a4628c48 100644 --- a/src/core/functions/scalar/local_clustering_coefficient.cpp +++ b/src/core/functions/scalar/local_clustering_coefficient.cpp @@ -10,16 +10,16 @@ namespace duckpgq { namespace core { -static void LocalClusteringCoefficientFunction(DataChunk &args, ExpressionState &state, - Vector &result) { +static void LocalClusteringCoefficientFunction(DataChunk &args, + ExpressionState &state, + Vector &result) { auto &func_expr = (BoundFunctionExpression &)state.expr; auto &info = (LocalClusteringCoefficientFunctionData &)*func_expr.bind_info; auto duckpgq_state = GetDuckPGQState(info.context); auto csr_entry = duckpgq_state->csr_list.find((uint64_t)info.csr_id); if (csr_entry == duckpgq_state->csr_list.end()) { - throw ConstraintException( - "CSR not found. Is the graph populated?"); + throw ConstraintException("CSR not found. Is the graph populated?"); } if (!(csr_entry->second->initialized_v && csr_entry->second->initialized_e)) { @@ -68,7 +68,8 @@ static void LocalClusteringCoefficientFunction(DataChunk &args, ExpressionState } } - float local_result = static_cast(count) / (number_of_edges * (number_of_edges - 1)); + float local_result = + static_cast(count) / (number_of_edges * (number_of_edges - 1)); result_data[n] = local_result; } duckpgq_state->csr_to_delete.insert(info.csr_id); @@ -90,4 +91,3 @@ void CoreScalarFunctions::RegisterLocalClusteringCoefficientScalarFunction( } // namespace core } // namespace duckpgq - diff --git a/src/core/functions/scalar/pagerank.cpp b/src/core/functions/scalar/pagerank.cpp index 889d453e..7a9f33e1 100644 --- a/src/core/functions/scalar/pagerank.cpp +++ b/src/core/functions/scalar/pagerank.cpp @@ -9,121 +9,122 @@ namespace duckpgq { namespace core { -static void PageRankFunction(DataChunk &args, - ExpressionState &state, +static void PageRankFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = (BoundFunctionExpression &)state.expr; - auto &info = (PageRankFunctionData &)*func_expr.bind_info; - auto duckpgq_state = GetDuckPGQState(info.context); - - // Locate the CSR representation of the graph - auto csr_entry = duckpgq_state->csr_list.find((uint64_t)info.csr_id); - if (csr_entry == duckpgq_state->csr_list.end()) { - throw ConstraintException("CSR not found. Is the graph populated?"); - } - - if (!(csr_entry->second->initialized_v && csr_entry->second->initialized_e)) { - throw ConstraintException("Need to initialize CSR before running PageRank."); - } - - int64_t *v = (int64_t *)duckpgq_state->csr_list[info.csr_id]->v; - vector &e = duckpgq_state->csr_list[info.csr_id]->e; - size_t v_size = duckpgq_state->csr_list[info.csr_id]->vsize; - - // State initialization (only once) - if (!info.state_initialized) { - info.rank.resize(v_size, 1.0 / v_size); // Initial rank for each node - info.temp_rank.resize(v_size, 0.0); // Temporary storage for ranks during iteration - info.damping_factor = 0.85; // Typical damping factor - info.convergence_threshold = 1e-6; // Convergence threshold - info.state_initialized = true; - info.converged = false; - info.iteration_count = 0; - } - - // Check if already converged - if (!info.converged) { - std::lock_guard guard(info.state_lock); // Thread safety - - bool continue_iteration = true; - while (continue_iteration) { - fill(info.temp_rank.begin(), info.temp_rank.end(), 0.0); - - double total_dangling_rank = 0.0; // For dangling nodes - - for (size_t i = 0; i < v_size; i++) { - int64_t start_edge = v[i]; - int64_t end_edge = (i + 1 < v_size) ? v[i + 1] : e.size(); // Adjust end_edge - if (end_edge > start_edge) { - double rank_contrib = info.rank[i] / (end_edge - start_edge); - for (int64_t j = start_edge; j < end_edge; j++) { - int64_t neighbor = e[j]; - info.temp_rank[neighbor] += rank_contrib; - } - } else { - total_dangling_rank += info.rank[i]; - } - } - - // Apply damping factor and handle dangling node ranks - double correction_factor = total_dangling_rank / v_size; - double max_delta = 0.0; - for (size_t i = 0; i < v_size; i++) { - info.temp_rank[i] = (1 - info.damping_factor) / v_size + - info.damping_factor * (info.temp_rank[i] + correction_factor); - max_delta = std::max(max_delta, std::abs(info.temp_rank[i] - info.rank[i])); - } - - info.rank.swap(info.temp_rank); - info.iteration_count++; - if (max_delta < info.convergence_threshold) { - info.converged = true; - continue_iteration = false; - } + auto &func_expr = (BoundFunctionExpression &)state.expr; + auto &info = (PageRankFunctionData &)*func_expr.bind_info; + auto duckpgq_state = GetDuckPGQState(info.context); + + // Locate the CSR representation of the graph + auto csr_entry = duckpgq_state->csr_list.find((uint64_t)info.csr_id); + if (csr_entry == duckpgq_state->csr_list.end()) { + throw ConstraintException("CSR not found. Is the graph populated?"); + } + + if (!(csr_entry->second->initialized_v && csr_entry->second->initialized_e)) { + throw ConstraintException( + "Need to initialize CSR before running PageRank."); + } + + int64_t *v = (int64_t *)duckpgq_state->csr_list[info.csr_id]->v; + vector &e = duckpgq_state->csr_list[info.csr_id]->e; + size_t v_size = duckpgq_state->csr_list[info.csr_id]->vsize; + + // State initialization (only once) + if (!info.state_initialized) { + info.rank.resize(v_size, 1.0 / v_size); // Initial rank for each node + info.temp_rank.resize(v_size, + 0.0); // Temporary storage for ranks during iteration + info.damping_factor = 0.85; // Typical damping factor + info.convergence_threshold = 1e-6; // Convergence threshold + info.state_initialized = true; + info.converged = false; + info.iteration_count = 0; + } + + // Check if already converged + if (!info.converged) { + std::lock_guard guard(info.state_lock); // Thread safety + + bool continue_iteration = true; + while (continue_iteration) { + fill(info.temp_rank.begin(), info.temp_rank.end(), 0.0); + + double total_dangling_rank = 0.0; // For dangling nodes + + for (size_t i = 0; i < v_size; i++) { + int64_t start_edge = v[i]; + int64_t end_edge = + (i + 1 < v_size) ? v[i + 1] : e.size(); // Adjust end_edge + if (end_edge > start_edge) { + double rank_contrib = info.rank[i] / (end_edge - start_edge); + for (int64_t j = start_edge; j < end_edge; j++) { + int64_t neighbor = e[j]; + info.temp_rank[neighbor] += rank_contrib; + } + } else { + total_dangling_rank += info.rank[i]; } + } + + // Apply damping factor and handle dangling node ranks + double correction_factor = total_dangling_rank / v_size; + double max_delta = 0.0; + for (size_t i = 0; i < v_size; i++) { + info.temp_rank[i] = + (1 - info.damping_factor) / v_size + + info.damping_factor * (info.temp_rank[i] + correction_factor); + max_delta = + std::max(max_delta, std::abs(info.temp_rank[i] - info.rank[i])); + } + + info.rank.swap(info.temp_rank); + info.iteration_count++; + if (max_delta < info.convergence_threshold) { + info.converged = true; + continue_iteration = false; + } } - - // Get the source vector for the current DataChunk - auto &src = args.data[1]; - UnifiedVectorFormat vdata_src; - src.ToUnifiedFormat(args.size(), vdata_src); - auto src_data = (int64_t *)vdata_src.data; - - // Create result vector - ValidityMask &result_validity = FlatVector::Validity(result); - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - - // Output the PageRank value corresponding to each source ID in the DataChunk - for (idx_t i = 0; i < args.size(); i++) { - auto id_pos = vdata_src.sel->get_index(i); - if (!vdata_src.validity.RowIsValid(id_pos)) { - result_validity.SetInvalid(i); - continue; // Skip invalid rows - } - auto node_id = src_data[id_pos]; - if (node_id < 0 || node_id >= (int64_t)v_size) { - result_validity.SetInvalid(i); - continue; - } - result_data[i] = info.rank[node_id]; + } + + // Get the source vector for the current DataChunk + auto &src = args.data[1]; + UnifiedVectorFormat vdata_src; + src.ToUnifiedFormat(args.size(), vdata_src); + auto src_data = (int64_t *)vdata_src.data; + + // Create result vector + ValidityMask &result_validity = FlatVector::Validity(result); + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + + // Output the PageRank value corresponding to each source ID in the DataChunk + for (idx_t i = 0; i < args.size(); i++) { + auto id_pos = vdata_src.sel->get_index(i); + if (!vdata_src.validity.RowIsValid(id_pos)) { + result_validity.SetInvalid(i); + continue; // Skip invalid rows + } + auto node_id = src_data[id_pos]; + if (node_id < 0 || node_id >= (int64_t)v_size) { + result_validity.SetInvalid(i); + continue; } + result_data[i] = info.rank[node_id]; + } - duckpgq_state->csr_to_delete.insert(info.csr_id); + duckpgq_state->csr_to_delete.insert(info.csr_id); } //------------------------------------------------------------------------------ // Register functions //------------------------------------------------------------------------------ -void CoreScalarFunctions::RegisterPageRankScalarFunction( - DatabaseInstance &db) { - ExtensionUtil::RegisterFunction( - db, - ScalarFunction( - "pagerank", - {LogicalType::INTEGER, LogicalType::BIGINT}, - LogicalType::DOUBLE, PageRankFunction, - PageRankFunctionData::PageRankBind)); +void CoreScalarFunctions::RegisterPageRankScalarFunction(DatabaseInstance &db) { + ExtensionUtil::RegisterFunction( + db, + ScalarFunction("pagerank", {LogicalType::INTEGER, LogicalType::BIGINT}, + LogicalType::DOUBLE, PageRankFunction, + PageRankFunctionData::PageRankBind)); } } // namespace core diff --git a/src/core/functions/scalar/reachability.cpp b/src/core/functions/scalar/reachability.cpp index 12bddf92..b7edc8eb 100644 --- a/src/core/functions/scalar/reachability.cpp +++ b/src/core/functions/scalar/reachability.cpp @@ -283,16 +283,14 @@ static void ReachabilityFunction(DataChunk &args, ExpressionState &state, void CoreScalarFunctions::RegisterReachabilityScalarFunction( DatabaseInstance &db) { ExtensionUtil::RegisterFunction( - db, - ScalarFunction("reachability", - {LogicalType::INTEGER, LogicalType::BOOLEAN, - LogicalType::BIGINT, LogicalType::BIGINT, - LogicalType::BIGINT}, - LogicalType::BOOLEAN, ReachabilityFunction, - IterativeLengthFunctionData::IterativeLengthBind)); + db, ScalarFunction("reachability", + {LogicalType::INTEGER, LogicalType::BOOLEAN, + LogicalType::BIGINT, LogicalType::BIGINT, + LogicalType::BIGINT}, + LogicalType::BOOLEAN, ReachabilityFunction, + IterativeLengthFunctionData::IterativeLengthBind)); } } // namespace core } // namespace duckpgq - diff --git a/src/core/functions/scalar/shortest_path.cpp b/src/core/functions/scalar/shortest_path.cpp index 669426a7..34c73e5c 100644 --- a/src/core/functions/scalar/shortest_path.cpp +++ b/src/core/functions/scalar/shortest_path.cpp @@ -62,7 +62,6 @@ static void ShortestPathFunction(DataChunk &args, ExpressionState &state, } auto &csr = csr_entry->second; - if (!csr->initialized_v) { throw ConstraintException( "Need to initialize CSR before doing shortest path"); @@ -238,16 +237,14 @@ static void ShortestPathFunction(DataChunk &args, ExpressionState &state, void CoreScalarFunctions::RegisterShortestPathScalarFunction( DatabaseInstance &db) { ExtensionUtil::RegisterFunction( - db, - ScalarFunction("shortestpath", - {LogicalType::INTEGER, LogicalType::BIGINT, - LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::BIGINT), - ShortestPathFunction, - IterativeLengthFunctionData::IterativeLengthBind)); + db, ScalarFunction("shortestpath", + {LogicalType::INTEGER, LogicalType::BIGINT, + LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::BIGINT), + ShortestPathFunction, + IterativeLengthFunctionData::IterativeLengthBind)); } } // namespace core } // namespace duckpgq - diff --git a/src/core/functions/scalar/weakly_connected_component.cpp b/src/core/functions/scalar/weakly_connected_component.cpp index b42fbf44..2a163adf 100644 --- a/src/core/functions/scalar/weakly_connected_component.cpp +++ b/src/core/functions/scalar/weakly_connected_component.cpp @@ -116,7 +116,8 @@ static void WeaklyConnectedComponentFunction(DataChunk &args, int64_t src_node = src_data[src_pos]; // Check if the node is already part of a component if (info.componentId[src_node] != -1) { - result_data[search_num] = info.componentId[src_node]; // Already known component + result_data[search_num] = + info.componentId[src_node]; // Already known component continue; } @@ -179,8 +180,8 @@ void CoreScalarFunctions::RegisterWeaklyConnectedComponentScalarFunction( db, ScalarFunction( "weakly_connected_component", - {LogicalType::INTEGER, LogicalType::BIGINT}, - LogicalType::BIGINT, WeaklyConnectedComponentFunction, + {LogicalType::INTEGER, LogicalType::BIGINT}, LogicalType::BIGINT, + WeaklyConnectedComponentFunction, WeaklyConnectedComponentFunctionData::WeaklyConnectedComponentBind)); } diff --git a/src/core/functions/table/create_property_graph.cpp b/src/core/functions/table/create_property_graph.cpp index 1ac26093..a4b7b2a9 100644 --- a/src/core/functions/table/create_property_graph.cpp +++ b/src/core/functions/table/create_property_graph.cpp @@ -68,15 +68,17 @@ void CreatePropertyGraphFunction::CheckPropertyGraphTableColumns( } // Helper function to validate source/destination keys -void CreatePropertyGraphFunction::ValidateKeys(shared_ptr &edge_table, const string &reference, const string &key_type, - vector &pk_columns, vector &fk_columns, - const vector> &table_constraints) { +void CreatePropertyGraphFunction::ValidateKeys( + shared_ptr &edge_table, const string &reference, + const string &key_type, vector &pk_columns, + vector &fk_columns, + const vector> &table_constraints) { if (fk_columns.empty() && pk_columns.empty()) { if (table_constraints.empty()) { throw Exception(ExceptionType::INVALID, "No primary key - foreign key relationship found in " + - edge_table->table_name + " with " + StringUtil::Upper(key_type) + - " table " + reference); + edge_table->table_name + " with " + + StringUtil::Upper(key_type) + " table " + reference); } for (const auto &constraint : table_constraints) { @@ -85,13 +87,19 @@ void CreatePropertyGraphFunction::ValidateKeys(shared_ptr &e if (fk_constraint.info.table != reference) { continue; } - // If a PK-FK relationship was found earlier, throw an ambiguity exception + // If a PK-FK relationship was found earlier, throw an ambiguity + // exception if (!pk_columns.empty() && !fk_columns.empty()) { - throw Exception( - ExceptionType::INVALID, - "Multiple primary key - foreign key relationships detected between " + edge_table->table_name + " and " + reference + ". " - "Please explicitly define the primary key and foreign key columns using `" + - StringUtil::Upper(key_type) + " KEY REFERENCES " + reference + " `"); + throw Exception(ExceptionType::INVALID, + "Multiple primary key - foreign key relationships " + "detected between " + + edge_table->table_name + " and " + reference + + ". " + "Please explicitly define the primary key and " + "foreign key columns using `" + + StringUtil::Upper(key_type) + + " KEY REFERENCES " + reference + + " `"); } pk_columns = fk_constraint.pk_columns; fk_columns = fk_constraint.fk_columns; @@ -100,45 +108,56 @@ void CreatePropertyGraphFunction::ValidateKeys(shared_ptr &e if (pk_columns.empty()) { throw Exception(ExceptionType::INVALID, - "The primary key for the " + StringUtil::Upper(key_type) + " table " + reference + - " is not defined in the edge table " + edge_table->table_name); + "The primary key for the " + StringUtil::Upper(key_type) + + " table " + reference + + " is not defined in the edge table " + + edge_table->table_name); } if (fk_columns.empty()) { throw Exception(ExceptionType::INVALID, - "The foreign key for the " + StringUtil::Upper(key_type) + " table " + reference + - " is not defined in the edge table " + edge_table->table_name); + "The foreign key for the " + StringUtil::Upper(key_type) + + " table " + reference + + " is not defined in the edge table " + + edge_table->table_name); } } } -void CreatePropertyGraphFunction::ValidateForeignKeyColumns(shared_ptr &edge_table, const vector &fk_columns, optional_ptr &table) { +void CreatePropertyGraphFunction::ValidateForeignKeyColumns( + shared_ptr &edge_table, + const vector &fk_columns, optional_ptr &table) { for (const auto &fk : fk_columns) { if (!table->ColumnExists(fk)) { - throw Exception(ExceptionType::INVALID, - "Foreign key " + fk + " does not exist in table " + edge_table->table_name); + throw Exception(ExceptionType::INVALID, "Foreign key " + fk + + " does not exist in table " + + edge_table->table_name); } } } // Helper function to check if the vertex table is registered -void CreatePropertyGraphFunction::ValidateVertexTableRegistration(const string &reference, const case_insensitive_set_t &v_table_names) { +void CreatePropertyGraphFunction::ValidateVertexTableRegistration( + const string &reference, const case_insensitive_set_t &v_table_names) { if (v_table_names.find(reference) == v_table_names.end()) { throw Exception(ExceptionType::INVALID, "Referenced vertex table " + reference + - " is not registered in the vertex tables."); + " is not registered in the vertex tables."); } } // Helper function to validate primary keys in the source or destination tables -void CreatePropertyGraphFunction::ValidatePrimaryKeyInTable(Catalog &catalog, ClientContext &context, const string &schema, - const string &reference, const vector &pk_columns) { - auto &table_entry = catalog.GetEntry(context, schema, reference); +void CreatePropertyGraphFunction::ValidatePrimaryKeyInTable( + Catalog &catalog, ClientContext &context, const string &schema, + const string &reference, const vector &pk_columns) { + auto &table_entry = + catalog.GetEntry(context, schema, reference); for (const auto &pk : pk_columns) { if (!table_entry.ColumnExists(pk)) { - throw Exception(ExceptionType::INVALID, - "Primary key " + pk + " does not exist in table " + reference); + throw Exception(ExceptionType::INVALID, "Primary key " + pk + + " does not exist in table " + + reference); } } } @@ -171,29 +190,40 @@ unique_ptr CreatePropertyGraphFunction::CreatePropertyGraphBind( case_insensitive_set_t v_table_names; for (auto &vertex_table : info->vertex_tables) { - try { - auto &catalog = Catalog::GetCatalog(context, vertex_table->catalog_name); - auto table = catalog.GetEntry( - context, info->schema, vertex_table->table_name, OnEntryNotFound::RETURN_NULL); - - if (!table) { - throw Exception(ExceptionType::INVALID, - "Table " + (vertex_table->catalog_name.empty() ? DEFAULT_SCHEMA : vertex_table->catalog_name) + "." + vertex_table->table_name + " not found"); - } + try { + auto &catalog = Catalog::GetCatalog(context, vertex_table->catalog_name); + auto table = catalog.GetEntry( + context, info->schema, vertex_table->table_name, + OnEntryNotFound::RETURN_NULL); - CheckPropertyGraphTableColumns(vertex_table, *table); - CheckPropertyGraphTableLabels(vertex_table, *table); - } catch (CatalogException &e) { - auto &catalog = Catalog::GetCatalog(context, vertex_table->catalog_name); - auto table = catalog.GetEntry(context, info->schema, vertex_table->table_name, OnEntryNotFound::RETURN_NULL); - if (table) { - throw Exception(ExceptionType::INVALID, "Found a view with name " + vertex_table->table_name + ". Creating property graph tables over views is currently not supported."); - } - throw Exception(ExceptionType::INVALID, e.what()); - } catch (BinderException &e) { - throw Exception(ExceptionType::INVALID, "Catalog '" + vertex_table->catalog_name + "' does not exist!"); + if (!table) { + throw Exception(ExceptionType::INVALID, + "Table " + + (vertex_table->catalog_name.empty() + ? DEFAULT_SCHEMA + : vertex_table->catalog_name) + + "." + vertex_table->table_name + " not found"); } + CheckPropertyGraphTableColumns(vertex_table, *table); + CheckPropertyGraphTableLabels(vertex_table, *table); + } catch (CatalogException &e) { + auto &catalog = Catalog::GetCatalog(context, vertex_table->catalog_name); + auto table = catalog.GetEntry( + context, info->schema, vertex_table->table_name, + OnEntryNotFound::RETURN_NULL); + if (table) { + throw Exception(ExceptionType::INVALID, + "Found a view with name " + vertex_table->table_name + + ". Creating property graph tables over views is " + "currently not supported."); + } + throw Exception(ExceptionType::INVALID, e.what()); + } catch (BinderException &e) { + throw Exception(ExceptionType::INVALID, "Catalog '" + + vertex_table->catalog_name + + "' does not exist!"); + } v_table_names.insert(vertex_table->table_name); if (vertex_table->hasTableNameAlias()) { @@ -205,11 +235,16 @@ unique_ptr CreatePropertyGraphFunction::CreatePropertyGraphBind( try { auto &catalog = Catalog::GetCatalog(context, edge_table->catalog_name); - auto table = catalog.GetEntry(context, edge_table->schema_name, - edge_table->table_name, OnEntryNotFound::RETURN_NULL); + auto table = catalog.GetEntry( + context, edge_table->schema_name, edge_table->table_name, + OnEntryNotFound::RETURN_NULL); if (!table) { throw Exception(ExceptionType::INVALID, - "Table " + (edge_table->catalog_name.empty() ? DEFAULT_SCHEMA : edge_table->catalog_name) + "." + edge_table->table_name + " not found"); + "Table " + + (edge_table->catalog_name.empty() + ? DEFAULT_SCHEMA + : edge_table->catalog_name) + + "." + edge_table->table_name + " not found"); } CheckPropertyGraphTableColumns(edge_table, *table); @@ -218,38 +253,53 @@ unique_ptr CreatePropertyGraphFunction::CreatePropertyGraphBind( auto &table_constraints = table->GetConstraints(); ValidateKeys(edge_table, edge_table->source_reference, "source", - edge_table->source_pk, edge_table->source_fk, table_constraints); + edge_table->source_pk, edge_table->source_fk, + table_constraints); // Check source foreign key columns exist in the table ValidateForeignKeyColumns(edge_table, edge_table->source_fk, table); // Validate destination keys ValidateKeys(edge_table, edge_table->destination_reference, "destination", - edge_table->destination_pk, edge_table->destination_fk, table_constraints); + edge_table->destination_pk, edge_table->destination_fk, + table_constraints); // Check destination foreign key columns exist in the table ValidateForeignKeyColumns(edge_table, edge_table->destination_fk, table); // Validate source table registration - ValidateVertexTableRegistration(edge_table->source_reference, v_table_names); + ValidateVertexTableRegistration(edge_table->source_reference, + v_table_names); // Validate primary keys in the source table - ValidatePrimaryKeyInTable(catalog, context, info->schema, edge_table->source_reference, edge_table->source_pk); + ValidatePrimaryKeyInTable(catalog, context, info->schema, + edge_table->source_reference, + edge_table->source_pk); // Validate destination table registration - ValidateVertexTableRegistration(edge_table->destination_reference, v_table_names); + ValidateVertexTableRegistration(edge_table->destination_reference, + v_table_names); // Validate primary keys in the destination table - ValidatePrimaryKeyInTable(catalog, context, info->schema, edge_table->destination_reference, edge_table->destination_pk); + ValidatePrimaryKeyInTable(catalog, context, info->schema, + edge_table->destination_reference, + edge_table->destination_pk); } catch (CatalogException &e) { auto &catalog = Catalog::GetCatalog(context, edge_table->catalog_name); - auto table = catalog.GetEntry(context, info->schema, edge_table->table_name, OnEntryNotFound::RETURN_NULL); + auto table = catalog.GetEntry( + context, info->schema, edge_table->table_name, + OnEntryNotFound::RETURN_NULL); if (table) { - throw Exception(ExceptionType::INVALID, "Found a view with name " + edge_table->table_name + ". Creating property graph tables over views is currently not supported."); + throw Exception(ExceptionType::INVALID, + "Found a view with name " + edge_table->table_name + + ". Creating property graph tables over views is " + "currently not supported."); } throw Exception(ExceptionType::INVALID, e.what()); } catch (BinderException &e) { - throw Exception(ExceptionType::INVALID, "Catalog '" + edge_table->catalog_name + "' does not exist!"); + throw Exception(ExceptionType::INVALID, "Catalog '" + + edge_table->catalog_name + + "' does not exist!"); } } return make_uniq(info); @@ -267,32 +317,42 @@ void CreatePropertyGraphFunction::CreatePropertyGraphFunc( auto pg_info = bind_data.create_pg_info; auto duckpgq_state = GetDuckPGQState(context); - for (auto &connection : ConnectionManager::Get(*context.db).GetConnectionList()) { - auto local_state = connection->registered_state->Get("duckpgq"); + for (auto &connection : + ConnectionManager::Get(*context.db).GetConnectionList()) { + auto local_state = + connection->registered_state->Get("duckpgq"); if (!local_state) { continue; } local_state->registered_property_graphs[pg_info->property_graph_name] = - pg_info->Copy(); + pg_info->Copy(); } auto new_conn = make_shared_ptr(context.db); - auto retrieve_query = new_conn->Query("SELECT * FROM __duckpgq_internal where property_graph = '" + pg_info->property_graph_name + "';", false); + auto retrieve_query = new_conn->Query( + "SELECT * FROM __duckpgq_internal where property_graph = '" + + pg_info->property_graph_name + "';", + false); if (retrieve_query->HasError()) { throw TransactionException(retrieve_query->GetError()); } auto &query_result = retrieve_query->Cast(); if (query_result.RowCount() > 0) { if (pg_info->on_conflict == OnCreateConflict::ERROR_ON_CONFLICT) { - throw Exception(ExceptionType::INVALID, "Property graph " + pg_info->property_graph_name + " is already registered"); + throw Exception(ExceptionType::INVALID, "Property graph " + + pg_info->property_graph_name + + " is already registered"); } if (pg_info->on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT) { return; // Do nothing and silently return } if (pg_info->on_conflict == OnCreateConflict::REPLACE_ON_CONFLICT) { // DELETE the old property graph and insert new one. - new_conn->Query("DELETE FROM __duckpgq_internal WHERE property_graph = '" + pg_info->property_graph_name + "';", false); + new_conn->Query( + "DELETE FROM __duckpgq_internal WHERE property_graph = '" + + pg_info->property_graph_name + "';", + false); } } @@ -309,11 +369,14 @@ void CreatePropertyGraphFunction::CreatePropertyGraphFunc( insert_info += "NULL, "; // destination_table insert_info += "NULL, "; // destination_pk insert_info += "NULL, "; // destination_fk - insert_info += v_table->discriminator.empty() ? "NULL, " : "'" + v_table->discriminator + "', "; + insert_info += v_table->discriminator.empty() + ? "NULL, " + : "'" + v_table->discriminator + "', "; if (!v_table->discriminator.empty()) { insert_info += "["; for (idx_t i = 0; i < v_table->sub_labels.size(); i++) { - insert_info += "'" + v_table->sub_labels[i] + (i == v_table->sub_labels.size() - 1 ? "'" : "', "); + insert_info += "'" + v_table->sub_labels[i] + + (i == v_table->sub_labels.size() - 1 ? "'" : "', "); } insert_info += "],"; } else { @@ -354,11 +417,14 @@ void CreatePropertyGraphFunction::CreatePropertyGraphFunc( } insert_info += "], "; - insert_info += e_table->discriminator.empty() ? "NULL, " : "'" + e_table->discriminator + "', "; + insert_info += e_table->discriminator.empty() + ? "NULL, " + : "'" + e_table->discriminator + "', "; if (!e_table->discriminator.empty()) { insert_info += "["; for (idx_t i = 0; i < e_table->sub_labels.size(); i++) { - insert_info += "'" + e_table->sub_labels[i] + (i == e_table->sub_labels.size() - 1 ? "'" : "', "); + insert_info += "'" + e_table->sub_labels[i] + + (i == e_table->sub_labels.size() - 1 ? "'" : "', "); } insert_info += "], "; } else { diff --git a/src/core/functions/table/describe_property_graph.cpp b/src/core/functions/table/describe_property_graph.cpp index 1a1b9f8d..611db59c 100644 --- a/src/core/functions/table/describe_property_graph.cpp +++ b/src/core/functions/table/describe_property_graph.cpp @@ -11,8 +11,6 @@ namespace duckpgq { namespace core { - - unique_ptr DescribePropertyGraphFunction::DescribePropertyGraphBind( ClientContext &context, TableFunctionBindInput &input, @@ -156,7 +154,8 @@ void DescribePropertyGraphFunction::DescribePropertyGraphFunc( //------------------------------------------------------------------------------ // Register functions //------------------------------------------------------------------------------ -void CoreTableFunctions::RegisterDescribePropertyGraphTableFunction(DatabaseInstance &db) { +void CoreTableFunctions::RegisterDescribePropertyGraphTableFunction( + DatabaseInstance &db) { ExtensionUtil::RegisterFunction(db, DescribePropertyGraphFunction()); } diff --git a/src/core/functions/table/drop_property_graph.cpp b/src/core/functions/table/drop_property_graph.cpp index d53e7e38..8da22d74 100644 --- a/src/core/functions/table/drop_property_graph.cpp +++ b/src/core/functions/table/drop_property_graph.cpp @@ -9,9 +9,7 @@ namespace duckpgq { namespace core { - -unique_ptr -DropPropertyGraphFunction::DropPropertyGraphBind( +unique_ptr DropPropertyGraphFunction::DropPropertyGraphBind( ClientContext &context, TableFunctionBindInput &, vector &return_types, vector &names) { names.emplace_back("success"); @@ -53,8 +51,10 @@ void DropPropertyGraphFunction::DropPropertyGraphFunc( pg_info->property_graph_name); } - for (auto &connection : ConnectionManager::Get(*context.db).GetConnectionList()) { - auto local_state = connection->registered_state->Get("duckpgq"); + for (auto &connection : + ConnectionManager::Get(*context.db).GetConnectionList()) { + auto local_state = + connection->registered_state->Get("duckpgq"); if (!local_state) { continue; } @@ -70,7 +70,8 @@ void DropPropertyGraphFunction::DropPropertyGraphFunc( //------------------------------------------------------------------------------ // Register functions //------------------------------------------------------------------------------ -void CoreTableFunctions::RegisterDropPropertyGraphTableFunction(DatabaseInstance &db) { +void CoreTableFunctions::RegisterDropPropertyGraphTableFunction( + DatabaseInstance &db) { ExtensionUtil::RegisterFunction(db, DropPropertyGraphFunction()); } } // namespace core diff --git a/src/core/functions/table/local_clustering_coefficient.cpp b/src/core/functions/table/local_clustering_coefficient.cpp index 6eba3768..62da7510 100644 --- a/src/core/functions/table/local_clustering_coefficient.cpp +++ b/src/core/functions/table/local_clustering_coefficient.cpp @@ -14,36 +14,41 @@ namespace duckpgq { namespace core { - - // Main binding function -unique_ptr LocalClusteringCoefficientFunction::LocalClusteringCoefficientBindReplace(ClientContext &context, TableFunctionBindInput &input) { - auto pg_name = StringUtil::Lower(StringValue::Get(input.inputs[0])); - auto node_label = StringUtil::Lower(StringValue::Get(input.inputs[1])); - auto edge_label = StringUtil::Lower(StringValue::Get(input.inputs[2])); - - auto duckpgq_state = GetDuckPGQState(context); - auto pg_info = GetPropertyGraphInfo(duckpgq_state, pg_name); - auto edge_pg_entry = ValidateSourceNodeAndEdgeTable(pg_info, node_label, edge_label); - - auto select_node = CreateSelectNode(edge_pg_entry, "local_clustering_coefficient", "local_clustering_coefficient"); - - select_node->cte_map.map["csr_cte"] = CreateUndirectedCSRCTE(edge_pg_entry, select_node); - - auto subquery = make_uniq(); - subquery->node = std::move(select_node); - - auto result = make_uniq(std::move(subquery)); - result->alias = "lcc"; - return std::move(result); +unique_ptr +LocalClusteringCoefficientFunction::LocalClusteringCoefficientBindReplace( + ClientContext &context, TableFunctionBindInput &input) { + auto pg_name = StringUtil::Lower(StringValue::Get(input.inputs[0])); + auto node_label = StringUtil::Lower(StringValue::Get(input.inputs[1])); + auto edge_label = StringUtil::Lower(StringValue::Get(input.inputs[2])); + + auto duckpgq_state = GetDuckPGQState(context); + auto pg_info = GetPropertyGraphInfo(duckpgq_state, pg_name); + auto edge_pg_entry = + ValidateSourceNodeAndEdgeTable(pg_info, node_label, edge_label); + + auto select_node = + CreateSelectNode(edge_pg_entry, "local_clustering_coefficient", + "local_clustering_coefficient"); + + select_node->cte_map.map["csr_cte"] = + CreateUndirectedCSRCTE(edge_pg_entry, select_node); + + auto subquery = make_uniq(); + subquery->node = std::move(select_node); + + auto result = make_uniq(std::move(subquery)); + result->alias = "lcc"; + return std::move(result); } //------------------------------------------------------------------------------ // Register functions //------------------------------------------------------------------------------ -void CoreTableFunctions::RegisterLocalClusteringCoefficientTableFunction(DatabaseInstance &db) { - ExtensionUtil::RegisterFunction(db, LocalClusteringCoefficientFunction()); +void CoreTableFunctions::RegisterLocalClusteringCoefficientTableFunction( + DatabaseInstance &db) { + ExtensionUtil::RegisterFunction(db, LocalClusteringCoefficientFunction()); } } // namespace core -} // namespace duckdb +} // namespace duckpgq diff --git a/src/core/functions/table/match.cpp b/src/core/functions/table/match.cpp index e67d1bab..fcbc0e8f 100644 --- a/src/core/functions/table/match.cpp +++ b/src/core/functions/table/match.cpp @@ -179,7 +179,6 @@ unique_ptr PGQMatchFunction::CreateCountCTESubquery() { return temp_cte_select_subquery; } - void PGQMatchFunction::EdgeTypeAny( const shared_ptr &edge_table, const string &edge_binding, const string &prev_binding, @@ -193,8 +192,8 @@ void PGQMatchFunction::EdgeTypeAny( auto edge_left_ref = edge_table->CreateBaseTableRef(edge_binding); src_dst_select_node->from_table = std::move(edge_left_ref); auto src_dst_children = vector>(); - src_dst_children.push_back(make_uniq( - edge_table->source_fk[0], edge_binding)); + src_dst_children.push_back( + make_uniq(edge_table->source_fk[0], edge_binding)); src_dst_children.push_back(make_uniq( edge_table->destination_fk[0], edge_binding)); src_dst_children.push_back(make_uniq()); @@ -211,8 +210,8 @@ void PGQMatchFunction::EdgeTypeAny( dst_src_children.push_back(make_uniq( edge_table->destination_fk[0], edge_binding)); - dst_src_children.push_back(make_uniq( - edge_table->source_fk[0], edge_binding)); + dst_src_children.push_back( + make_uniq(edge_table->source_fk[0], edge_binding)); dst_src_children.push_back(make_uniq()); dst_src_select_node->select_list = std::move(dst_src_children); @@ -287,7 +286,8 @@ void PGQMatchFunction::EdgeTypeLeftRight( const string &edge_binding, const string &prev_binding, const string &next_binding, vector> &conditions, - case_insensitive_map_t> &alias_map, int32_t &extra_alias_counter) { + case_insensitive_map_t> &alias_map, + int32_t &extra_alias_counter) { auto src_left_expr = CreateMatchJoinExpression( edge_table->source_pk, edge_table->source_fk, next_binding, edge_binding); auto dst_left_expr = CreateMatchJoinExpression(edge_table->destination_pk, @@ -327,8 +327,8 @@ PathElement *PGQMatchFunction::HandleNestedSubPath( return GetPathElement(subpath->path_list[element_idx]); } -unique_ptr -PGQMatchFunction::CreateWhereClause(vector> &conditions) { +unique_ptr PGQMatchFunction::CreateWhereClause( + vector> &conditions) { unique_ptr where_clause; for (auto &condition : conditions) { if (where_clause) { @@ -342,8 +342,10 @@ PGQMatchFunction::CreateWhereClause(vector> &condit return where_clause; } -unique_ptr PGQMatchFunction::GenerateShortestPathCTE(CreatePropertyGraphInfo &pg_table, SubPath *edge_subpath, - PathElement *previous_vertex_element, PathElement * next_vertex_element, vector> &path_finding_conditions) { +unique_ptr PGQMatchFunction::GenerateShortestPathCTE( + CreatePropertyGraphInfo &pg_table, SubPath *edge_subpath, + PathElement *previous_vertex_element, PathElement *next_vertex_element, + vector> &path_finding_conditions) { auto cte_info = make_uniq(); auto select_statement = make_uniq(); auto select_node = make_uniq(); @@ -360,7 +362,8 @@ unique_ptr PGQMatchFunction::GenerateShortestPathCTE( vector> pathfinding_children; pathfinding_children.push_back(std::move(csr_id)); pathfinding_children.push_back(std::move(GetCountTable( - edge_table->source_pg_table, previous_vertex_element->variable_binding, edge_table->source_pk[0]))); + edge_table->source_pg_table, previous_vertex_element->variable_binding, + edge_table->source_pk[0]))); pathfinding_children.push_back(std::move(src_row_id)); pathfinding_children.push_back(std::move(dst_row_id)); @@ -368,10 +371,12 @@ unique_ptr PGQMatchFunction::GenerateShortestPathCTE( "shortestpath", std::move(pathfinding_children)); shortest_path_function->alias = "path"; select_node->select_list.push_back(std::move(shortest_path_function)); - auto src_rowid_outer_select = make_uniq("rowid", previous_vertex_element->variable_binding); + auto src_rowid_outer_select = make_uniq( + "rowid", previous_vertex_element->variable_binding); src_rowid_outer_select->alias = "src_rowid"; select_node->select_list.push_back(std::move(src_rowid_outer_select)); - auto dst_rowid_outer_select = make_uniq("rowid", next_vertex_element->variable_binding); + auto dst_rowid_outer_select = make_uniq( + "rowid", next_vertex_element->variable_binding); dst_rowid_outer_select->alias = "dst_rowid"; select_node->select_list.push_back(std::move(dst_rowid_outer_select)); @@ -418,7 +423,6 @@ unique_ptr PGQMatchFunction::CreatePathFindingFunction( // full list of element rowids, using list_concat. For now we will only // support returning rowids - unique_ptr final_list; vector> path_finding_conditions; auto previous_vertex_element = GetPathElement(path_list[0]); @@ -427,8 +431,7 @@ unique_ptr PGQMatchFunction::CreatePathFindingFunction( // We hit a vertex element with a WHERE, but we only care about the rowid // here // In the future this might be a recursive path pattern - previous_vertex_subpath = - reinterpret_cast(path_list[0].get()); + previous_vertex_subpath = reinterpret_cast(path_list[0].get()); previous_vertex_element = GetPathElement(previous_vertex_subpath->path_list[0]); } @@ -449,28 +452,42 @@ unique_ptr PGQMatchFunction::CreatePathFindingFunction( // (un)bounded shortest path // Add the shortest path UDF as a CTE if (previous_vertex_subpath) { - path_finding_conditions.push_back(std::move(previous_vertex_subpath->where_clause)); + path_finding_conditions.push_back( + std::move(previous_vertex_subpath->where_clause)); } if (next_vertex_subpath) { - path_finding_conditions.push_back(std::move(next_vertex_subpath->where_clause)); + path_finding_conditions.push_back( + std::move(next_vertex_subpath->where_clause)); } - if (final_select_node->cte_map.map.find("cte1") == final_select_node->cte_map.map.end()) { - edge_element = reinterpret_cast(edge_subpath->path_list[0].get()); + if (final_select_node->cte_map.map.find("cte1") == + final_select_node->cte_map.map.end()) { + edge_element = + reinterpret_cast(edge_subpath->path_list[0].get()); if (edge_element->match_type == PGQMatchType::MATCH_EDGE_RIGHT) { - final_select_node->cte_map.map["cte1"] = - CreateDirectedCSRCTE(FindGraphTable(edge_element->label, pg_table), previous_vertex_element->variable_binding, edge_element->variable_binding, next_vertex_element->variable_binding); + final_select_node->cte_map.map["cte1"] = CreateDirectedCSRCTE( + FindGraphTable(edge_element->label, pg_table), + previous_vertex_element->variable_binding, + edge_element->variable_binding, + next_vertex_element->variable_binding); } else if (edge_element->match_type == PGQMatchType::MATCH_EDGE_ANY) { - final_select_node->cte_map.map["cte1"] = - CreateUndirectedCSRCTE(FindGraphTable(edge_element->label, pg_table), final_select_node); + final_select_node->cte_map.map["cte1"] = CreateUndirectedCSRCTE( + FindGraphTable(edge_element->label, pg_table), + final_select_node); } else { - throw NotImplementedException("Cannot do shortest path for edge type %s", edge_element->match_type == PGQMatchType::MATCH_EDGE_LEFT ? "MATCH_EDGE_LEFT" : "MATCH_EDGE_LEFT_RIGHT"); + throw NotImplementedException( + "Cannot do shortest path for edge type %s", + edge_element->match_type == PGQMatchType::MATCH_EDGE_LEFT + ? "MATCH_EDGE_LEFT" + : "MATCH_EDGE_LEFT_RIGHT"); } } - string shortest_path_cte_name = "shortest_path_cte" ; - if (final_select_node->cte_map.map.find(shortest_path_cte_name) == final_select_node->cte_map.map.end()) { + string shortest_path_cte_name = "shortest_path_cte"; + if (final_select_node->cte_map.map.find(shortest_path_cte_name) == + final_select_node->cte_map.map.end()) { final_select_node->cte_map.map[shortest_path_cte_name] = - GenerateShortestPathCTE(pg_table, edge_subpath, previous_vertex_element, - next_vertex_element, path_finding_conditions); + GenerateShortestPathCTE( + pg_table, edge_subpath, previous_vertex_element, + next_vertex_element, path_finding_conditions); auto cte_shortest_path_ref = make_uniq(); cte_shortest_path_ref->table_name = shortest_path_cte_name; if (!final_select_node->from_table) { @@ -482,13 +499,21 @@ unique_ptr PGQMatchFunction::CreatePathFindingFunction( final_select_node->from_table = std::move(join_ref); } - conditions.push_back(make_uniq(ExpressionType::COMPARE_EQUAL, - make_uniq("src_rowid", shortest_path_cte_name), make_uniq("rowid", previous_vertex_element->variable_binding))); - conditions.push_back(make_uniq(ExpressionType::COMPARE_EQUAL, - make_uniq("dst_rowid", shortest_path_cte_name), make_uniq("rowid", next_vertex_element->variable_binding))); - + conditions.push_back(make_uniq( + ExpressionType::COMPARE_EQUAL, + make_uniq("src_rowid", + shortest_path_cte_name), + make_uniq( + "rowid", previous_vertex_element->variable_binding))); + conditions.push_back(make_uniq( + ExpressionType::COMPARE_EQUAL, + make_uniq("dst_rowid", + shortest_path_cte_name), + make_uniq( + "rowid", next_vertex_element->variable_binding))); } - auto shortest_path_ref = make_uniq("path", shortest_path_cte_name); + auto shortest_path_ref = + make_uniq("path", shortest_path_cte_name); if (!final_list) { final_list = std::move(shortest_path_ref); } else { @@ -556,8 +581,8 @@ void PGQMatchFunction::AddEdgeJoins( PGQMatchType edge_type, const string &edge_binding, const string &prev_binding, const string &next_binding, vector> &conditions, - case_insensitive_map_t> &alias_map, int32_t &extra_alias_counter, - unique_ptr &from_clause) { + case_insensitive_map_t> &alias_map, + int32_t &extra_alias_counter, unique_ptr &from_clause) { if (edge_type != PGQMatchType::MATCH_EDGE_ANY) { alias_map[edge_binding] = edge_table; } @@ -597,8 +622,8 @@ unique_ptr PGQMatchFunction::AddPathQuantifierCondition( vector> pathfinding_children; pathfinding_children.push_back(std::move(csr_id)); - pathfinding_children.push_back( - std::move(GetCountTable(edge_table->source_pg_table, prev_binding, edge_table->source_pk[0]))); + pathfinding_children.push_back(std::move(GetCountTable( + edge_table->source_pg_table, prev_binding, edge_table->source_pk[0]))); pathfinding_children.push_back(std::move(src_row_id)); pathfinding_children.push_back(std::move(dst_row_id)); @@ -635,8 +660,8 @@ void PGQMatchFunction::AddPathFinding( //! FROM (SELECT count(cte1.temp) * 0 as temp from cte1) __x if (select_node->cte_map.map.find("cte1") == select_node->cte_map.map.end()) { if (edge_type == PGQMatchType::MATCH_EDGE_RIGHT) { - select_node->cte_map.map["cte1"] = - CreateDirectedCSRCTE(edge_table, prev_binding, edge_binding, next_binding); + select_node->cte_map.map["cte1"] = CreateDirectedCSRCTE( + edge_table, prev_binding, edge_binding, next_binding); } else if (edge_type == PGQMatchType::MATCH_EDGE_ANY) { select_node->cte_map.map["cte1"] = CreateUndirectedCSRCTE(edge_table, select_node); @@ -647,7 +672,8 @@ void PGQMatchFunction::AddPathFinding( : "MATCH_EDGE_LEFT_RIGHT"); } } - if (select_node->cte_map.map.find("shortest_path_cte") != select_node->cte_map.map.end()) { + if (select_node->cte_map.map.find("shortest_path_cte") != + select_node->cte_map.map.end()) { return; } auto temp_cte_select_subquery = CreateCountCTESubquery(); @@ -676,11 +702,12 @@ void PGQMatchFunction::AddPathFinding( void PGQMatchFunction::CheckNamedSubpath( SubPath &subpath, MatchExpression &original_ref, - CreatePropertyGraphInfo &pg_table, unique_ptr &final_select_node, + CreatePropertyGraphInfo &pg_table, + unique_ptr &final_select_node, vector> &conditions) { for (idx_t idx_i = 0; idx_i < original_ref.column_list.size(); idx_i++) { - auto parsed_ref = - dynamic_cast(original_ref.column_list[idx_i].get()); + auto parsed_ref = dynamic_cast( + original_ref.column_list[idx_i].get()); if (parsed_ref == nullptr) { continue; } @@ -698,8 +725,9 @@ void PGQMatchFunction::CheckNamedSubpath( if (parsed_ref->function_name == "element_id") { // Check subpath name matches the column referenced in the function --> // element_id(named_subpath) - auto shortest_path_function = - CreatePathFindingFunction(subpath.path_list, pg_table, subpath.path_variable, final_select_node, conditions); + auto shortest_path_function = CreatePathFindingFunction( + subpath.path_list, pg_table, subpath.path_variable, final_select_node, + conditions); if (column_alias.empty()) { shortest_path_function->alias = @@ -709,10 +737,11 @@ void PGQMatchFunction::CheckNamedSubpath( } original_ref.column_list.erase(original_ref.column_list.begin() + idx_i); original_ref.column_list.insert(original_ref.column_list.begin() + idx_i, - std::move(shortest_path_function)); + std::move(shortest_path_function)); } else if (parsed_ref->function_name == "path_length") { - auto shortest_path_function = - CreatePathFindingFunction(subpath.path_list, pg_table, subpath.path_variable, final_select_node, conditions); + auto shortest_path_function = CreatePathFindingFunction( + subpath.path_list, pg_table, subpath.path_variable, final_select_node, + conditions); auto path_len_children = vector>(); path_len_children.push_back(std::move(shortest_path_function)); auto path_len = @@ -728,12 +757,13 @@ void PGQMatchFunction::CheckNamedSubpath( : column_alias; original_ref.column_list.erase(original_ref.column_list.begin() + idx_i); original_ref.column_list.insert(original_ref.column_list.begin() + idx_i, - std::move(path_length_function)); + std::move(path_length_function)); } else if (parsed_ref->function_name == "vertices" || parsed_ref->function_name == "edges") { auto list_slice_children = vector>(); - auto shortest_path_function = - CreatePathFindingFunction(subpath.path_list, pg_table, subpath.path_variable, final_select_node, conditions); + auto shortest_path_function = CreatePathFindingFunction( + subpath.path_list, pg_table, subpath.path_variable, final_select_node, + conditions); list_slice_children.push_back(std::move(shortest_path_function)); if (parsed_ref->function_name == "vertices") { @@ -760,7 +790,8 @@ void PGQMatchFunction::CheckNamedSubpath( : column_alias; } original_ref.column_list.erase(original_ref.column_list.begin() + idx_i); - original_ref.column_list.insert(original_ref.column_list.begin() + idx_i, std::move(list_slice)); + original_ref.column_list.insert(original_ref.column_list.begin() + idx_i, + std::move(list_slice)); } } } @@ -769,8 +800,8 @@ void PGQMatchFunction::ProcessPathList( vector> &path_list, vector> &conditions, unique_ptr &final_select_node, - case_insensitive_map_t> &alias_map, CreatePropertyGraphInfo &pg_table, - int32_t &extra_alias_counter, + case_insensitive_map_t> &alias_map, + CreatePropertyGraphInfo &pg_table, int32_t &extra_alias_counter, MatchExpression &original_ref) { PathElement *previous_vertex_element = GetPathElement(path_list[0]); if (!previous_vertex_element) { @@ -779,8 +810,10 @@ void PGQMatchFunction::ProcessPathList( if (previous_vertex_subpath->where_clause) { conditions.push_back(std::move(previous_vertex_subpath->where_clause)); } - if (!previous_vertex_subpath->path_variable.empty() && previous_vertex_subpath->path_list.size() > 1) { - CheckNamedSubpath(*previous_vertex_subpath, original_ref, pg_table, final_select_node, conditions); + if (!previous_vertex_subpath->path_variable.empty() && + previous_vertex_subpath->path_list.size() > 1) { + CheckNamedSubpath(*previous_vertex_subpath, original_ref, pg_table, + final_select_node, conditions); } if (previous_vertex_subpath->path_list.size() == 1) { previous_vertex_element = @@ -796,8 +829,7 @@ void PGQMatchFunction::ProcessPathList( auto previous_vertex_table = FindGraphTable(previous_vertex_element->label, pg_table); CheckInheritance(previous_vertex_table, previous_vertex_element, conditions); - alias_map[previous_vertex_element->variable_binding] = - previous_vertex_table; + alias_map[previous_vertex_element->variable_binding] = previous_vertex_table; for (idx_t idx_j = 1; idx_j < path_list.size(); idx_j = idx_j + 2) { PathElement *next_vertex_element = GetPathElement(path_list[idx_j + 1]); @@ -839,16 +871,17 @@ void PGQMatchFunction::ProcessPathList( if (edge_subpath->upper > 1) { // Add the path-finding AddPathFinding(final_select_node, conditions, - previous_vertex_element->variable_binding, + previous_vertex_element->variable_binding, edge_element->variable_binding, next_vertex_element->variable_binding, edge_table, - pg_table, edge_subpath, edge_element->match_type); - } else { - AddEdgeJoins(edge_table, previous_vertex_table, next_vertex_table, - edge_element->match_type, edge_element->variable_binding, - previous_vertex_element->variable_binding, - next_vertex_element->variable_binding, conditions, - alias_map, extra_alias_counter, final_select_node->from_table); + pg_table, edge_subpath, edge_element->match_type); + } else { + AddEdgeJoins(edge_table, previous_vertex_table, next_vertex_table, + edge_element->match_type, edge_element->variable_binding, + previous_vertex_element->variable_binding, + next_vertex_element->variable_binding, conditions, + alias_map, extra_alias_counter, + final_select_node->from_table); } } else { // The edge element is a path element without WHERE or path-finding. @@ -894,9 +927,8 @@ PGQMatchFunction::MatchBindReplace(ClientContext &context, auto &path_pattern = ref->path_patterns[idx_i]; // Check if the element is PathElement or a Subpath with potentially many // items - ProcessPathList(path_pattern->path_elements, conditions, - final_select_node, alias_map, *pg_table, extra_alias_counter, - *ref); + ProcessPathList(path_pattern->path_elements, conditions, final_select_node, + alias_map, *pg_table, extra_alias_counter, *ref); } // Go through all aliases encountered @@ -993,4 +1025,4 @@ void CoreTableFunctions::RegisterMatchTableFunction(DatabaseInstance &db) { } // namespace core -} // namespace duckdb +} // namespace duckpgq diff --git a/src/core/functions/table/pagerank.cpp b/src/core/functions/table/pagerank.cpp index 40fa6d42..36e808ea 100644 --- a/src/core/functions/table/pagerank.cpp +++ b/src/core/functions/table/pagerank.cpp @@ -10,18 +10,22 @@ namespace duckpgq { namespace core { // Main binding function -unique_ptr PageRankFunction::PageRankBindReplace(ClientContext &context, TableFunctionBindInput &input) { +unique_ptr +PageRankFunction::PageRankBindReplace(ClientContext &context, + TableFunctionBindInput &input) { auto pg_name = StringUtil::Lower(StringValue::Get(input.inputs[0])); auto node_table = StringUtil::Lower(StringValue::Get(input.inputs[1])); auto edge_table = StringUtil::Lower(StringValue::Get(input.inputs[2])); auto duckpgq_state = GetDuckPGQState(context); auto pg_info = GetPropertyGraphInfo(duckpgq_state, pg_name); - auto edge_pg_entry = ValidateSourceNodeAndEdgeTable(pg_info, node_table, edge_table); + auto edge_pg_entry = + ValidateSourceNodeAndEdgeTable(pg_info, node_table, edge_table); auto select_node = CreateSelectNode(edge_pg_entry, "pagerank", "pagerank"); - select_node->cte_map.map["csr_cte"] = CreateDirectedCSRCTE(edge_pg_entry, "src", "edge", "dst"); + select_node->cte_map.map["csr_cte"] = + CreateDirectedCSRCTE(edge_pg_entry, "src", "edge", "dst"); auto subquery = make_uniq(); subquery->node = std::move(select_node); diff --git a/src/core/functions/table/pgq_scan.cpp b/src/core/functions/table/pgq_scan.cpp index 82ae3dc3..4356519c 100644 --- a/src/core/functions/table/pgq_scan.cpp +++ b/src/core/functions/table/pgq_scan.cpp @@ -14,7 +14,6 @@ namespace duckpgq { namespace core { - static void ScanCSREFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { bool &gstate = ((CSRScanState &)*data_p.global_state).finished; @@ -246,30 +245,47 @@ static void ScanPGEColFunction(ClientContext &context, // Register functions //------------------------------------------------------------------------------ void CoreTableFunctions::RegisterScanTableFunctions(DatabaseInstance &db) { - ExtensionUtil::RegisterFunction(db, TableFunction("get_csr_e", {LogicalType::INTEGER}, ScanCSREFunction, - CSRScanEData::ScanCSREBind, CSRScanState::Init)); - - ExtensionUtil::RegisterFunction(db, TableFunction("get_csr_v", {LogicalType::INTEGER}, ScanCSRVFunction, - CSRScanVData::ScanCSRVBind, CSRScanState::Init)); - - ExtensionUtil::RegisterFunction(db, TableFunction("get_csr_w", {LogicalType::INTEGER}, ScanCSRWFunction, - CSRScanWData::ScanCSRWBind, CSRScanState::Init)); - - ExtensionUtil::RegisterFunction(db, TableFunction("get_pg_vtablenames", {LogicalType::VARCHAR}, ScanPGVTableFunction, - PGScanVTableData::ScanPGVTableBind, CSRScanState::Init)); - - ExtensionUtil::RegisterFunction(db, TableFunction("get_pg_vcolnames", {LogicalType::VARCHAR, LogicalType::VARCHAR}, ScanPGVColFunction, - PGScanVColData::ScanPGVColBind, CSRScanState::Init)); - - ExtensionUtil::RegisterFunction(db, TableFunction("get_csr_ptr", {LogicalType::INTEGER}, ScanCSRPtrFunction, + ExtensionUtil::RegisterFunction( + db, TableFunction("get_csr_e", {LogicalType::INTEGER}, ScanCSREFunction, + CSRScanEData::ScanCSREBind, CSRScanState::Init)); + + ExtensionUtil::RegisterFunction( + db, TableFunction("get_csr_v", {LogicalType::INTEGER}, ScanCSRVFunction, + CSRScanVData::ScanCSRVBind, CSRScanState::Init)); + + ExtensionUtil::RegisterFunction( + db, TableFunction("get_csr_w", {LogicalType::INTEGER}, ScanCSRWFunction, + CSRScanWData::ScanCSRWBind, CSRScanState::Init)); + + ExtensionUtil::RegisterFunction( + db, + TableFunction("get_pg_vtablenames", {LogicalType::VARCHAR}, + ScanPGVTableFunction, PGScanVTableData::ScanPGVTableBind, + CSRScanState::Init)); + + ExtensionUtil::RegisterFunction( + db, TableFunction("get_pg_vcolnames", + {LogicalType::VARCHAR, LogicalType::VARCHAR}, + ScanPGVColFunction, PGScanVColData::ScanPGVColBind, + CSRScanState::Init)); + + ExtensionUtil::RegisterFunction( + db, + TableFunction("get_csr_ptr", {LogicalType::INTEGER}, ScanCSRPtrFunction, CSRScanPtrData::ScanCSRPtrBind, CSRScanState::Init)); - ExtensionUtil::RegisterFunction(db, TableFunction("get_pg_etablenames", {LogicalType::VARCHAR}, ScanPGETableFunction, - PGScanETableData::ScanPGETableBind, CSRScanState::Init)); - - ExtensionUtil::RegisterFunction(db, TableFunction("get_pg_ecolnames", {LogicalType::VARCHAR, LogicalType::VARCHAR}, ScanPGEColFunction, - PGScanEColData::ScanPGEColBind, CSRScanState::Init)); + ExtensionUtil::RegisterFunction( + db, + TableFunction("get_pg_etablenames", {LogicalType::VARCHAR}, + ScanPGETableFunction, PGScanETableData::ScanPGETableBind, + CSRScanState::Init)); + + ExtensionUtil::RegisterFunction( + db, TableFunction("get_pg_ecolnames", + {LogicalType::VARCHAR, LogicalType::VARCHAR}, + ScanPGEColFunction, PGScanEColData::ScanPGEColBind, + CSRScanState::Init)); } } // namespace core -} // namespace duckdb +} // namespace duckpgq diff --git a/src/core/functions/table/weakly_connected_component.cpp b/src/core/functions/table/weakly_connected_component.cpp index 2f076631..1b863656 100644 --- a/src/core/functions/table/weakly_connected_component.cpp +++ b/src/core/functions/table/weakly_connected_component.cpp @@ -10,18 +10,23 @@ namespace duckpgq { namespace core { // Main binding function -unique_ptr WeaklyConnectedComponentFunction::WeaklyConnectedComponentBindReplace(ClientContext &context, TableFunctionBindInput &input) { +unique_ptr +WeaklyConnectedComponentFunction::WeaklyConnectedComponentBindReplace( + ClientContext &context, TableFunctionBindInput &input) { auto pg_name = StringUtil::Lower(StringValue::Get(input.inputs[0])); auto node_table = StringUtil::Lower(StringValue::Get(input.inputs[1])); auto edge_table = StringUtil::Lower(StringValue::Get(input.inputs[2])); auto duckpgq_state = GetDuckPGQState(context); auto pg_info = GetPropertyGraphInfo(duckpgq_state, pg_name); - auto edge_pg_entry = ValidateSourceNodeAndEdgeTable(pg_info, node_table, edge_table); + auto edge_pg_entry = + ValidateSourceNodeAndEdgeTable(pg_info, node_table, edge_table); - auto select_node = CreateSelectNode(edge_pg_entry, "weakly_connected_component", "componentId"); + auto select_node = CreateSelectNode( + edge_pg_entry, "weakly_connected_component", "componentId"); - select_node->cte_map.map["csr_cte"] = CreateUndirectedCSRCTE(edge_pg_entry, select_node); + select_node->cte_map.map["csr_cte"] = + CreateUndirectedCSRCTE(edge_pg_entry, select_node); auto subquery = make_uniq(); subquery->node = std::move(select_node); @@ -34,7 +39,8 @@ unique_ptr WeaklyConnectedComponentFunction::WeaklyConnectedComponentB //------------------------------------------------------------------------------ // Register functions //------------------------------------------------------------------------------ -void CoreTableFunctions::RegisterWeaklyConnectedComponentTableFunction(DatabaseInstance &db) { +void CoreTableFunctions::RegisterWeaklyConnectedComponentTableFunction( + DatabaseInstance &db) { ExtensionUtil::RegisterFunction(db, WeaklyConnectedComponentFunction()); } diff --git a/src/core/module.cpp b/src/core/module.cpp index c4887294..ef4d9d39 100644 --- a/src/core/module.cpp +++ b/src/core/module.cpp @@ -11,13 +11,12 @@ namespace duckpgq { namespace core { void CoreModule::Register(DatabaseInstance &db) { - CoreTableFunctions::Register(db); - CoreScalarFunctions::Register(db); - CorePGQParser::Register(db); - CorePGQOperator::Register(db); + CoreTableFunctions::Register(db); + CoreScalarFunctions::Register(db); + CorePGQParser::Register(db); + CorePGQOperator::Register(db); } - } // namespace core } // namespace duckpgq \ No newline at end of file diff --git a/src/core/operator/duckpgq_bind.cpp b/src/core/operator/duckpgq_bind.cpp index 7b11d547..0d42bc6f 100644 --- a/src/core/operator/duckpgq_bind.cpp +++ b/src/core/operator/duckpgq_bind.cpp @@ -29,13 +29,11 @@ BoundStatement duckpgq_bind(ClientContext &context, Binder &binder, //------------------------------------------------------------------------------ // Register functions //------------------------------------------------------------------------------ -void CorePGQOperator::RegisterPGQBindOperator( - DatabaseInstance &db) { +void CorePGQOperator::RegisterPGQBindOperator(DatabaseInstance &db) { auto &config = DBConfig::GetConfig(db); config.operator_extensions.push_back(make_uniq()); - } -} +} // namespace core } // namespace duckpgq \ No newline at end of file diff --git a/src/core/parser/duckpgq_parser.cpp b/src/core/parser/duckpgq_parser.cpp index 96e141bb..4d3d395a 100644 --- a/src/core/parser/duckpgq_parser.cpp +++ b/src/core/parser/duckpgq_parser.cpp @@ -35,8 +35,6 @@ ParserExtensionParseResult duckpgq_parse(ParserExtensionInfo *info, std::move(parser.statements[0]))); } - - void duckpgq_find_match_function(TableRef *table_ref, DuckPGQState &duckpgq_state) { if (auto table_function_ref = dynamic_cast(table_ref)) { @@ -60,7 +58,6 @@ void duckpgq_find_match_function(TableRef *table_ref, } } - ParserExtensionPlanResult duckpgq_handle_statement(SQLStatement *statement, DuckPGQState &duckpgq_state) { if (statement->type == StatementType::SELECT_STATEMENT) { @@ -80,7 +77,8 @@ duckpgq_handle_statement(SQLStatement *statement, DuckPGQState &duckpgq_state) { // Check if node is a ShowRef if (node) { - const auto describe_node = dynamic_cast(node->from_table.get()); + const auto describe_node = + dynamic_cast(node->from_table.get()); if (describe_node) { ParserExtensionPlanResult result; result.function = DescribePropertyGraphFunction(); @@ -99,11 +97,13 @@ duckpgq_handle_statement(SQLStatement *statement, DuckPGQState &duckpgq_state) { } for (auto &key : cte_keys) { auto cte = node->cte_map.map.find(key); - auto cte_select_statement = dynamic_cast(cte->second->query.get()); + auto cte_select_statement = + dynamic_cast(cte->second->query.get()); if (cte_select_statement == nullptr) { continue; // Skip non-select statements } - auto cte_node = dynamic_cast(cte_select_statement->node.get()); + auto cte_node = + dynamic_cast(cte_select_statement->node.get()); if (cte_node) { duckpgq_find_match_function(cte_node->from_table.get(), duckpgq_state); } @@ -164,7 +164,8 @@ duckpgq_plan(ParserExtensionInfo *, ClientContext &context, unique_ptr parse_data) { auto duckpgq_state = context.registered_state->Get("duckpgq"); if (duckpgq_state == nullptr) { - throw Exception(ExceptionType::INVALID, "DuckPGQ extension has not been properly initialized"); + throw Exception(ExceptionType::INVALID, + "DuckPGQ extension has not been properly initialized"); } duckpgq_state->parse_data = std::move(parse_data); auto duckpgq_parse_data = @@ -178,12 +179,10 @@ duckpgq_plan(ParserExtensionInfo *, ClientContext &context, return duckpgq_handle_statement(statement, *duckpgq_state); } - //------------------------------------------------------------------------------ // Register functions //------------------------------------------------------------------------------ -void CorePGQParser::RegisterPGQParserExtension( - DatabaseInstance &db) { +void CorePGQParser::RegisterPGQParserExtension(DatabaseInstance &db) { auto &config = DBConfig::GetConfig(db); config.parser_extensions.push_back(DuckPGQParserExtension()); } diff --git a/src/core/utils/compressed_sparse_row.cpp b/src/core/utils/compressed_sparse_row.cpp index bea9d714..346b4e12 100644 --- a/src/core/utils/compressed_sparse_row.cpp +++ b/src/core/utils/compressed_sparse_row.cpp @@ -50,7 +50,6 @@ bool CSRFunctionData::Equals(const FunctionData &other_p) const { return id == other.id && weight_type == other.weight_type; } - unique_ptr CSRFunctionData::CSRVertexBind(ClientContext &context, ScalarFunction &bound_function, @@ -66,8 +65,7 @@ CSRFunctionData::CSRVertexBind(ClientContext &context, logical_type); } return make_uniq(context, id.GetValue(), - arguments[3]->return_type); - + arguments[3]->return_type); } unique_ptr @@ -98,39 +96,54 @@ CSRFunctionData::CSRBind(ClientContext &context, ScalarFunction &bound_function, LogicalType::BOOLEAN); } - // Helper function to create a JoinRef unique_ptr CreateJoin(const string &fk_column, const string &pk_column, - const shared_ptr &fk_table, const shared_ptr &pk_table) { + const shared_ptr &fk_table, + const shared_ptr &pk_table) { auto join = make_uniq(JoinRefType::REGULAR); join->left = fk_table->CreateBaseTableRef(); join->right = pk_table->CreateBaseTableRef(); - join->condition = make_uniq(ExpressionType::COMPARE_EQUAL, - make_uniq(fk_column, fk_table->table_name), make_uniq(pk_column, pk_table->table_name)); + join->condition = make_uniq( + ExpressionType::COMPARE_EQUAL, + make_uniq(fk_column, fk_table->table_name), + make_uniq(pk_column, pk_table->table_name)); return join; } - // Helper function to setup SelectNode -void SetupSelectNode(unique_ptr &select_node, const shared_ptr &edge_table, bool reverse) { +void SetupSelectNode(unique_ptr &select_node, + const shared_ptr &edge_table, + bool reverse) { select_node = make_uniq(); - select_node->select_list.emplace_back(CreateColumnRefExpression("rowid", edge_table->source_reference, "dense_id")); + select_node->select_list.emplace_back(CreateColumnRefExpression( + "rowid", edge_table->source_reference, "dense_id")); if (!reverse) { - select_node->select_list.emplace_back(CreateColumnRefExpression(edge_table->source_fk[0], edge_table->table_name, "outgoing_edges")); - select_node->select_list.emplace_back(CreateColumnRefExpression(edge_table->destination_fk[0], edge_table->table_name, "incoming_edges")); - select_node->from_table = CreateJoin(edge_table->source_fk[0], edge_table->source_pk[0], edge_table, edge_table->source_pg_table); + select_node->select_list.emplace_back(CreateColumnRefExpression( + edge_table->source_fk[0], edge_table->table_name, "outgoing_edges")); + select_node->select_list.emplace_back( + CreateColumnRefExpression(edge_table->destination_fk[0], + edge_table->table_name, "incoming_edges")); + select_node->from_table = + CreateJoin(edge_table->source_fk[0], edge_table->source_pk[0], + edge_table, edge_table->source_pg_table); } else { - select_node->select_list.emplace_back(CreateColumnRefExpression(edge_table->destination_fk[0], edge_table->table_name, "outgoing_edges")); - select_node->select_list.emplace_back(CreateColumnRefExpression(edge_table->source_fk[0], edge_table->table_name, "incoming_edges")); - select_node->from_table = CreateJoin(edge_table->destination_fk[0], edge_table->source_pk[0], edge_table, edge_table->source_pg_table); + select_node->select_list.emplace_back( + CreateColumnRefExpression(edge_table->destination_fk[0], + edge_table->table_name, "outgoing_edges")); + select_node->select_list.emplace_back(CreateColumnRefExpression( + edge_table->source_fk[0], edge_table->table_name, "incoming_edges")); + select_node->from_table = + CreateJoin(edge_table->destination_fk[0], edge_table->source_pk[0], + edge_table, edge_table->source_pg_table); } } - // Function to create a subquery expression for counting table entries -unique_ptr GetCountTable(const shared_ptr &table, const string &table_alias, const string &primary_key) { +unique_ptr +GetCountTable(const shared_ptr &table, + const string &table_alias, const string &primary_key) { auto select_count = make_uniq(); auto select_inner = make_uniq(); auto ref = table->CreateBaseTableRef(table_alias); @@ -139,7 +152,8 @@ unique_ptr GetCountTable(const shared_ptr> children; children.push_back(make_uniq(primary_key, table_alias)); - auto count_function = make_uniq("count", std::move(children)); + auto count_function = + make_uniq("count", std::move(children)); select_inner->select_list.push_back(std::move(count_function)); select_count->node = std::move(select_inner); @@ -149,11 +163,10 @@ unique_ptr GetCountTable(const shared_ptr -GetJoinRef(const shared_ptr &edge_table, - const string &edge_binding, - const string &prev_binding, - const string &next_binding) { +unique_ptr GetJoinRef(const shared_ptr &edge_table, + const string &edge_binding, + const string &prev_binding, + const string &next_binding) { auto first_join_ref = make_uniq(JoinRefType::REGULAR); first_join_ref->type = JoinType::INNER; @@ -161,7 +174,8 @@ GetJoinRef(const shared_ptr &edge_table, second_join_ref->type = JoinType::INNER; second_join_ref->left = edge_table->CreateBaseTableRef(edge_binding); - second_join_ref->right = edge_table->source_pg_table->CreateBaseTableRef(prev_binding); + second_join_ref->right = + edge_table->source_pg_table->CreateBaseTableRef(prev_binding); auto t_from_ref = make_uniq(edge_table->source_fk[0], edge_binding); auto src_cid_ref = @@ -170,7 +184,8 @@ GetJoinRef(const shared_ptr &edge_table, ExpressionType::COMPARE_EQUAL, std::move(t_from_ref), std::move(src_cid_ref)); first_join_ref->left = std::move(second_join_ref); - first_join_ref->right = edge_table->destination_pg_table->CreateBaseTableRef(next_binding); + first_join_ref->right = + edge_table->destination_pg_table->CreateBaseTableRef(next_binding); auto t_to_ref = make_uniq(edge_table->destination_fk[0], edge_binding); @@ -182,50 +197,66 @@ GetJoinRef(const shared_ptr &edge_table, return first_join_ref; } -unique_ptr CreateDirectedCSRVertexSubquery(const shared_ptr &edge_table, const string &prev_binding) { - auto count_create_vertex_expr = GetCountTable(edge_table->source_pg_table, prev_binding, edge_table->source_pk[0]); +unique_ptr CreateDirectedCSRVertexSubquery( + const shared_ptr &edge_table, + const string &prev_binding) { + auto count_create_vertex_expr = GetCountTable( + edge_table->source_pg_table, prev_binding, edge_table->source_pk[0]); vector> csr_vertex_children; - csr_vertex_children.push_back(make_uniq(Value::INTEGER(0))); + csr_vertex_children.push_back( + make_uniq(Value::INTEGER(0))); csr_vertex_children.push_back(std::move(count_create_vertex_expr)); - csr_vertex_children.push_back(make_uniq("dense_id", "sub")); + csr_vertex_children.push_back( + make_uniq("dense_id", "sub")); csr_vertex_children.push_back(make_uniq("cnt", "sub")); - auto create_vertex_function = make_uniq("create_csr_vertex", std::move(csr_vertex_children)); + auto create_vertex_function = make_uniq( + "create_csr_vertex", std::move(csr_vertex_children)); vector> sum_children; sum_children.push_back(std::move(create_vertex_function)); - auto sum_function = make_uniq("sum", std::move(sum_children)); + auto sum_function = + make_uniq("sum", std::move(sum_children)); auto inner_select_statement = make_uniq(); auto inner_select_node = make_uniq(); - inner_select_node->select_list.emplace_back(CreateColumnRefExpression("rowid", prev_binding, "dense_id")); - auto edge_src_colref = make_uniq(edge_table->source_fk[0], edge_table->table_name); + inner_select_node->select_list.emplace_back( + CreateColumnRefExpression("rowid", prev_binding, "dense_id")); + auto edge_src_colref = make_uniq( + edge_table->source_fk[0], edge_table->table_name); vector> count_children; count_children.push_back(std::move(edge_src_colref)); - auto count_function = make_uniq("count", std::move(count_children)); + auto count_function = + make_uniq("count", std::move(count_children)); count_function->alias = "cnt"; inner_select_node->select_list.emplace_back(std::move(count_function)); auto left_join_ref = make_uniq(JoinRefType::REGULAR); left_join_ref->type = JoinType::LEFT; - left_join_ref->left = edge_table->source_pg_table->CreateBaseTableRef(prev_binding); - left_join_ref->right = edge_table->CreateBaseTableRef(edge_table->table_name_alias); - - auto join_condition = make_uniq(ExpressionType::COMPARE_EQUAL, - make_uniq(edge_table->source_fk[0], edge_table->table_name), - make_uniq(edge_table->source_pk[0], prev_binding)); + left_join_ref->left = + edge_table->source_pg_table->CreateBaseTableRef(prev_binding); + left_join_ref->right = + edge_table->CreateBaseTableRef(edge_table->table_name_alias); + + auto join_condition = make_uniq( + ExpressionType::COMPARE_EQUAL, + make_uniq(edge_table->source_fk[0], + edge_table->table_name), + make_uniq(edge_table->source_pk[0], prev_binding)); left_join_ref->condition = std::move(join_condition); inner_select_node->from_table = std::move(left_join_ref); auto dense_id_colref = make_uniq("dense_id"); - inner_select_node->groups.group_expressions.push_back(std::move(dense_id_colref)); + inner_select_node->groups.group_expressions.push_back( + std::move(dense_id_colref)); GroupingSet grouping_set = {0}; inner_select_node->groups.grouping_sets.push_back(grouping_set); inner_select_statement->node = std::move(inner_select_node); - auto inner_from_subquery = make_uniq(std::move(inner_select_statement), "sub"); + auto inner_from_subquery = + make_uniq(std::move(inner_select_statement), "sub"); auto cast_select_node = make_uniq(); cast_select_node->from_table = std::move(inner_from_subquery); @@ -242,26 +273,33 @@ unique_ptr CreateDirectedCSRVertexSubquery(const shared_ptr< } // Helper function to create CSR Vertex Subquery -unique_ptr CreateUndirectedCSRVertexSubquery(const shared_ptr &edge_table, const string &binding) { - auto count_create_vertex_expr = GetCountTable(edge_table->source_pg_table, binding, edge_table->source_pk[0]); +unique_ptr CreateUndirectedCSRVertexSubquery( + const shared_ptr &edge_table, const string &binding) { + auto count_create_vertex_expr = GetCountTable( + edge_table->source_pg_table, binding, edge_table->source_pk[0]); vector> csr_vertex_children; - csr_vertex_children.push_back(make_uniq(Value::INTEGER(0))); + csr_vertex_children.push_back( + make_uniq(Value::INTEGER(0))); csr_vertex_children.push_back(std::move(count_create_vertex_expr)); - csr_vertex_children.push_back(make_uniq("dense_id", "sub")); + csr_vertex_children.push_back( + make_uniq("dense_id", "sub")); csr_vertex_children.push_back(make_uniq("cnt", "sub")); - auto create_vertex_function = make_uniq("create_csr_vertex", std::move(csr_vertex_children)); + auto create_vertex_function = make_uniq( + "create_csr_vertex", std::move(csr_vertex_children)); vector> sum_children; sum_children.push_back(std::move(create_vertex_function)); - auto sum_function = make_uniq("sum", std::move(sum_children)); + auto sum_function = + make_uniq("sum", std::move(sum_children)); vector> multiply_csr_vertex_children; auto two_constant = make_uniq(Value::INTEGER(2)); multiply_csr_vertex_children.push_back(std::move(two_constant)); multiply_csr_vertex_children.push_back(std::move(sum_function)); - auto multiply_function = make_uniq("multiply", std::move(multiply_csr_vertex_children)); + auto multiply_function = make_uniq( + "multiply", std::move(multiply_csr_vertex_children)); auto inner_select_statement = make_uniq(); auto inner_select_node = make_uniq(); @@ -273,18 +311,21 @@ unique_ptr CreateUndirectedCSRVertexSubquery(const shared_pt auto outgoing_edges_ref = make_uniq("outgoing_edges"); vector> inner_count_children; inner_count_children.push_back(std::move(outgoing_edges_ref)); - auto inner_count_function = make_uniq("count", std::move(inner_count_children)); + auto inner_count_function = + make_uniq("count", std::move(inner_count_children)); inner_count_function->alias = "cnt"; inner_select_node->select_list.push_back(std::move(dense_id_ref)); inner_select_node->select_list.push_back(std::move(inner_count_function)); auto dense_id_colref = make_uniq("dense_id"); - inner_select_node->groups.group_expressions.push_back(std::move(dense_id_colref)); + inner_select_node->groups.group_expressions.push_back( + std::move(dense_id_colref)); GroupingSet grouping_set = {0}; inner_select_node->groups.grouping_sets.push_back(grouping_set); - unique_ptr unique_edges_select_node, unique_edges_select_node_reverse; + unique_ptr unique_edges_select_node, + unique_edges_select_node_reverse; SetupSelectNode(unique_edges_select_node, edge_table, false); SetupSelectNode(unique_edges_select_node_reverse, edge_table, true); @@ -296,12 +337,14 @@ unique_ptr CreateUndirectedCSRVertexSubquery(const shared_pt auto subquery_select_statement = make_uniq(); subquery_select_statement->node = std::move(union_all_node); - auto unique_edges_subquery = make_uniq(std::move(subquery_select_statement), "unique_edges"); + auto unique_edges_subquery = make_uniq( + std::move(subquery_select_statement), "unique_edges"); inner_select_node->from_table = std::move(unique_edges_subquery); inner_select_statement->node = std::move(inner_select_node); - auto inner_from_subquery = make_uniq(std::move(inner_select_statement), "sub"); + auto inner_from_subquery = + make_uniq(std::move(inner_select_statement), "sub"); auto cast_select_node = make_uniq(); cast_select_node->from_table = std::move(inner_from_subquery); @@ -317,21 +360,25 @@ unique_ptr CreateUndirectedCSRVertexSubquery(const shared_pt return cast_subquery_expr; } - // Helper function to create outer select edges node unique_ptr CreateOuterSelectEdgesNode() { auto outer_select_edges_node = make_uniq(); - outer_select_edges_node->select_list.push_back(make_uniq("src")); - outer_select_edges_node->select_list.push_back(make_uniq("dst")); + outer_select_edges_node->select_list.push_back( + make_uniq("src")); + outer_select_edges_node->select_list.push_back( + make_uniq("dst")); vector> any_value_children; any_value_children.push_back(make_uniq("edges")); - auto any_value_function = make_uniq("any_value", std::move(any_value_children)); + auto any_value_function = + make_uniq("any_value", std::move(any_value_children)); any_value_function->alias = "edge"; outer_select_edges_node->select_list.push_back(std::move(any_value_function)); - outer_select_edges_node->groups.group_expressions.push_back(make_uniq("src")); - outer_select_edges_node->groups.group_expressions.push_back(make_uniq("dst")); + outer_select_edges_node->groups.group_expressions.push_back( + make_uniq("src")); + outer_select_edges_node->groups.group_expressions.push_back( + make_uniq("dst")); GroupingSet outer_grouping_set = {0, 1}; outer_select_edges_node->groups.grouping_sets.push_back(outer_grouping_set); @@ -339,7 +386,8 @@ unique_ptr CreateOuterSelectEdgesNode() { } // Helper function to create outer select node -unique_ptr CreateOuterSelectNode(unique_ptr create_csr_edge_function) { +unique_ptr +CreateOuterSelectNode(unique_ptr create_csr_edge_function) { auto outer_select_node = make_uniq(); create_csr_edge_function->alias = "temp"; outer_select_node->select_list.push_back(std::move(create_csr_edge_function)); @@ -347,73 +395,90 @@ unique_ptr CreateOuterSelectNode(unique_ptr crea } // Function to create the CTE for the edges -unique_ptr MakeEdgesCTE(const shared_ptr &edge_table) { - std::vector> select_expression; - auto src_col_ref = make_uniq("rowid", "src_table"); - src_col_ref->alias = "src"; +unique_ptr +MakeEdgesCTE(const shared_ptr &edge_table) { + std::vector> select_expression; + auto src_col_ref = make_uniq("rowid", "src_table"); + src_col_ref->alias = "src"; - select_expression.emplace_back(std::move(src_col_ref)); + select_expression.emplace_back(std::move(src_col_ref)); - auto dst_col_ref = make_uniq("rowid", "dst_table"); - dst_col_ref->alias = "dst"; - select_expression.emplace_back(std::move(dst_col_ref)); + auto dst_col_ref = make_uniq("rowid", "dst_table"); + dst_col_ref->alias = "dst"; + select_expression.emplace_back(std::move(dst_col_ref)); - auto edge_col_ref = make_uniq("rowid", edge_table->table_name); - edge_col_ref->alias = "edges"; - select_expression.emplace_back(std::move(edge_col_ref)); + auto edge_col_ref = + make_uniq("rowid", edge_table->table_name); + edge_col_ref->alias = "edges"; + select_expression.emplace_back(std::move(edge_col_ref)); - auto select_node = make_uniq(); - select_node->select_list = std::move(select_expression); + auto select_node = make_uniq(); + select_node->select_list = std::move(select_expression); - auto join_ref = make_uniq(JoinRefType::REGULAR); - auto first_join_ref = make_uniq(JoinRefType::REGULAR); - first_join_ref->type = JoinType::INNER; - first_join_ref->left = edge_table->CreateBaseTableRef(); - first_join_ref->right = edge_table->source_pg_table->CreateBaseTableRef("src_table"); + auto join_ref = make_uniq(JoinRefType::REGULAR); + auto first_join_ref = make_uniq(JoinRefType::REGULAR); + first_join_ref->type = JoinType::INNER; + first_join_ref->left = edge_table->CreateBaseTableRef(); + first_join_ref->right = + edge_table->source_pg_table->CreateBaseTableRef("src_table"); - auto edge_from_ref = make_uniq(edge_table->source_fk[0], edge_table->table_name); - auto src_cid_ref = make_uniq(edge_table->source_pk[0], "src_table"); - first_join_ref->condition = make_uniq(ExpressionType::COMPARE_EQUAL, std::move(edge_from_ref), std::move(src_cid_ref)); + auto edge_from_ref = make_uniq(edge_table->source_fk[0], + edge_table->table_name); + auto src_cid_ref = + make_uniq(edge_table->source_pk[0], "src_table"); + first_join_ref->condition = make_uniq( + ExpressionType::COMPARE_EQUAL, std::move(edge_from_ref), + std::move(src_cid_ref)); - auto second_join_ref = make_uniq(JoinRefType::REGULAR); - second_join_ref->type = JoinType::INNER; - second_join_ref->left = std::move(first_join_ref); - second_join_ref->right = edge_table->destination_pg_table->CreateBaseTableRef("dst_table"); + auto second_join_ref = make_uniq(JoinRefType::REGULAR); + second_join_ref->type = JoinType::INNER; + second_join_ref->left = std::move(first_join_ref); + second_join_ref->right = + edge_table->destination_pg_table->CreateBaseTableRef("dst_table"); - auto edge_to_ref = make_uniq(edge_table->destination_fk[0], edge_table->table_name); - auto dst_cid_ref = make_uniq(edge_table->destination_pk[0], "dst_table"); - second_join_ref->condition = make_uniq(ExpressionType::COMPARE_EQUAL, std::move(edge_to_ref), std::move(dst_cid_ref)); + auto edge_to_ref = make_uniq( + edge_table->destination_fk[0], edge_table->table_name); + auto dst_cid_ref = make_uniq( + edge_table->destination_pk[0], "dst_table"); + second_join_ref->condition = make_uniq( + ExpressionType::COMPARE_EQUAL, std::move(edge_to_ref), + std::move(dst_cid_ref)); - select_node->from_table = std::move(second_join_ref); + select_node->from_table = std::move(second_join_ref); - auto select_statement = make_uniq(); - select_statement->node = std::move(select_node); + auto select_statement = make_uniq(); + select_statement->node = std::move(select_node); - auto result = make_uniq(); - result->query = std::move(select_statement); - return result; + auto result = make_uniq(); + result->query = std::move(select_statement); + return result; } - // Function to create the CTE for the Undirected CSR -unique_ptr CreateUndirectedCSRCTE(const shared_ptr &edge_table, +unique_ptr +CreateUndirectedCSRCTE(const shared_ptr &edge_table, const unique_ptr &select_node) { - if (select_node->cte_map.map.find("edges_cte") == select_node->cte_map.map.end()) { + if (select_node->cte_map.map.find("edges_cte") == + select_node->cte_map.map.end()) { select_node->cte_map.map["edges_cte"] = MakeEdgesCTE(edge_table); } auto csr_edge_id_constant = make_uniq(Value::INTEGER(0)); - auto count_create_edge_select = GetCountTable(edge_table->source_pg_table, edge_table->source_reference, edge_table->source_pk[0]); + auto count_create_edge_select = + GetCountTable(edge_table->source_pg_table, edge_table->source_reference, + edge_table->source_pk[0]); auto count_edges_subquery = GetCountUndirectedEdgeTable(); - auto cast_subquery_expr = CreateUndirectedCSRVertexSubquery(edge_table, edge_table->source_reference); + auto cast_subquery_expr = CreateUndirectedCSRVertexSubquery( + edge_table, edge_table->source_reference); auto src_rowid_colref = make_uniq("src"); auto dst_rowid_colref = make_uniq("dst"); auto edge_rowid_colref = make_uniq("edge"); - auto cast_expression = make_uniq(LogicalType::BIGINT, std::move(cast_subquery_expr)); + auto cast_expression = make_uniq( + LogicalType::BIGINT, std::move(cast_subquery_expr)); vector> csr_edge_children; csr_edge_children.push_back(std::move(csr_edge_id_constant)); @@ -424,8 +489,10 @@ unique_ptr CreateUndirectedCSRCTE(const shared_ptr("create_csr_edge", std::move(csr_edge_children)); - auto outer_select_node = CreateOuterSelectNode(std::move(create_csr_edge_function)); + auto create_csr_edge_function = make_uniq( + "create_csr_edge", std::move(csr_edge_children)); + auto outer_select_node = + CreateOuterSelectNode(std::move(create_csr_edge_function)); auto outer_select_edges_node = CreateOuterSelectEdgesNode(); @@ -436,26 +503,35 @@ unique_ptr CreateUndirectedCSRCTE(const shared_ptr(); src_dst_select_node->from_table = std::move(CreateBaseTableRef("edges_cte")); - src_dst_select_node->select_list.push_back(make_uniq("src")); - src_dst_select_node->select_list.push_back(make_uniq("dst")); - src_dst_select_node->select_list.push_back(make_uniq("edges")); + src_dst_select_node->select_list.push_back( + make_uniq("src")); + src_dst_select_node->select_list.push_back( + make_uniq("dst")); + src_dst_select_node->select_list.push_back( + make_uniq("edges")); auto dst_src_select_node = make_uniq(); dst_src_select_node->from_table = std::move(CreateBaseTableRef("edges_cte")); - dst_src_select_node->select_list.push_back(make_uniq("dst")); - dst_src_select_node->select_list.push_back(make_uniq("src")); - dst_src_select_node->select_list.push_back(make_uniq("edges")); + dst_src_select_node->select_list.push_back( + make_uniq("dst")); + dst_src_select_node->select_list.push_back( + make_uniq("src")); + dst_src_select_node->select_list.push_back( + make_uniq("edges")); outer_union_all_node->left = std::move(src_dst_select_node); outer_union_all_node->right = std::move(dst_src_select_node); auto outer_union_select_statement = make_uniq(); outer_union_select_statement->node = std::move(outer_union_all_node); - outer_select_edges_node->from_table = make_uniq(std::move(outer_union_select_statement)); + outer_select_edges_node->from_table = + make_uniq(std::move(outer_union_select_statement)); auto outer_select_edges_select_statement = make_uniq(); - outer_select_edges_select_statement->node = std::move(outer_select_edges_node); - outer_select_node->from_table = make_uniq(std::move(outer_select_edges_select_statement)); + outer_select_edges_select_statement->node = + std::move(outer_select_edges_node); + outer_select_node->from_table = + make_uniq(std::move(outer_select_edges_select_statement)); auto outer_select_statement = make_uniq(); outer_select_statement->node = std::move(outer_select_node); @@ -468,25 +544,32 @@ unique_ptr GetCountUndirectedEdgeTable() { auto count_edges_select_statement = make_uniq(); auto count_edges_select_node = make_uniq(); vector> count_children; - auto count_function = make_uniq("count", std::move(count_children)); + auto count_function = + make_uniq("count", std::move(count_children)); vector> multiply_children; auto constant_two = make_uniq(Value::BIGINT(2)); multiply_children.push_back(std::move(constant_two)); multiply_children.push_back(std::move(count_function)); - auto multiply_function = make_uniq("multiply", std::move(multiply_children)); - count_edges_select_node->select_list.emplace_back(std::move(multiply_function)); + auto multiply_function = + make_uniq("multiply", std::move(multiply_children)); + count_edges_select_node->select_list.emplace_back( + std::move(multiply_function)); auto inner_select_statement = make_uniq(); auto src_dst_select_node = make_uniq(); - src_dst_select_node->select_list.emplace_back(CreateColumnRefExpression("src")); - src_dst_select_node->select_list.emplace_back(CreateColumnRefExpression("dst")); + src_dst_select_node->select_list.emplace_back( + CreateColumnRefExpression("src")); + src_dst_select_node->select_list.emplace_back( + CreateColumnRefExpression("dst")); src_dst_select_node->from_table = std::move(CreateBaseTableRef("edges_cte")); auto dst_src_select_node = make_uniq(); - dst_src_select_node->select_list.emplace_back(CreateColumnRefExpression("dst", "", "src")); - dst_src_select_node->select_list.emplace_back(CreateColumnRefExpression("src", "", "dst")); + dst_src_select_node->select_list.emplace_back( + CreateColumnRefExpression("dst", "", "src")); + dst_src_select_node->select_list.emplace_back( + CreateColumnRefExpression("src", "", "dst")); dst_src_select_node->from_table = CreateBaseTableRef("edges_cte"); auto union_by_name_node = make_uniq(); @@ -495,7 +578,8 @@ unique_ptr GetCountUndirectedEdgeTable() { union_by_name_node->left = std::move(src_dst_select_node); union_by_name_node->right = std::move(dst_src_select_node); inner_select_statement->node = std::move(union_by_name_node); - auto inner_from_subquery = make_uniq(std::move(inner_select_statement)); + auto inner_from_subquery = + make_uniq(std::move(inner_select_statement)); count_edges_select_node->from_table = std::move(inner_from_subquery); count_edges_select_statement->node = std::move(count_edges_select_node); auto result = make_uniq(); @@ -504,22 +588,33 @@ unique_ptr GetCountUndirectedEdgeTable() { return result; } -unique_ptr GetCountEdgeTable(const shared_ptr &edge_table) { +unique_ptr +GetCountEdgeTable(const shared_ptr &edge_table) { auto result = make_uniq(); auto outer_select_statement = make_uniq(); auto outer_select_node = make_uniq(); vector> count_children; - outer_select_node->select_list.push_back(make_uniq("count", std::move(count_children))); + outer_select_node->select_list.push_back( + make_uniq("count", std::move(count_children))); auto inner_select_node = make_uniq(); auto first_join = make_uniq(JoinRefType::REGULAR); first_join->left = edge_table->CreateBaseTableRef(); first_join->right = edge_table->source_pg_table->CreateBaseTableRef("src"); - first_join->condition = make_uniq(ExpressionType::COMPARE_EQUAL, make_uniq(edge_table->source_fk[0], edge_table->table_name), make_uniq(edge_table->source_pk[0], "src")); + first_join->condition = make_uniq( + ExpressionType::COMPARE_EQUAL, + make_uniq(edge_table->source_fk[0], + edge_table->table_name), + make_uniq(edge_table->source_pk[0], "src")); auto second_join = make_uniq(JoinRefType::REGULAR); second_join->left = std::move(first_join); - second_join->right = edge_table->destination_pg_table->CreateBaseTableRef("dst"); - second_join->condition = make_uniq(ExpressionType::COMPARE_EQUAL, make_uniq(edge_table->destination_fk[0], edge_table->table_name), make_uniq(edge_table->destination_pk[0], "dst")); + second_join->right = + edge_table->destination_pg_table->CreateBaseTableRef("dst"); + second_join->condition = make_uniq( + ExpressionType::COMPARE_EQUAL, + make_uniq(edge_table->destination_fk[0], + edge_table->table_name), + make_uniq(edge_table->destination_pk[0], "dst")); outer_select_node->from_table = std::move(second_join); outer_select_statement->node = std::move(outer_select_node); result->subquery = std::move(outer_select_statement); @@ -527,20 +622,27 @@ unique_ptr GetCountEdgeTable(const shared_ptr CreateDirectedCSRCTE(const shared_ptr &edge_table, const string &prev_binding, const string &edge_binding, const string &next_binding) { +unique_ptr +CreateDirectedCSRCTE(const shared_ptr &edge_table, + const string &prev_binding, const string &edge_binding, + const string &next_binding) { auto csr_edge_id_constant = make_uniq(Value::INTEGER(0)); - auto count_create_edge_select = GetCountTable(edge_table->source_pg_table, prev_binding, edge_table->source_pk[0]); + auto count_create_edge_select = GetCountTable( + edge_table->source_pg_table, prev_binding, edge_table->source_pk[0]); - auto cast_subquery_expr = CreateDirectedCSRVertexSubquery(edge_table, prev_binding); - auto count_edge_table = GetCountEdgeTable(edge_table); // Count the number of edges + auto cast_subquery_expr = + CreateDirectedCSRVertexSubquery(edge_table, prev_binding); + auto count_edge_table = + GetCountEdgeTable(edge_table); // Count the number of edges auto src_rowid_colref = make_uniq("rowid", prev_binding); auto dst_rowid_colref = make_uniq("rowid", next_binding); - auto edge_rowid_colref = make_uniq("rowid", edge_binding); + auto edge_rowid_colref = + make_uniq("rowid", edge_binding); - auto cast_expression = make_uniq(LogicalType::BIGINT, std::move(cast_subquery_expr)); + auto cast_expression = make_uniq( + LogicalType::BIGINT, std::move(cast_subquery_expr)); vector> csr_edge_children; csr_edge_children.push_back(std::move(csr_edge_id_constant)); @@ -551,10 +653,13 @@ unique_ptr CreateDirectedCSRCTE(const shared_ptr("create_csr_edge", std::move(csr_edge_children)); - auto outer_select_node = CreateOuterSelectNode(std::move(create_csr_edge_function)); + auto create_csr_edge_function = make_uniq( + "create_csr_edge", std::move(csr_edge_children)); + auto outer_select_node = + CreateOuterSelectNode(std::move(create_csr_edge_function)); - outer_select_node->from_table = GetJoinRef(edge_table, edge_binding, prev_binding, next_binding); + outer_select_node->from_table = + GetJoinRef(edge_table, edge_binding, prev_binding, next_binding); auto outer_select_statement = make_uniq(); outer_select_statement->node = std::move(outer_select_node); @@ -575,14 +680,16 @@ unique_ptr CreateCountCTESubquery() { vector> children; children.push_back(make_uniq("temp", "csr_cte")); - auto count_function = make_uniq("count", std::move(children)); + auto count_function = + make_uniq("count", std::move(children)); auto zero = make_uniq(Value::INTEGER((int32_t)0)); vector> multiply_children; multiply_children.push_back(std::move(zero)); multiply_children.push_back(std::move(count_function)); - auto multiply_function = make_uniq("multiply", std::move(multiply_children)); + auto multiply_function = + make_uniq("multiply", std::move(multiply_children)); multiply_function->alias = "temp"; temp_cte_select_node->select_list.push_back(std::move(multiply_function)); diff --git a/src/core/utils/duckpgq_bitmap.cpp b/src/core/utils/duckpgq_bitmap.cpp index f5a5535f..166e9835 100644 --- a/src/core/utils/duckpgq_bitmap.cpp +++ b/src/core/utils/duckpgq_bitmap.cpp @@ -16,8 +16,6 @@ bool DuckPGQBitmap::test(size_t index) const { return (bitmap[index / 64] & (1ULL << (index % 64))) != 0; } -void DuckPGQBitmap::reset() { - fill(bitmap.begin(), bitmap.end(), 0); -} +void DuckPGQBitmap::reset() { fill(bitmap.begin(), bitmap.end(), 0); } } // namespace core } // namespace duckpgq diff --git a/src/core/utils/duckpgq_utils.cpp b/src/core/utils/duckpgq_utils.cpp index 35f15fb8..b07c2eae 100644 --- a/src/core/utils/duckpgq_utils.cpp +++ b/src/core/utils/duckpgq_utils.cpp @@ -17,55 +17,72 @@ namespace core { shared_ptr GetDuckPGQState(ClientContext &context) { auto lookup = context.registered_state->Get("duckpgq"); if (!lookup) { - throw Exception(ExceptionType::INVALID, "Registered DuckPGQ state not found"); + throw Exception(ExceptionType::INVALID, + "Registered DuckPGQ state not found"); } return lookup; } // Function to get PropertyGraphInfo from DuckPGQState -CreatePropertyGraphInfo* GetPropertyGraphInfo(const shared_ptr &duckpgq_state, const string &pg_name) { +CreatePropertyGraphInfo * +GetPropertyGraphInfo(const shared_ptr &duckpgq_state, + const string &pg_name) { auto property_graph = duckpgq_state->registered_property_graphs.find(pg_name); if (property_graph == duckpgq_state->registered_property_graphs.end()) { - throw Exception(ExceptionType::INVALID, "Property graph " + pg_name + " not found"); + throw Exception(ExceptionType::INVALID, + "Property graph " + pg_name + " not found"); } - return dynamic_cast(property_graph->second.get()); + return dynamic_cast(property_graph->second.get()); } // Function to validate the source node and edge table -shared_ptr ValidateSourceNodeAndEdgeTable(CreatePropertyGraphInfo *pg_info, const std::string &node_label, const std::string &edge_label) { +shared_ptr +ValidateSourceNodeAndEdgeTable(CreatePropertyGraphInfo *pg_info, + const std::string &node_label, + const std::string &edge_label) { auto source_node_pg_entry = pg_info->GetTableByLabel(node_label, true, true); if (!source_node_pg_entry->is_vertex_table) { - throw Exception(ExceptionType::INVALID, node_label + " is an edge table, expected a vertex table"); + throw Exception(ExceptionType::INVALID, + node_label + " is an edge table, expected a vertex table"); } auto edge_pg_entry = pg_info->GetTableByLabel(edge_label, true, false); if (edge_pg_entry->is_vertex_table) { - throw Exception(ExceptionType::INVALID, edge_label + " is a vertex table, expected an edge table"); + throw Exception(ExceptionType::INVALID, + edge_label + " is a vertex table, expected an edge table"); } if (!edge_pg_entry->IsSourceTable(source_node_pg_entry->table_name)) { - throw Exception(ExceptionType::INVALID, "Vertex table " + node_label + " is not a source of edge table " + edge_label); + throw Exception(ExceptionType::INVALID, + "Vertex table " + node_label + + " is not a source of edge table " + edge_label); } return edge_pg_entry; } // Function to create the SELECT node -unique_ptr CreateSelectNode(const shared_ptr &edge_pg_entry, const string& function_name, const string& function_alias) { +unique_ptr +CreateSelectNode(const shared_ptr &edge_pg_entry, + const string &function_name, const string &function_alias) { auto select_node = make_uniq(); std::vector> select_expression; - select_expression.emplace_back(make_uniq(edge_pg_entry->source_pk[0], edge_pg_entry->source_reference)); + select_expression.emplace_back(make_uniq( + edge_pg_entry->source_pk[0], edge_pg_entry->source_reference)); auto cte_col_ref = make_uniq("temp", "__x"); vector> function_children; function_children.push_back(make_uniq(Value::INTEGER(0))); - function_children.push_back(make_uniq("rowid", edge_pg_entry->source_reference)); - auto function = make_uniq(function_name, std::move(function_children)); + function_children.push_back( + make_uniq("rowid", edge_pg_entry->source_reference)); + auto function = make_uniq(function_name, + std::move(function_children)); std::vector> addition_children; addition_children.emplace_back(std::move(cte_col_ref)); addition_children.emplace_back(std::move(function)); - auto addition_function = make_uniq("add", std::move(addition_children)); + auto addition_function = + make_uniq("add", std::move(addition_children)); addition_function->alias = function_alias; select_expression.emplace_back(std::move(addition_function)); select_node->select_list = std::move(select_expression); @@ -83,7 +100,8 @@ unique_ptr CreateSelectNode(const shared_ptr &ed return select_node; } -unique_ptr CreateBaseTableRef(const string &table_name, const string &alias) { +unique_ptr CreateBaseTableRef(const string &table_name, + const string &alias) { auto base_table_ref = make_uniq(); base_table_ref->table_name = table_name; if (!alias.empty()) { @@ -92,11 +110,13 @@ unique_ptr CreateBaseTableRef(const string &table_name, const stri return base_table_ref; } -unique_ptr CreateColumnRefExpression(const string &column_name, const string &table_name, const string& alias) { +unique_ptr +CreateColumnRefExpression(const string &column_name, const string &table_name, + const string &alias) { unique_ptr column_ref; if (table_name.empty()) { column_ref = make_uniq(column_name); - } else { + } else { column_ref = make_uniq(column_name, table_name); } if (!alias.empty()) { diff --git a/src/duckpgq_extension.cpp b/src/duckpgq_extension.cpp index f170c693..8cfc2ad9 100644 --- a/src/duckpgq_extension.cpp +++ b/src/duckpgq_extension.cpp @@ -12,8 +12,10 @@ static void LoadInternal(DatabaseInstance &instance) { duckpgq::core::CoreModule::Register(instance); auto &config = DBConfig::GetConfig(instance); config.extension_callbacks.push_back(make_uniq()); - for (auto &connection : ConnectionManager::Get(instance).GetConnectionList()) { - connection->registered_state->Insert("duckpgq", make_shared_ptr(connection)); + for (auto &connection : + ConnectionManager::Get(instance).GetConnectionList()) { + connection->registered_state->Insert( + "duckpgq", make_shared_ptr(connection)); } } @@ -21,17 +23,17 @@ void DuckpgqExtension::Load(DuckDB &db) { LoadInternal(*db.instance); } std::string DuckpgqExtension::Name() { return "duckpgq"; } -} // namespace duckpgq +} // namespace duckdb extern "C" { - DUCKDB_EXTENSION_API void duckpgq_init(DatabaseInstance &db) { - LoadInternal(db); - } +DUCKDB_EXTENSION_API void duckpgq_init(DatabaseInstance &db) { + LoadInternal(db); +} - DUCKDB_EXTENSION_API const char *duckpgq_version() { - return DuckDB::LibraryVersion(); - } +DUCKDB_EXTENSION_API const char *duckpgq_version() { + return DuckDB::LibraryVersion(); +} } #ifndef DUCKDB_EXTENSION_MAIN diff --git a/src/duckpgq_state.cpp b/src/duckpgq_state.cpp index 35d5df51..6297bd8e 100644 --- a/src/duckpgq_state.cpp +++ b/src/duckpgq_state.cpp @@ -27,100 +27,116 @@ DuckPGQState::DuckPGQState(shared_ptr context) { RetrievePropertyGraphs(new_conn); } -void DuckPGQState::RetrievePropertyGraphs(const shared_ptr &context) { - // Retrieve and process vertex property graphs - auto vertex_property_graphs = context->Query("SELECT * FROM __duckpgq_internal WHERE is_vertex_table", false); - ProcessPropertyGraphs(vertex_property_graphs, true); - - // Retrieve and process edge property graphs - auto edge_property_graphs = context->Query("SELECT * FROM __duckpgq_internal WHERE NOT is_vertex_table", false); - ProcessPropertyGraphs(edge_property_graphs, false); +void DuckPGQState::RetrievePropertyGraphs( + const shared_ptr &context) { + // Retrieve and process vertex property graphs + auto vertex_property_graphs = context->Query( + "SELECT * FROM __duckpgq_internal WHERE is_vertex_table", false); + ProcessPropertyGraphs(vertex_property_graphs, true); + + // Retrieve and process edge property graphs + auto edge_property_graphs = context->Query( + "SELECT * FROM __duckpgq_internal WHERE NOT is_vertex_table", false); + ProcessPropertyGraphs(edge_property_graphs, false); } -void DuckPGQState::ProcessPropertyGraphs(unique_ptr &property_graphs, bool is_vertex) { - if (!property_graphs || property_graphs->type != QueryResultType::MATERIALIZED_RESULT) { - throw std::runtime_error("Failed to fetch property graphs or invalid result type."); - } +void DuckPGQState::ProcessPropertyGraphs( + unique_ptr &property_graphs, bool is_vertex) { + if (!property_graphs || + property_graphs->type != QueryResultType::MATERIALIZED_RESULT) { + throw std::runtime_error( + "Failed to fetch property graphs or invalid result type."); + } + + auto &materialized_result = property_graphs->Cast(); + auto row_count = materialized_result.RowCount(); + if (row_count == 0) { + return; // No results + } - auto &materialized_result = property_graphs->Cast(); - auto row_count = materialized_result.RowCount(); - if (row_count == 0) { - return; // No results + auto chunk = materialized_result.Fetch(); + for (idx_t i = 0; i < row_count; i++) { + auto table = make_shared_ptr(); + + // Extract and validate common properties + table->table_name = chunk->GetValue(1, i).GetValue(); + table->main_label = chunk->GetValue(2, i).GetValue(); + table->is_vertex_table = chunk->GetValue(3, i).GetValue(); + table->all_columns = true; // TODO: Be stricter on properties + + // Handle discriminator and sub-labels + const auto &discriminator = chunk->GetValue(10, i).GetValue(); + if (discriminator != "NULL") { + table->discriminator = discriminator; + auto sublabels = ListValue::GetChildren(chunk->GetValue(11, i)); + for (const auto &sublabel : sublabels) { + table->sub_labels.push_back(sublabel.GetValue()); + } } - auto chunk = materialized_result.Fetch(); - for (idx_t i = 0; i < row_count; i++) { - auto table = make_shared_ptr(); - - // Extract and validate common properties - table->table_name = chunk->GetValue(1, i).GetValue(); - table->main_label = chunk->GetValue(2, i).GetValue(); - table->is_vertex_table = chunk->GetValue(3, i).GetValue(); - table->all_columns = true; // TODO: Be stricter on properties - - // Handle discriminator and sub-labels - const auto &discriminator = chunk->GetValue(10, i).GetValue(); - if (discriminator != "NULL") { - table->discriminator = discriminator; - auto sublabels = ListValue::GetChildren(chunk->GetValue(11, i)); - for (const auto &sublabel : sublabels) { - table->sub_labels.push_back(sublabel.GetValue()); - } - } - - // Extract catalog and schema names - table->catalog_name = chunk->GetValue(12, i).GetValue(); - table->schema_name = chunk->GetValue(13, i).GetValue(); - - // Additional edge-specific handling - if (!is_vertex) { - PopulateEdgeSpecificFields(chunk, i, *table); - } - - RegisterPropertyGraph(table, chunk->GetValue(0, i).GetValue(), is_vertex); + // Extract catalog and schema names + table->catalog_name = chunk->GetValue(12, i).GetValue(); + table->schema_name = chunk->GetValue(13, i).GetValue(); + + // Additional edge-specific handling + if (!is_vertex) { + PopulateEdgeSpecificFields(chunk, i, *table); } + + RegisterPropertyGraph(table, chunk->GetValue(0, i).GetValue(), + is_vertex); + } } -void DuckPGQState::PopulateEdgeSpecificFields(unique_ptr &chunk, idx_t row_idx, PropertyGraphTable &table) { - table.source_reference = chunk->GetValue(4, row_idx).GetValue(); - ExtractListValues(chunk->GetValue(5, row_idx), table.source_pk); - ExtractListValues(chunk->GetValue(6, row_idx), table.source_fk); - table.destination_reference = chunk->GetValue(7, row_idx).GetValue(); - ExtractListValues(chunk->GetValue(8, row_idx), table.destination_pk); - ExtractListValues(chunk->GetValue(9, row_idx), table.destination_fk); +void DuckPGQState::PopulateEdgeSpecificFields(unique_ptr &chunk, + idx_t row_idx, + PropertyGraphTable &table) { + table.source_reference = chunk->GetValue(4, row_idx).GetValue(); + ExtractListValues(chunk->GetValue(5, row_idx), table.source_pk); + ExtractListValues(chunk->GetValue(6, row_idx), table.source_fk); + table.destination_reference = chunk->GetValue(7, row_idx).GetValue(); + ExtractListValues(chunk->GetValue(8, row_idx), table.destination_pk); + ExtractListValues(chunk->GetValue(9, row_idx), table.destination_fk); } -void DuckPGQState::ExtractListValues(const Value &list_value, vector &output) { - auto children = ListValue::GetChildren(list_value); - for (const auto &child : children) { - output.push_back(child.GetValue()); - } +void DuckPGQState::ExtractListValues(const Value &list_value, + vector &output) { + auto children = ListValue::GetChildren(list_value); + for (const auto &child : children) { + output.push_back(child.GetValue()); + } } -void DuckPGQState::RegisterPropertyGraph(const shared_ptr &table, const string &graph_name, bool is_vertex) { - // Ensure the property graph exists in the registry - if (registered_property_graphs.find(graph_name) == registered_property_graphs.end()) { - registered_property_graphs[graph_name] = make_uniq(graph_name); - } +void DuckPGQState::RegisterPropertyGraph( + const shared_ptr &table, const string &graph_name, + bool is_vertex) { + // Ensure the property graph exists in the registry + if (registered_property_graphs.find(graph_name) == + registered_property_graphs.end()) { + registered_property_graphs[graph_name] = + make_uniq(graph_name); + } - auto &pg_info = registered_property_graphs[graph_name]->Cast(); - pg_info.label_map[table->main_label] = table; + auto &pg_info = + registered_property_graphs[graph_name]->Cast(); + pg_info.label_map[table->main_label] = table; - if (!table->discriminator.empty()) { - for (const auto &label : table->sub_labels) { - pg_info.label_map[label] = table; - } + if (!table->discriminator.empty()) { + for (const auto &label : table->sub_labels) { + pg_info.label_map[label] = table; } + } - if (is_vertex) { - pg_info.vertex_tables.push_back(table); - } else { - table->source_pg_table = pg_info.GetTableByName(table->source_reference); - D_ASSERT(table->source_pg_table); - table->destination_pg_table = pg_info.GetTableByName(table->destination_reference); - D_ASSERT(table->destination_pg_table); - pg_info.edge_tables.push_back(table); - } + if (is_vertex) { + pg_info.vertex_tables.push_back(table); + } else { + table->source_pg_table = pg_info.GetTableByName(table->source_reference); + D_ASSERT(table->source_pg_table); + table->destination_pg_table = + pg_info.GetTableByName(table->destination_reference); + D_ASSERT(table->destination_pg_table); + pg_info.edge_tables.push_back(table); + } } void DuckPGQState::QueryEnd() { diff --git a/src/include/duckpgq/core/functions/function_data/cheapest_path_length_function_data.hpp b/src/include/duckpgq/core/functions/function_data/cheapest_path_length_function_data.hpp index b9c5e993..1aa34361 100644 --- a/src/include/duckpgq/core/functions/function_data/cheapest_path_length_function_data.hpp +++ b/src/include/duckpgq/core/functions/function_data/cheapest_path_length_function_data.hpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// - #pragma once #include "duckpgq/common.hpp" #include "duckdb/main/client_context.hpp" @@ -15,7 +14,6 @@ namespace duckpgq { namespace core { - struct CheapestPathLengthFunctionData final : FunctionData { ClientContext &context; int32_t csr_id; diff --git a/src/include/duckpgq/core/functions/function_data/local_clustering_coefficient_function_data.hpp b/src/include/duckpgq/core/functions/function_data/local_clustering_coefficient_function_data.hpp index d3fb177a..20b63f44 100644 --- a/src/include/duckpgq/core/functions/function_data/local_clustering_coefficient_function_data.hpp +++ b/src/include/duckpgq/core/functions/function_data/local_clustering_coefficient_function_data.hpp @@ -18,10 +18,12 @@ struct LocalClusteringCoefficientFunctionData final : FunctionData { ClientContext &context; int32_t csr_id; - LocalClusteringCoefficientFunctionData(ClientContext &context, int32_t csr_id); + LocalClusteringCoefficientFunctionData(ClientContext &context, + int32_t csr_id); static unique_ptr - LocalClusteringCoefficientBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments); + LocalClusteringCoefficientBind(ClientContext &context, + ScalarFunction &bound_function, + vector> &arguments); unique_ptr Copy() const override; bool Equals(const FunctionData &other_p) const override; diff --git a/src/include/duckpgq/core/functions/function_data/pagerank_function_data.hpp b/src/include/duckpgq/core/functions/function_data/pagerank_function_data.hpp index 7e44de75..7907029b 100644 --- a/src/include/duckpgq/core/functions/function_data/pagerank_function_data.hpp +++ b/src/include/duckpgq/core/functions/function_data/pagerank_function_data.hpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// - #pragma once #include "duckdb/main/client_context.hpp" #include "duckpgq/common.hpp" @@ -26,16 +25,16 @@ struct PageRankFunctionData final : FunctionData { bool converged; PageRankFunctionData(ClientContext &context, int32_t csr_id); - PageRankFunctionData(ClientContext &context, int32_t csr_id, const vector &componentId); + PageRankFunctionData(ClientContext &context, int32_t csr_id, + const vector &componentId); static unique_ptr PageRankBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments); + vector> &arguments); unique_ptr Copy() const override; bool Equals(const FunctionData &other_p) const override; }; - } // namespace core } // namespace duckpgq \ No newline at end of file diff --git a/src/include/duckpgq/core/functions/function_data/weakly_connected_component_function_data.hpp b/src/include/duckpgq/core/functions/function_data/weakly_connected_component_function_data.hpp index c39299c0..8e71a7ef 100644 --- a/src/include/duckpgq/core/functions/function_data/weakly_connected_component_function_data.hpp +++ b/src/include/duckpgq/core/functions/function_data/weakly_connected_component_function_data.hpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// - #pragma once #include "duckdb/main/client_context.hpp" #include "duckpgq/common.hpp" @@ -20,16 +19,17 @@ struct WeaklyConnectedComponentFunctionData final : FunctionData { std::mutex component_lock; bool component_id_initialized; // if componentId is initialized WeaklyConnectedComponentFunctionData(ClientContext &context, int32_t csr_id); - WeaklyConnectedComponentFunctionData(ClientContext &context, int32_t csr_id, const vector &componentId); + WeaklyConnectedComponentFunctionData(ClientContext &context, int32_t csr_id, + const vector &componentId); static unique_ptr - WeaklyConnectedComponentBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments); + WeaklyConnectedComponentBind(ClientContext &context, + ScalarFunction &bound_function, + vector> &arguments); unique_ptr Copy() const override; bool Equals(const FunctionData &other_p) const override; }; - } // namespace core } // namespace duckpgq \ No newline at end of file diff --git a/src/include/duckpgq/core/functions/scalar.hpp b/src/include/duckpgq/core/functions/scalar.hpp index 25d063a9..f02159d3 100644 --- a/src/include/duckpgq/core/functions/scalar.hpp +++ b/src/include/duckpgq/core/functions/scalar.hpp @@ -28,16 +28,17 @@ struct CoreScalarFunctions { static void RegisterGetCSRWTypeScalarFunction(DatabaseInstance &db); static void RegisterIterativeLengthScalarFunction(DatabaseInstance &db); static void RegisterIterativeLength2ScalarFunction(DatabaseInstance &db); - static void RegisterIterativeLengthBidirectionalScalarFunction(DatabaseInstance &db); - static void RegisterLocalClusteringCoefficientScalarFunction(DatabaseInstance &db); + static void + RegisterIterativeLengthBidirectionalScalarFunction(DatabaseInstance &db); + static void + RegisterLocalClusteringCoefficientScalarFunction(DatabaseInstance &db); static void RegisterReachabilityScalarFunction(DatabaseInstance &db); - static void RegisterShortestPathScalarFunction(DatabaseInstance &db); - static void RegisterWeaklyConnectedComponentScalarFunction(DatabaseInstance &db); - static void RegisterPageRankScalarFunction(DatabaseInstance &db); - + static void RegisterShortestPathScalarFunction(DatabaseInstance &db); + static void + RegisterWeaklyConnectedComponentScalarFunction(DatabaseInstance &db); + static void RegisterPageRankScalarFunction(DatabaseInstance &db); }; - } // namespace core } // namespace duckpgq diff --git a/src/include/duckpgq/core/functions/table.hpp b/src/include/duckpgq/core/functions/table.hpp index 6dff40f5..d2da78a2 100644 --- a/src/include/duckpgq/core/functions/table.hpp +++ b/src/include/duckpgq/core/functions/table.hpp @@ -22,13 +22,14 @@ struct CoreTableFunctions { static void RegisterMatchTableFunction(DatabaseInstance &db); static void RegisterDropPropertyGraphTableFunction(DatabaseInstance &db); static void RegisterDescribePropertyGraphTableFunction(DatabaseInstance &db); - static void RegisterLocalClusteringCoefficientTableFunction(DatabaseInstance &db); + static void + RegisterLocalClusteringCoefficientTableFunction(DatabaseInstance &db); static void RegisterScanTableFunctions(DatabaseInstance &db); - static void RegisterWeaklyConnectedComponentTableFunction(DatabaseInstance &db); + static void + RegisterWeaklyConnectedComponentTableFunction(DatabaseInstance &db); static void RegisterPageRankTableFunction(DatabaseInstance &db); }; - } // namespace core } // namespace duckpgq diff --git a/src/include/duckpgq/core/functions/table/create_property_graph.hpp b/src/include/duckpgq/core/functions/table/create_property_graph.hpp index bb55c81b..02c25eeb 100644 --- a/src/include/duckpgq/core/functions/table/create_property_graph.hpp +++ b/src/include/duckpgq/core/functions/table/create_property_graph.hpp @@ -48,21 +48,28 @@ class CreatePropertyGraphFunction : public TableFunction { vector &return_types, vector &names); - static void ValidateVertexTableRegistration(const string &reference, const case_insensitive_set_t &v_table_names); + static void + ValidateVertexTableRegistration(const string &reference, + const case_insensitive_set_t &v_table_names); - static void ValidatePrimaryKeyInTable(Catalog &catalog, ClientContext &context, const string &schema, - const string &reference, const vector &pk_columns); + static void ValidatePrimaryKeyInTable(Catalog &catalog, + ClientContext &context, + const string &schema, + const string &reference, + const vector &pk_columns); - static void ValidateKeys(shared_ptr &edge_table, - const string &reference, const string &key_type, - vector &pk_columns, vector &fk_columns, - const vector> &table_constraints); + static void + ValidateKeys(shared_ptr &edge_table, + const string &reference, const string &key_type, + vector &pk_columns, vector &fk_columns, + const vector> &table_constraints); - static void ValidateForeignKeyColumns(shared_ptr &edge_table, - const vector &fk_columns, - optional_ptr &table); + static void + ValidateForeignKeyColumns(shared_ptr &edge_table, + const vector &fk_columns, + optional_ptr &table); - static unique_ptr + static unique_ptr CreatePropertyGraphInit(ClientContext &context, TableFunctionInitInput &input); diff --git a/src/include/duckpgq/core/functions/table/describe_property_graph.hpp b/src/include/duckpgq/core/functions/table/describe_property_graph.hpp index f1061b09..270995d6 100644 --- a/src/include/duckpgq/core/functions/table/describe_property_graph.hpp +++ b/src/include/duckpgq/core/functions/table/describe_property_graph.hpp @@ -16,7 +16,6 @@ namespace duckpgq { namespace core { - class DescribePropertyGraphFunction : public TableFunction { public: DescribePropertyGraphFunction() { diff --git a/src/include/duckpgq/core/functions/table/drop_property_graph.hpp b/src/include/duckpgq/core/functions/table/drop_property_graph.hpp index 158a825e..a0034fa1 100644 --- a/src/include/duckpgq/core/functions/table/drop_property_graph.hpp +++ b/src/include/duckpgq/core/functions/table/drop_property_graph.hpp @@ -17,7 +17,6 @@ namespace duckpgq { namespace core { - class DropPropertyGraphFunction : public TableFunction { public: DropPropertyGraphFunction() { diff --git a/src/include/duckpgq/core/functions/table/local_clustering_coefficient.hpp b/src/include/duckpgq/core/functions/table/local_clustering_coefficient.hpp index 3770eb3d..6558e44b 100644 --- a/src/include/duckpgq/core/functions/table/local_clustering_coefficient.hpp +++ b/src/include/duckpgq/core/functions/table/local_clustering_coefficient.hpp @@ -11,44 +11,44 @@ class LocalClusteringCoefficientFunction : public TableFunction { public: LocalClusteringCoefficientFunction() { name = "local_clustering_coefficient"; - arguments = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}; + arguments = {LogicalType::VARCHAR, LogicalType::VARCHAR, + LogicalType::VARCHAR}; bind_replace = LocalClusteringCoefficientBindReplace; } - static unique_ptr LocalClusteringCoefficientBindReplace(ClientContext &context, - TableFunctionBindInput &input); - - }; - - struct LocalClusteringCoefficientData : TableFunctionData { - static unique_ptr - LocalClusteringCoefficientBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - result->pg_name = StringValue::Get(input.inputs[0]); - result->node_table = StringValue::Get(input.inputs[1]); - result->edge_table = StringValue::Get(input.inputs[2]); - return_types.emplace_back(LogicalType::BIGINT); - return_types.emplace_back(LogicalType::FLOAT); - names.emplace_back("rowid"); - names.emplace_back("local_clustering_coefficient"); - return std::move(result); - } - - string pg_name; - string node_table; - string edge_table; - }; - - - struct LocalClusteringCoefficientScanState : GlobalTableFunctionState { - static unique_ptr - Init(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - return std::move(result); - } - - bool finished = false; + static unique_ptr + LocalClusteringCoefficientBindReplace(ClientContext &context, + TableFunctionBindInput &input); +}; + +struct LocalClusteringCoefficientData : TableFunctionData { + static unique_ptr LocalClusteringCoefficientBind( + ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto result = make_uniq(); + result->pg_name = StringValue::Get(input.inputs[0]); + result->node_table = StringValue::Get(input.inputs[1]); + result->edge_table = StringValue::Get(input.inputs[2]); + return_types.emplace_back(LogicalType::BIGINT); + return_types.emplace_back(LogicalType::FLOAT); + names.emplace_back("rowid"); + names.emplace_back("local_clustering_coefficient"); + return std::move(result); + } + + string pg_name; + string node_table; + string edge_table; +}; + +struct LocalClusteringCoefficientScanState : GlobalTableFunctionState { + static unique_ptr + Init(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + return std::move(result); + } + + bool finished = false; }; } // namespace core diff --git a/src/include/duckpgq/core/functions/table/match.hpp b/src/include/duckpgq/core/functions/table/match.hpp index f319578d..0eda8f8c 100644 --- a/src/include/duckpgq/core/functions/table/match.hpp +++ b/src/include/duckpgq/core/functions/table/match.hpp @@ -24,7 +24,6 @@ namespace duckpgq { namespace core { - struct PGQMatchFunction : public TableFunction { public: PGQMatchFunction() { @@ -90,13 +89,13 @@ struct PGQMatchFunction : public TableFunction { const string &next_binding, vector> &conditions); - static void - EdgeTypeLeftRight(const shared_ptr &edge_table, - const string &edge_binding, const string &prev_binding, - const string &next_binding, - vector> &conditions, - case_insensitive_map_t> &alias_map, - int32_t &extra_alias_counter); + static void EdgeTypeLeftRight( + const shared_ptr &edge_table, + const string &edge_binding, const string &prev_binding, + const string &next_binding, + vector> &conditions, + case_insensitive_map_t> &alias_map, + int32_t &extra_alias_counter); static PathElement * HandleNestedSubPath(unique_ptr &path_reference, @@ -104,8 +103,8 @@ struct PGQMatchFunction : public TableFunction { idx_t element_idx); static unique_ptr AddPathQuantifierCondition( - const string &prev_binding, const string &next_binding, - const shared_ptr &edge_table, const SubPath *subpath); + const string &prev_binding, const string &next_binding, + const shared_ptr &edge_table, const SubPath *subpath); static unique_ptr MatchBindReplace(ClientContext &context, TableFunctionBindInput &input); @@ -115,29 +114,35 @@ struct PGQMatchFunction : public TableFunction { vector> &column_list, unordered_set &named_subpaths); - static unique_ptr GenerateShortestPathCTE(CreatePropertyGraphInfo & pg_table, SubPath * edge_subpath, - PathElement * path_element, PathElement * next_vertex_element, vector> &path_finding_conditions); + static unique_ptr GenerateShortestPathCTE( + CreatePropertyGraphInfo &pg_table, SubPath *edge_subpath, + PathElement *path_element, PathElement *next_vertex_element, + vector> &path_finding_conditions); static unique_ptr CreatePathFindingFunction(vector> &path_list, - CreatePropertyGraphInfo &pg_table, const string &path_variable, unique_ptr &final_select_node, vector> &conditions); - - static void AddPathFinding( - unique_ptr &select_node, - vector> &conditions, - const string &prev_binding, const string &edge_binding, - const string &next_binding, - const shared_ptr &edge_table, - CreatePropertyGraphInfo &pg_table, SubPath *subpath, PGQMatchType edge_type); + CreatePropertyGraphInfo &pg_table, + const string &path_variable, + unique_ptr &final_select_node, + vector> &conditions); - static void - AddEdgeJoins(const shared_ptr &edge_table, - const shared_ptr &previous_vertex_table, - const shared_ptr &next_vertex_table, - PGQMatchType edge_type, const string &edge_binding, - const string &prev_binding, const string &next_binding, - vector> &conditions, - case_insensitive_map_t> &alias_map, - int32_t &extra_alias_counter, unique_ptr &from_clause); + static void AddPathFinding(unique_ptr &select_node, + vector> &conditions, + const string &prev_binding, + const string &edge_binding, + const string &next_binding, + const shared_ptr &edge_table, + CreatePropertyGraphInfo &pg_table, + SubPath *subpath, PGQMatchType edge_type); + + static void AddEdgeJoins( + const shared_ptr &edge_table, + const shared_ptr &previous_vertex_table, + const shared_ptr &next_vertex_table, + PGQMatchType edge_type, const string &edge_binding, + const string &prev_binding, const string &next_binding, + vector> &conditions, + case_insensitive_map_t> &alias_map, + int32_t &extra_alias_counter, unique_ptr &from_clause); static void ProcessPathList( vector> &path_pattern, @@ -148,8 +153,7 @@ struct PGQMatchFunction : public TableFunction { MatchExpression &original_ref); static void - CheckNamedSubpath(SubPath &subpath, - MatchExpression &original_ref, + CheckNamedSubpath(SubPath &subpath, MatchExpression &original_ref, CreatePropertyGraphInfo &pg_table, unique_ptr &final_select_node, vector> &conditions); @@ -157,4 +161,4 @@ struct PGQMatchFunction : public TableFunction { } // namespace core -} // namespace duckdb +} // namespace duckpgq diff --git a/src/include/duckpgq/core/functions/table/pagerank.hpp b/src/include/duckpgq/core/functions/table/pagerank.hpp index 65518d64..e11f489d 100644 --- a/src/include/duckpgq/core/functions/table/pagerank.hpp +++ b/src/include/duckpgq/core/functions/table/pagerank.hpp @@ -15,19 +15,19 @@ class PageRankFunction : public TableFunction { public: PageRankFunction() { name = "pagerank"; - arguments = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}; + arguments = {LogicalType::VARCHAR, LogicalType::VARCHAR, + LogicalType::VARCHAR}; bind_replace = PageRankBindReplace; } - static unique_ptr PageRankBindReplace(ClientContext &context, - TableFunctionBindInput &input); - + static unique_ptr + PageRankBindReplace(ClientContext &context, TableFunctionBindInput &input); }; struct PageRankData : TableFunctionData { static unique_ptr PageRankBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { + vector &return_types, vector &names) { auto result = make_uniq(); result->pg_name = StringValue::Get(input.inputs[0]); result->node_table = StringValue::Get(input.inputs[1]); @@ -44,8 +44,7 @@ struct PageRankData : TableFunctionData { string edge_table; }; - -struct PageRankScanState : GlobalTableFunctionState { +struct PageRankScanState : GlobalTableFunctionState { static unique_ptr Init(ClientContext &context, TableFunctionInitInput &input) { auto result = make_uniq(); diff --git a/src/include/duckpgq/core/functions/table/pgq_scan.hpp b/src/include/duckpgq/core/functions/table/pgq_scan.hpp index fc0ee010..47e24d9a 100644 --- a/src/include/duckpgq/core/functions/table/pgq_scan.hpp +++ b/src/include/duckpgq/core/functions/table/pgq_scan.hpp @@ -79,7 +79,6 @@ struct CSRScanWData : public TableFunctionData { throw InternalException("The DuckPGQ extension has not been loaded"); } - CSR *csr = duckpgq_state->GetCSR(result->csr_id); if (!csr->w.empty()) { @@ -194,6 +193,6 @@ struct CSRScanState : public GlobalTableFunctionState { bool finished = false; }; -} // namespace core +} // namespace core } // namespace duckpgq diff --git a/src/include/duckpgq/core/functions/table/weakly_connected_component.hpp b/src/include/duckpgq/core/functions/table/weakly_connected_component.hpp index 34fd0eae..e94584ef 100644 --- a/src/include/duckpgq/core/functions/table/weakly_connected_component.hpp +++ b/src/include/duckpgq/core/functions/table/weakly_connected_component.hpp @@ -15,19 +15,20 @@ class WeaklyConnectedComponentFunction : public TableFunction { public: WeaklyConnectedComponentFunction() { name = "weakly_connected_component"; - arguments = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}; + arguments = {LogicalType::VARCHAR, LogicalType::VARCHAR, + LogicalType::VARCHAR}; bind_replace = WeaklyConnectedComponentBindReplace; } - static unique_ptr WeaklyConnectedComponentBindReplace(ClientContext &context, - TableFunctionBindInput &input); - + static unique_ptr + WeaklyConnectedComponentBindReplace(ClientContext &context, + TableFunctionBindInput &input); }; struct WeaklyConnectedComponentData : TableFunctionData { - static unique_ptr - WeaklyConnectedComponentBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { + static unique_ptr WeaklyConnectedComponentBind( + ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { auto result = make_uniq(); result->pg_name = StringValue::Get(input.inputs[0]); result->node_table = StringValue::Get(input.inputs[1]); @@ -44,8 +45,7 @@ struct WeaklyConnectedComponentData : TableFunctionData { string edge_table; }; - -struct WeaklyConnectedComponentScanState : GlobalTableFunctionState { +struct WeaklyConnectedComponentScanState : GlobalTableFunctionState { static unique_ptr Init(ClientContext &context, TableFunctionInitInput &input) { auto result = make_uniq(); diff --git a/src/include/duckpgq/core/module.hpp b/src/include/duckpgq/core/module.hpp index 77dbaaa9..6fa47e9d 100644 --- a/src/include/duckpgq/core/module.hpp +++ b/src/include/duckpgq/core/module.hpp @@ -8,7 +8,6 @@ namespace core { struct CoreModule { public: static void Register(DatabaseInstance &db); - }; } // namespace core diff --git a/src/include/duckpgq/core/operator/duckpgq_bind.hpp b/src/include/duckpgq/core/operator/duckpgq_bind.hpp index ed5f4468..7329d32e 100644 --- a/src/include/duckpgq/core/operator/duckpgq_bind.hpp +++ b/src/include/duckpgq/core/operator/duckpgq_bind.hpp @@ -21,7 +21,6 @@ struct DuckPGQOperatorExtension : public OperatorExtension { } }; - } // namespace core } // namespace duckpgq \ No newline at end of file diff --git a/src/include/duckpgq/core/operator/duckpgq_operator.hpp b/src/include/duckpgq/core/operator/duckpgq_operator.hpp index ff1346b8..3aa28510 100644 --- a/src/include/duckpgq/core/operator/duckpgq_operator.hpp +++ b/src/include/duckpgq/core/operator/duckpgq_operator.hpp @@ -6,9 +6,7 @@ namespace duckpgq { namespace core { struct CorePGQOperator { - static void Register(DatabaseInstance &db) { - RegisterPGQBindOperator(db); - } + static void Register(DatabaseInstance &db) { RegisterPGQBindOperator(db); } private: static void RegisterPGQBindOperator(DatabaseInstance &db); diff --git a/src/include/duckpgq/core/parser/duckpgq_parser.hpp b/src/include/duckpgq/core/parser/duckpgq_parser.hpp index d856a558..a7536865 100644 --- a/src/include/duckpgq/core/parser/duckpgq_parser.hpp +++ b/src/include/duckpgq/core/parser/duckpgq_parser.hpp @@ -6,9 +6,7 @@ namespace duckpgq { namespace core { struct CorePGQParser { - static void Register(DatabaseInstance &db) { - RegisterPGQParserExtension(db); - } + static void Register(DatabaseInstance &db) { RegisterPGQParserExtension(db); } private: static void RegisterPGQParserExtension(DatabaseInstance &db); @@ -51,7 +49,6 @@ struct DuckPGQParseData : ParserExtensionParseData { : statement(std::move(statement)) {} }; - } // namespace core } // namespace duckpgq \ No newline at end of file diff --git a/src/include/duckpgq/core/utils/compressed_sparse_row.hpp b/src/include/duckpgq/core/utils/compressed_sparse_row.hpp index 1eca95e4..6d6a9048 100644 --- a/src/include/duckpgq/core/utils/compressed_sparse_row.hpp +++ b/src/include/duckpgq/core/utils/compressed_sparse_row.hpp @@ -66,22 +66,38 @@ struct CSRFunctionData : FunctionData { }; // CSR BindReplace functions -unique_ptr CreateUndirectedCSRCTE(const shared_ptr &edge_table, +unique_ptr +CreateUndirectedCSRCTE(const shared_ptr &edge_table, const unique_ptr &select_node); -unique_ptr CreateDirectedCSRCTE(const shared_ptr &edge_table, const string &prev_binding, const string &edge_binding, const string &next_binding); +unique_ptr +CreateDirectedCSRCTE(const shared_ptr &edge_table, + const string &prev_binding, const string &edge_binding, + const string &next_binding); // Helper functions -unique_ptr MakeEdgesCTE(const shared_ptr &edge_table); -unique_ptr CreateDirectedCSRVertexSubquery(const shared_ptr &edge_table, const string &binding); -unique_ptr CreateUndirectedCSRVertexSubquery(const shared_ptr &edge_table, const string &binding); +unique_ptr +MakeEdgesCTE(const shared_ptr &edge_table); +unique_ptr CreateDirectedCSRVertexSubquery( + const shared_ptr &edge_table, const string &binding); +unique_ptr CreateUndirectedCSRVertexSubquery( + const shared_ptr &edge_table, const string &binding); unique_ptr CreateOuterSelectEdgesNode(); -unique_ptr CreateOuterSelectNode(unique_ptr create_csr_edge_function); -unique_ptr GetJoinRef(const shared_ptr &edge_table,const string &edge_binding, const string &prev_binding, const string &next_binding); -unique_ptr GetCountTable(const shared_ptr &table, const string &table_alias, const string &primary_key); -void SetupSelectNode(unique_ptr &select_node, const shared_ptr &edge_table, bool reverse = false); +unique_ptr +CreateOuterSelectNode(unique_ptr create_csr_edge_function); +unique_ptr GetJoinRef(const shared_ptr &edge_table, + const string &edge_binding, + const string &prev_binding, + const string &next_binding); +unique_ptr +GetCountTable(const shared_ptr &table, + const string &table_alias, const string &primary_key); +void SetupSelectNode(unique_ptr &select_node, + const shared_ptr &edge_table, + bool reverse = false); unique_ptr CreateCountCTESubquery(); unique_ptr GetCountUndirectedEdgeTable(); -unique_ptr GetCountEdgeTable(const shared_ptr &edge_table); +unique_ptr +GetCountEdgeTable(const shared_ptr &edge_table); } // namespace core } // namespace duckpgq diff --git a/src/include/duckpgq/core/utils/duckpgq_utils.hpp b/src/include/duckpgq/core/utils/duckpgq_utils.hpp index a68861f7..973136dd 100644 --- a/src/include/duckpgq/core/utils/duckpgq_utils.hpp +++ b/src/include/duckpgq/core/utils/duckpgq_utils.hpp @@ -10,11 +10,22 @@ namespace core { // Function to get DuckPGQState from ClientContext shared_ptr GetDuckPGQState(ClientContext &context); -CreatePropertyGraphInfo* GetPropertyGraphInfo(const shared_ptr &duckpgq_state, const string &pg_name); -shared_ptr ValidateSourceNodeAndEdgeTable(CreatePropertyGraphInfo *pg_info, const std::string &node_table, const std::string &edge_table); -unique_ptr CreateSelectNode(const shared_ptr &edge_pg_entry, const string& function_name, const string& function_alias); -unique_ptr CreateBaseTableRef(const string &table_name, const string &alias = ""); -unique_ptr CreateColumnRefExpression(const string &column_name, const string &table_name = "", const string& alias = ""); +CreatePropertyGraphInfo * +GetPropertyGraphInfo(const shared_ptr &duckpgq_state, + const string &pg_name); +shared_ptr +ValidateSourceNodeAndEdgeTable(CreatePropertyGraphInfo *pg_info, + const std::string &node_table, + const std::string &edge_table); +unique_ptr +CreateSelectNode(const shared_ptr &edge_pg_entry, + const string &function_name, const string &function_alias); +unique_ptr CreateBaseTableRef(const string &table_name, + const string &alias = ""); +unique_ptr +CreateColumnRefExpression(const string &column_name, + const string &table_name = "", + const string &alias = ""); } // namespace core } // namespace duckpgq diff --git a/src/include/duckpgq_extension.hpp b/src/include/duckpgq_extension.hpp index b55e1790..2531cbe4 100644 --- a/src/include/duckpgq_extension.hpp +++ b/src/include/duckpgq_extension.hpp @@ -10,4 +10,4 @@ class DuckpgqExtension : public Extension { std::string Name() override; }; -} // namespace duckpgq +} // namespace duckdb diff --git a/src/include/duckpgq_extension_callback.hpp b/src/include/duckpgq_extension_callback.hpp index 61ade80f..38e84a6f 100644 --- a/src/include/duckpgq_extension_callback.hpp +++ b/src/include/duckpgq_extension_callback.hpp @@ -1,6 +1,5 @@ #pragma once - #include "duckpgq/common.hpp" #include "duckdb/planner/extension_callback.hpp" #include @@ -8,8 +7,8 @@ namespace duckdb { class DuckpgqExtensionCallback : public ExtensionCallback { void OnConnectionOpened(ClientContext &context) override { - context.registered_state->Insert("duckpgq", - make_shared_ptr(context.shared_from_this())); + context.registered_state->Insert( + "duckpgq", make_shared_ptr(context.shared_from_this())); } }; -} \ No newline at end of file +} // namespace duckdb \ No newline at end of file diff --git a/src/include/duckpgq_state.hpp b/src/include/duckpgq_state.hpp index dcecaf82..c1e65389 100644 --- a/src/include/duckpgq_state.hpp +++ b/src/include/duckpgq_state.hpp @@ -15,11 +15,14 @@ class DuckPGQState : public ClientContextState { duckpgq::core::CSR *GetCSR(int32_t id); void RetrievePropertyGraphs(const shared_ptr &context); - void ProcessPropertyGraphs(unique_ptr &property_graphs, bool is_vertex); + void ProcessPropertyGraphs(unique_ptr &property_graphs, + bool is_vertex); void PopulateEdgeSpecificFields(unique_ptr &chunk, idx_t row_idx, PropertyGraphTable &table); - static void ExtractListValues(const Value &list_value, vector &output); - void RegisterPropertyGraph(const shared_ptr &table, const string &graph_name, bool is_vertex); + static void ExtractListValues(const Value &list_value, + vector &output); + void RegisterPropertyGraph(const shared_ptr &table, + const string &graph_name, bool is_vertex); public: unique_ptr parse_data;