Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT REVIEW] Add new join order optimizer #3703

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/binder/query/query_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include "binder/expression_visitor.h"

using namespace kuzu::common;

namespace kuzu {
namespace binder {

Expand All @@ -12,6 +14,33 @@ std::size_t SubqueryGraphHasher::operator()(const SubqueryGraph& key) const {
return std::hash<std::bitset<MAX_NUM_QUERY_VARIABLES>>{}(key.queryRelsSelector);
}

std::unordered_map<common::idx_t, std::vector<common::idx_t>>
SubqueryGraph::getWCOJRelCandidates() const {
std::unordered_map<common::idx_t, std::vector<common::idx_t>> candidates;
for (auto relPos : getRelNbrPositions()) {
auto rel = queryGraph.getQueryRel(relPos);
// TODO(Xiyang): is the following check relevant?
if (!queryGraph.containsQueryNode(rel->getSrcNodeName()) ||
!queryGraph.containsQueryNode(rel->getDstNodeName())) {
continue;
}
auto srcNodePos = queryGraph.getQueryNodePos(rel->getSrcNodeName());
auto dstNodePos = queryGraph.getQueryNodePos(rel->getDstNodeName());
auto isSrcConnected = queryNodesSelector[srcNodePos];
auto isDstConnected = queryNodesSelector[dstNodePos];
// Closing rel should be handled with inner join.
if (isSrcConnected && isDstConnected) {
continue;
}
auto intersectNodePos = isSrcConnected ? dstNodePos : srcNodePos;
if (!candidates.contains(intersectNodePos)) {
candidates.insert({intersectNodePos, std::vector<common::idx_t>{}});
}
candidates.at(intersectNodePos).push_back(relPos);
}
return candidates;
}

bool SubqueryGraph::containAllVariables(std::unordered_set<std::string>& variables) const {
for (auto& var : variables) {
if (queryGraph.containsQueryNode(var) &&
Expand Down Expand Up @@ -168,6 +197,26 @@ std::vector<std::shared_ptr<NodeOrRelExpression>> QueryGraph::getAllPatterns() c
return patterns;
}

std::vector<std::shared_ptr<NodeExpression>> QueryGraph::getQueryNodes(
const std::vector<idx_t>& indices) const {
std::vector<std::shared_ptr<NodeExpression>> result;
result.reserve(indices.size());
for (auto idx : indices) {
result.push_back(queryNodes[idx]);
}
return result;
}

std::vector<std::shared_ptr<RelExpression>> QueryGraph::getQueryRels(
const std::vector<idx_t>& indices) const {
std::vector<std::shared_ptr<RelExpression>> result;
result.reserve(indices.size());
for (auto idx : indices) {
result.push_back(queryRels[idx]);
}
return result;
}

void QueryGraph::addQueryNode(std::shared_ptr<NodeExpression> queryNode) {
// Note that a node may be added multiple times. We should only keep one of it.
// E.g. MATCH (a:person)-[:knows]->(b:person), (a)-[:knows]->(c:person)
Expand Down
37 changes: 18 additions & 19 deletions src/include/binder/query/query_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ struct SubqueryGraph {
queryNodesSelector |= other.queryNodesSelector;
}

uint32_t getNumQueryRels() const { return queryRelsSelector.count(); }
uint32_t getTotalNumVariables() const {
common::idx_t getNumQueryNodes() const { return queryNodesSelector.count(); }
common::idx_t getNumQueryRels() const { return queryRelsSelector.count(); }
common::idx_t getTotalNumVariables() const {
return queryNodesSelector.count() + queryRelsSelector.count();
}
bool isSingleRel() const {
return queryRelsSelector.count() == 1 && queryNodesSelector.count() == 0;
}
std::unordered_map<common::idx_t, std::vector<common::idx_t>> getWCOJRelCandidates() const;

bool containAllVariables(std::unordered_set<std::string>& variables) const;

Expand Down Expand Up @@ -79,27 +81,20 @@ class QueryGraph {

std::vector<std::shared_ptr<NodeOrRelExpression>> getAllPatterns() const;

uint32_t getNumQueryNodes() const { return queryNodes.size(); }
common::idx_t getNumQueryNodes() const { return queryNodes.size(); }
bool containsQueryNode(const std::string& queryNodeName) const {
return queryNodeNameToPosMap.contains(queryNodeName);
}
std::vector<std::shared_ptr<NodeExpression>> getQueryNodes() const { return queryNodes; }
std::shared_ptr<NodeExpression> getQueryNode(const std::string& queryNodeName) const {
return queryNodes[getQueryNodePos(queryNodeName)];
}
std::vector<std::shared_ptr<NodeExpression>> getQueryNodes(
const std::vector<uint32_t>& nodePoses) const {
std::vector<std::shared_ptr<NodeExpression>> result;
result.reserve(nodePoses.size());
for (auto nodePos : nodePoses) {
result.push_back(queryNodes[nodePos]);
}
return result;
}
std::shared_ptr<NodeExpression> getQueryNode(uint32_t nodePos) const {
std::shared_ptr<NodeExpression> getQueryNode(common::idx_t nodePos) const {
return queryNodes[nodePos];
}
uint32_t getQueryNodePos(NodeExpression& node) const {
std::vector<std::shared_ptr<NodeExpression>> getQueryNodes(
const std::vector<common::idx_t>& nodePoses) const;
common::idx_t getQueryNodePos(NodeExpression& node) const {
return getQueryNodePos(node.getUniqueName());
}
uint32_t getQueryNodePos(const std::string& queryNodeName) const {
Expand All @@ -115,8 +110,12 @@ class QueryGraph {
std::shared_ptr<RelExpression> getQueryRel(const std::string& queryRelName) const {
return queryRels.at(queryRelNameToPosMap.at(queryRelName));
}
std::shared_ptr<RelExpression> getQueryRel(uint32_t relPos) const { return queryRels[relPos]; }
uint32_t getQueryRelPos(const std::string& queryRelName) const {
std::shared_ptr<RelExpression> getQueryRel(common::idx_t relPos) const {
return queryRels[relPos];
}
std::vector<std::shared_ptr<RelExpression>> getQueryRels(
const std::vector<common::idx_t>& indices) const;
common::idx_t getQueryRelPos(const std::string& queryRelName) const {
return queryRelNameToPosMap.at(queryRelName);
}
void addQueryRel(std::shared_ptr<RelExpression> queryRel);
Expand Down Expand Up @@ -146,9 +145,9 @@ class QueryGraphCollection {
void addAndMergeQueryGraphIfConnected(QueryGraph queryGraphToAdd);
void finalize();

uint32_t getNumQueryGraphs() const { return queryGraphs.size(); }
QueryGraph* getQueryGraphUnsafe(uint32_t idx) { return &queryGraphs[idx]; }
const QueryGraph* getQueryGraph(uint32_t idx) const { return &queryGraphs[idx]; }
common::idx_t getNumQueryGraphs() const { return queryGraphs.size(); }
QueryGraph* getQueryGraphUnsafe(common::idx_t idx) { return &queryGraphs[idx]; }
const QueryGraph* getQueryGraph(common::idx_t idx) const { return &queryGraphs[idx]; }

std::vector<std::shared_ptr<NodeExpression>> getQueryNodes() const;
std::vector<std::shared_ptr<RelExpression>> getQueryRels() const;
Expand Down
31 changes: 17 additions & 14 deletions src/include/planner/join_order/cardinality_estimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,33 @@ class CardinalityEstimator {
void addNodeIDDom(const binder::Expression& nodeID,
const std::vector<common::table_id_t>& tableIDs, transaction::Transaction* transaction);

uint64_t estimateScanNode(LogicalOperator* op);
uint64_t estimateHashJoin(const binder::expression_vector& joinKeys,
cardianlity_t estimateScanNode(LogicalOperator* op);
cardianlity_t estimateHashJoin(const binder::expression_vector& joinKeys,
const LogicalPlan& probePlan, const LogicalPlan& buildPlan);
uint64_t estimateCrossProduct(const LogicalPlan& probePlan, const LogicalPlan& buildPlan);
uint64_t estimateIntersect(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probePlan, const std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);
uint64_t estimateFlatten(const LogicalPlan& childPlan, f_group_pos groupPosToFlatten);
uint64_t estimateFilter(const LogicalPlan& childPlan, const binder::Expression& predicate);
cardianlity_t estimateHashJoin(const binder::expression_vector& joinKeys,
cardianlity_t probeCard, cardianlity_t buildCard);
cardianlity_t estimateCrossProduct(const LogicalPlan& probePlan, const LogicalPlan& buildPlan);
cardianlity_t estimateIntersect(const binder::expression_vector& joinNodeIDs,
cardianlity_t probeCard, const std::vector<cardianlity_t>& buildCard);
cardianlity_t estimateFlatten(const LogicalPlan& childPlan, f_group_pos groupPosToFlatten);
cardianlity_t estimateFilters(cardianlity_t inCardinality,
const binder::expression_vector& predicates);
cardianlity_t estimateFilter(cardianlity_t inCardinality, const binder::Expression& predicate);

double getExtensionRate(const binder::RelExpression& rel,
const binder::NodeExpression& boundNode, transaction::Transaction* transaction);
cardianlity_t getNumNodes(const std::vector<common::table_id_t>& tableIDs,
transaction::Transaction* transaction);
cardianlity_t getNumRels(const std::vector<common::table_id_t>& tableIDs,
transaction::Transaction* transaction);

private:
inline uint64_t atLeastOne(uint64_t x) { return x == 0 ? 1 : x; }
uint64_t atLeastOne(uint64_t x) { return x == 0 ? 1 : x; }

inline uint64_t getNodeIDDom(const std::string& nodeIDName) {
uint64_t getNodeIDDom(const std::string& nodeIDName) {
KU_ASSERT(nodeIDName2dom.contains(nodeIDName));
return nodeIDName2dom.at(nodeIDName);
}
uint64_t getNumNodes(const std::vector<common::table_id_t>& tableIDs,
transaction::Transaction* transaction);

uint64_t getNumRels(const std::vector<common::table_id_t>& tableIDs,
transaction::Transaction* transaction);

private:
main::ClientContext* context;
Expand Down
16 changes: 9 additions & 7 deletions src/include/planner/join_order/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ namespace planner {

class CostModel {
public:
static uint64_t computeExtendCost(const LogicalPlan& childPlan);
static uint64_t computeRecursiveExtendCost(uint8_t upperBound, double extensionRate,
const LogicalPlan& childPlan);
static uint64_t computeHashJoinCost(const binder::expression_vector& joinNodeIDs,
static cost_t computeExtendCost(cardianlity_t inCardinality);
static cost_t computeRecursiveExtendCost(cardianlity_t inCardinality, uint8_t upperBound,
double extensionRate);
static cost_t computeHashJoinCost(cost_t probeCost, cost_t buildCost, cardianlity_t probeCard,
cardianlity_t buildCard);
static cost_t computeHashJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeMarkJoinCost(const binder::expression_vector& joinNodeIDs,
static cost_t computeMarkJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeIntersectCost(const LogicalPlan& probePlan,
const std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);
static cost_t computeIntersectCost(cost_t probeCost, std::vector<cost_t> buildCosts,
cardianlity_t probeCard);
};

} // namespace planner
Expand Down
42 changes: 42 additions & 0 deletions src/include/planner/join_order/dp_table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include "binder/query/query_graph.h"
#include "join_tree.h"

namespace kuzu {
namespace planner {

class DPLevel {
public:
bool contains(const binder::SubqueryGraph& subqueryGraph) const {
return subgraphToJoinTree.contains(subqueryGraph);
}
const JoinTree& getJoinTree(const binder::SubqueryGraph& subqueryGraph) const {
KU_ASSERT(contains(subqueryGraph));
return subgraphToJoinTree.at(subqueryGraph);
}

void add(const binder::SubqueryGraph& subqueryGraph, const JoinTree& joinTree);

const binder::subquery_graph_V_map_t<JoinTree>& getSubgraphAndJoinTrees() const {
return subgraphToJoinTree;
}

private:
binder::subquery_graph_V_map_t<JoinTree> subgraphToJoinTree;
};

class DPTable {
public:
void init(common::idx_t maxLevel);

void add(const binder::SubqueryGraph& subqueryGraph, const JoinTree& joinTree);

const DPLevel& getLevel(common::idx_t idx) const { return levels[idx]; }

private:
std::vector<DPLevel> levels;
};

} // namespace planner
} // namespace kuzu
88 changes: 88 additions & 0 deletions src/include/planner/join_order/join_order_solver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#pragma once

#include "binder/query/query_graph.h"
#include "cardinality_estimator.h"
#include "dp_table.h"
#include "join_tree.h"
#include "planner/join_order_enumerator_context.h"

namespace kuzu {
namespace planner {

class PropertyExprCollection {
public:
binder::expression_vector getProperties(std::shared_ptr<binder::Expression> pattern) const {
if (!patternToProperties.contains(pattern)) {
return binder::expression_vector{};
}
return patternToProperties.at(pattern);
}

void addProperties(std::shared_ptr<binder::Expression> pattern,
const binder::expression_vector& properties) {
KU_ASSERT(!patternToProperties.contains(pattern));
patternToProperties.insert({pattern, properties});
}

private:
binder::expression_map<binder::expression_vector> patternToProperties;
};

/*
* JoinOrderSolver solves a reasonable join order for
*/
class JoinOrderSolver {
public:
explicit JoinOrderSolver(const binder::QueryGraph& queryGraph,
binder::expression_vector predicates, PropertyExprCollection propertyExprCollection,
main::ClientContext* context)
: queryGraph{queryGraph}, queryGraphPredicates{std::move(predicates)},
propertyCollection{std::move(propertyExprCollection)}, context{context} {}

void setCorrExprs(SubqueryType subqueryType_, binder::expression_vector exprs,
cardianlity_t card) {
subqueryType = subqueryType_;
corrExprs = std::move(exprs);
corrExprsCardinality = card;
}

JoinTree solve();

private:
void planLevel(common::idx_t level);
void planBaseScans();
void planCorrelatedExpressionsScan(const binder::SubqueryGraph& newSubgraph);
void planBaseNodeScan(common::idx_t nodeIdx);
void planBaseRelScan(common::idx_t relIdx);
void planBinaryJoin(common::idx_t leftSize, common::idx_t rightSize);
void planWorstCaseOptimalJoin(common::idx_t size, common::idx_t otherSize);
void planBinaryJoin(const binder::SubqueryGraph& subqueryGraph, const JoinTree& joinTree,
const binder::SubqueryGraph& otherSubqueryGraph, const JoinTree& otherJoinTree,
std::vector<std::shared_ptr<binder::NodeExpression>> joinNodes);
void planHashJoin(const JoinTree& joinTree, const JoinTree& otherJoinTree,
std::vector<std::shared_ptr<binder::NodeExpression>> joinNodes,
const binder::SubqueryGraph& newSubqueryGraph, const binder::expression_vector& predicates);
void planWorstCaseOptimalJoin(const JoinTree& joinTree,
const std::vector<JoinTree>& relJoinTrees, std::shared_ptr<binder::NodeExpression> joinNode,
const binder::SubqueryGraph& newSubqueryGraph, const binder::expression_vector& predicates);
bool tryPlanIndexNestedLoopJoin(const JoinTree& joinTree, const JoinTree& otherJoinTree,
std::shared_ptr<binder::NodeExpression> joinNode,
const binder::SubqueryGraph& newSubqueryGraph, const binder::expression_vector& predicates);

private:
// Query graph to plan
const binder::QueryGraph& queryGraph;
// Predicates to apply for given query graph
binder::expression_vector queryGraphPredicates;
//
SubqueryPlanInfo subqueryPlanInfo;
// Properties to scan for given query graph.
PropertyExprCollection propertyCollection;

main::ClientContext* context;
DPTable dpTable;
CardinalityEstimator cardinalityEstimator;
};

} // namespace planner
} // namespace kuzu
32 changes: 32 additions & 0 deletions src/include/planner/join_order/join_plan_solver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include "join_tree.h"
#include "planner/planner.h"

namespace kuzu {
namespace planner {

/*
* JoinPlanSolver solves a JoinTree into a LogicalPlan
* */
class JoinPlanSolver {
public:
JoinPlanSolver(Planner* planner) : planner{planner} {}

LogicalPlan solve(const JoinTree& joinTree);

private:
LogicalPlan solveTreeNode(const JoinTreeNode& current, const JoinTreeNode* parent);

LogicalPlan solveExprScanTreeNode(const JoinTreeNode& treeNode);
LogicalPlan solveNodeScanTreeNode(const JoinTreeNode& treeNode);
LogicalPlan solveRelScanTreeNode(const JoinTreeNode& treeNode, const JoinTreeNode& parent);
LogicalPlan solveBinaryJoinTreeNode(const JoinTreeNode& treeNode);
LogicalPlan solveMultiwayJoinTreeNode(const JoinTreeNode& treeNode);

private:
Planner* planner;
};

} // namespace planner
} // namespace kuzu
Loading