generated from duckdb/extension-template
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
If the lower bound is not greater than 1, the high performance algori…
…thm is invoked
- Loading branch information
1 parent
fbc9303
commit 79b7495
Showing
10 changed files
with
593 additions
and
182 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
209 changes: 209 additions & 0 deletions
209
duckpgq/src/duckpgq/functions/scalar/iterativelength_lowerbound.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
#include <duckpgq_extension.hpp> | ||
#include "duckdb/main/client_data.hpp" | ||
#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" | ||
#include "duckdb/planner/expression/bound_function_expression.hpp" | ||
#include "duckpgq/common.hpp" | ||
#include "duckpgq/duckpgq_functions.hpp" | ||
|
||
namespace duckdb { | ||
|
||
static bool IterativeLengthLowerBound(int64_t v_size, int64_t *v, vector<int64_t> &e, | ||
vector<vector<unordered_set<int64_t>>> &parents_v, | ||
vector<std::bitset<LANE_LIMIT>> &seen, | ||
vector<std::bitset<LANE_LIMIT>> &visit, | ||
vector<std::bitset<LANE_LIMIT>> &next) { | ||
bool change = false; | ||
for (auto i = 0; i < v_size; i++) { | ||
next[i] = 0; | ||
} | ||
|
||
for (auto lane = 0; lane < LANE_LIMIT; lane++) { | ||
for (auto i = 0; i < v_size; i++) { | ||
if (visit[i][lane]) { | ||
for (auto offset = v[i]; offset < v[i + 1]; offset++) { | ||
auto n = e[offset]; | ||
if (seen[n][lane] == false || parents_v[i][lane].find(n) == parents_v[i][lane].end()) { | ||
parents_v[n][lane] = parents_v[i][lane]; | ||
parents_v[n][lane].insert(i); | ||
next[n][lane] = true; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
for (auto i = 0; i < v_size; i++) { | ||
seen[i] = seen[i] | next[i]; | ||
change |= next[i].any(); | ||
} | ||
|
||
return change; | ||
} | ||
|
||
static void IterativeLengthLowerBoundFunction(DataChunk &args, ExpressionState &state, | ||
Vector &result) { | ||
auto &func_expr = (BoundFunctionExpression &)state.expr; | ||
auto &info = (IterativeLengthFunctionData &)*func_expr.bind_info; | ||
auto duckpgq_state_entry = info.context.registered_state.find("duckpgq"); | ||
if (duckpgq_state_entry == info.context.registered_state.end()) { | ||
//! Wondering how you can get here if the extension wasn't loaded, but | ||
//! leaving this check in anyways | ||
throw MissingExtensionException( | ||
"The DuckPGQ extension has not been loaded"); | ||
} | ||
auto duckpgq_state = | ||
reinterpret_cast<DuckPGQState *>(duckpgq_state_entry->second.get()); | ||
|
||
D_ASSERT(duckpgq_state->csr_list[info.csr_id]); | ||
|
||
if ((uint64_t)info.csr_id + 1 > duckpgq_state->csr_list.size()) { | ||
throw ConstraintException("Invalid ID"); | ||
} | ||
auto csr_entry = duckpgq_state->csr_list.find((uint64_t)info.csr_id); | ||
if (csr_entry == duckpgq_state->csr_list.end()) { | ||
throw ConstraintException( | ||
"Need to initialize CSR before doing shortest path"); | ||
} | ||
|
||
if (!(csr_entry->second->initialized_v && csr_entry->second->initialized_e)) { | ||
throw ConstraintException( | ||
"Need to initialize CSR before doing shortest path"); | ||
} | ||
int64_t v_size = args.data[1].GetValue(0).GetValue<int64_t>(); | ||
int64_t *v = (int64_t *)duckpgq_state->csr_list[info.csr_id]->v; | ||
vector<int64_t> &e = duckpgq_state->csr_list[info.csr_id]->e; | ||
|
||
// get src and dst vectors for searches | ||
auto &src = args.data[2]; | ||
auto &dst = args.data[3]; | ||
UnifiedVectorFormat vdata_src; | ||
UnifiedVectorFormat vdata_dst; | ||
src.ToUnifiedFormat(args.size(), vdata_src); | ||
dst.ToUnifiedFormat(args.size(), vdata_dst); | ||
auto src_data = (int64_t *)vdata_src.data; | ||
auto dst_data = (int64_t *)vdata_dst.data; | ||
|
||
// get lowerbound and upperbound | ||
auto &lower = args.data[4]; | ||
auto &upper = args.data[5]; | ||
UnifiedVectorFormat vdata_lower_bound; | ||
UnifiedVectorFormat vdata_upper_bound; | ||
lower.ToUnifiedFormat(args.size(), vdata_lower_bound); | ||
upper.ToUnifiedFormat(args.size(), vdata_upper_bound); | ||
auto lower_bound = ((int64_t *)vdata_lower_bound.data)[0]; | ||
auto upper_bound = ((int64_t *)vdata_upper_bound.data)[0]; | ||
|
||
ValidityMask &result_validity = FlatVector::Validity(result); | ||
|
||
// create result vector | ||
result.SetVectorType(VectorType::FLAT_VECTOR); | ||
auto result_data = FlatVector::GetData<int64_t>(result); | ||
|
||
// create temp SIMD arrays | ||
vector<std::bitset<LANE_LIMIT>> seen(v_size); | ||
vector<std::bitset<LANE_LIMIT>> visit1(v_size); | ||
vector<std::bitset<LANE_LIMIT>> visit2(v_size); | ||
vector<vector<unordered_set<int64_t>>> parents_v(v_size, std::vector<unordered_set<int64_t>>(LANE_LIMIT)); | ||
|
||
// maps lane to search number | ||
short lane_to_num[LANE_LIMIT]; | ||
for (int64_t lane = 0; lane < LANE_LIMIT; lane++) { | ||
lane_to_num[lane] = -1; // inactive | ||
} | ||
|
||
idx_t started_searches = 0; | ||
while (started_searches < args.size()) { | ||
|
||
// empty visit vectors | ||
for (auto i = 0; i < v_size; i++) { | ||
seen[i] = 0; | ||
visit1[i] = 0; | ||
} | ||
|
||
// add search jobs to free lanes | ||
uint64_t active = 0; | ||
for (int64_t lane = 0; lane < LANE_LIMIT; lane++) { | ||
lane_to_num[lane] = -1; | ||
while (started_searches < args.size()) { | ||
int64_t search_num = started_searches++; | ||
int64_t src_pos = vdata_src.sel->get_index(search_num); | ||
int64_t dst_pos = vdata_dst.sel->get_index(search_num); | ||
if (!vdata_src.validity.RowIsValid(src_pos)) { | ||
result_validity.SetInvalid(search_num); | ||
result_data[search_num] = (int64_t)-1; /* no path */ | ||
} else if (src_data[src_pos] == dst_data[dst_pos]) { | ||
result_data[search_num] = (int64_t)-1; /* no path */ | ||
visit1[src_data[src_pos]][lane] = true; | ||
lane_to_num[lane] = search_num; // active lane | ||
active++; | ||
break; | ||
} else { | ||
result_data[search_num] = (int64_t)-1; /* initialize to no path */ | ||
seen[src_data[src_pos]][lane] = true; | ||
visit1[src_data[src_pos]][lane] = true; | ||
lane_to_num[lane] = search_num; // active lane | ||
active++; | ||
break; | ||
} | ||
} | ||
} | ||
|
||
// make passes while a lane is still active | ||
for (int64_t iter = 1; active && iter <= upper_bound; iter++) { | ||
bool stop = !IterativeLengthLowerBound(v_size, v, e, parents_v, seen, (iter & 1) ? visit1 : visit2, | ||
(iter & 1) ? visit2 : visit1); | ||
// detect lanes that finished | ||
for (int64_t lane = 0; lane < LANE_LIMIT; lane++) { | ||
int64_t search_num = lane_to_num[lane]; | ||
if (search_num >= 0) { // active lane | ||
int64_t dst_pos = vdata_dst.sel->get_index(search_num); | ||
if (seen[dst_data[dst_pos]][lane]){ | ||
|
||
// check if the path length is within bounds | ||
// bound vector is either a constant or a flat vector | ||
if (iter < lower_bound) { | ||
// when reach the destination too early, treat destination as null | ||
// looks like the graph does not have that vertex | ||
seen[dst_data[dst_pos]][lane] = false; | ||
(iter & 1) ? visit2[dst_data[dst_pos]][lane] = false | ||
: visit1[dst_data[dst_pos]][lane] = false; | ||
continue; | ||
} else { | ||
result_data[search_num] = | ||
iter; /* found at iter => iter = path length */ | ||
lane_to_num[lane] = -1; // mark inactive | ||
active--; | ||
} | ||
|
||
} | ||
} | ||
} | ||
if (stop) { | ||
break; | ||
} | ||
} | ||
|
||
// no changes anymore: any still active searches have no path | ||
for (int64_t lane = 0; lane < LANE_LIMIT; lane++) { | ||
int64_t search_num = lane_to_num[lane]; | ||
if (search_num >= 0) { // active lane | ||
result_validity.SetInvalid(search_num); | ||
result_data[search_num] = (int64_t)-1; /* no path */ | ||
lane_to_num[lane] = -1; // mark inactive | ||
} | ||
} | ||
} | ||
duckpgq_state->csr_to_delete.insert(info.csr_id); | ||
} | ||
|
||
CreateScalarFunctionInfo DuckPGQFunctions::GetIterativeLengthLowerBoundFunction() { | ||
auto fun = ScalarFunction("iterativelength_lowerbound", | ||
{LogicalType::INTEGER, LogicalType::BIGINT, | ||
LogicalType::BIGINT, LogicalType::BIGINT, | ||
LogicalType::BIGINT, LogicalType::BIGINT}, | ||
LogicalType::BIGINT, IterativeLengthLowerBoundFunction, | ||
IterativeLengthFunctionData::IterativeLengthBind); | ||
return CreateScalarFunctionInfo(fun); | ||
} | ||
|
||
} // namespace duckdb |
Oops, something went wrong.