From d269b406b681fcffc4cac3f1759fd1e41633c342 Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 16 Oct 2023 16:02:34 -0400 Subject: [PATCH 01/32] compiler build --- lib/compiler/CMakeLists.txt | 1 + .../include/compiler/machine_mapping.h | 24 +- .../include/compiler/unity_algorithm.h | 17 +- lib/compiler/src/graph_utils.cc | 2 +- lib/compiler/src/machine_mapping.cc | 126 +- lib/compiler/src/old/basic_graph.h | 158 - lib/compiler/src/old/dominators.h | 494 --- lib/compiler/src/old/graph.cc | 1255 ------ lib/compiler/src/old/graph.h | 248 -- lib/compiler/src/old/graph_structures.h | 269 -- lib/compiler/src/old/node.h | 47 - .../src/old/parallel_dim_mapping_record.h | 4 - lib/compiler/src/old/search_helper.cc | 525 --- lib/compiler/src/old/search_helper.h | 122 - lib/compiler/src/old/simplification.cc | 189 - lib/compiler/src/old/simplification.h | 34 - lib/compiler/src/old/split_types.cc | 36 - lib/compiler/src/old/split_types.h | 32 - lib/compiler/src/old/substitution.cc | 3733 ----------------- lib/compiler/src/old/substitution.h | 309 -- lib/compiler/src/unity_algorithm.cc | 16 +- lib/pcg/include/pcg/machine_view.h | 5 +- .../include/pcg/parallel_computation_graph.h | 10 + .../include/substitutions/substitution.h | 8 + lib/utils/include/utils/graph/algorithms.h | 3 + .../utils/graph/labelled/node_labelled.h | 1 + .../utils/graph/labelled/node_labelled_open.h | 8 +- .../include/utils/graph/labelled/open_views.h | 10 + .../utils/graph/labelled/output_labelled.h | 5 +- .../graph/labelled/output_labelled_open.h | 10 +- .../include/utils/graph/labelled/views.h | 3 +- 31 files changed, 137 insertions(+), 7567 deletions(-) delete mode 100644 lib/compiler/src/old/basic_graph.h delete mode 100644 lib/compiler/src/old/dominators.h delete mode 100644 lib/compiler/src/old/graph.cc delete mode 100644 lib/compiler/src/old/graph.h delete mode 100644 lib/compiler/src/old/graph_structures.h delete mode 100644 lib/compiler/src/old/node.h delete mode 100644 lib/compiler/src/old/parallel_dim_mapping_record.h delete mode 100644 lib/compiler/src/old/search_helper.cc delete mode 100644 lib/compiler/src/old/search_helper.h delete mode 100644 lib/compiler/src/old/simplification.cc delete mode 100644 lib/compiler/src/old/simplification.h delete mode 100644 lib/compiler/src/old/split_types.cc delete mode 100644 lib/compiler/src/old/split_types.h delete mode 100644 lib/compiler/src/old/substitution.cc delete mode 100644 lib/compiler/src/old/substitution.h diff --git a/lib/compiler/CMakeLists.txt b/lib/compiler/CMakeLists.txt index daa96b08bc..45c369fcdf 100644 --- a/lib/compiler/CMakeLists.txt +++ b/lib/compiler/CMakeLists.txt @@ -14,6 +14,7 @@ ff_add_library( optional pcg spdlog + substitutions ) add_subdirectory(ffi) diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 4089260735..e8d7457fbf 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -5,10 +5,12 @@ #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph.h" -#include "sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph.h" namespace FlexFlow { +using SubParallelComputationGraphView = OutputLabelledOpenMultiDiGraphView; + struct MachineMapping { static MachineMapping combine(MachineMapping const &, MachineMapping const &); static bool nodes_are_disjoint(MachineMapping const &m1, @@ -20,14 +22,13 @@ FF_VISITABLE_STRUCT(MachineMapping, machine_views); struct OptimalCostState { SerialParallelDecomposition subgraph; - MachineSpecification resource; - req> source_machine_view, sink_machine_view; + req resource; + // req> given_machine_views; + // req> frontier_machine_views; }; FF_VISITABLE_STRUCT(OptimalCostState, subgraph, - resource, - source_machine_view, - sink_machine_view); + resource); struct OptimalCostResult { static OptimalCostResult sequential_combine(OptimalCostResult const &s1, @@ -37,7 +38,7 @@ struct OptimalCostResult { static OptimalCostResult infinity(); float runtime; - MachineMapping machine_mapping; + req machine_mapping; }; FF_VISITABLE_STRUCT(OptimalCostResult, runtime, machine_mapping); @@ -67,4 +68,13 @@ OptimalCostResult } // namespace FlexFlow +namespace std { + +template <> +struct hash> { + size_t operator()(std::unordered_map const &g) const; +}; + +}; + #endif diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index 57f1c8c063..fc068d48c5 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -4,18 +4,15 @@ #include "cost_estimate.h" #include "machine_mapping.h" #include "pcg/computation_graph.h" -#include "sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph.h" namespace FlexFlow { -struct Substitution {}; - struct Strategy { ParallelComputationGraph pcg; MachineMapping machine_mapping; req runtime; }; -FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime); struct StrategyRuntimeCmp { bool operator()(Strategy const &, Strategy const &); @@ -30,7 +27,7 @@ struct OptimizerConfig { Strategy graph_optimize(ComputationGraph &cg, - ICostEstimator const &cost_estimator, + CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( Operator const &, MachineSpecification const &)> const @@ -39,4 +36,14 @@ Strategy } // namespace FlexFlow +VISITABLE_STRUCT(FlexFlow::Strategy, pcg, machine_mapping, runtime); +namespace std { + +template <> +struct hash { + size_t operator()(FlexFlow::Strategy const &) const; +}; + +}; + #endif diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 4f22490ffa..d7f15e0796 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -4,7 +4,7 @@ namespace FlexFlow { SerialParallelDecomposition get_serial_parallel_decomposition(ParallelComputationGraph const &pcg) { - return get_serial_parallel_decomposition(as_digraph(pcg)); + return get_serial_parallel_decomposition(pcg.value()); } std::vector diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 2f6af8a62b..fb04f57eac 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -45,9 +45,12 @@ bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, optional OptimalCostCache::load(OptimalCostState const &state) const { - if (contains_key(cache, state)) { - return make_optional(cache.at(state)); - } + auto it = cache.find(state); + // if (contains_key(cache, state)) { + // // auto result = cache.at(state); + // OptimalCostResult result = OptimalCostResult::infinity(); + // return make_optional(result); + // } return nullopt; } @@ -88,31 +91,10 @@ GraphSplit return {get_nodes(pre_decomposition), get_nodes(post_decomposition)}; } -std::pair - apply_split(SubParallelComputationGraph const &g, GraphSplit const &split) { - OpenMultiDiGraphView g1 = get_subgraph(g, split.first); - OpenMultiDiGraphView g2 = get_subgraph(g, split.second); - - if (get_edge_splits(g, split).size() > 0) { - // Sequential split - if (get_open_sinks(g1).size() <= get_open_sources(g2).size()) { - // get_open_sinks(*g1).size() should be 1 in perfect sp graphs - return {get_subgraph(g, split.first), - get_subgraph(g, split.second)}; - } else { - return {get_subgraph(g, split.first), - get_subgraph(g, split.first)}; - } - } else { - // Parallel split - return {get_subgraph(g, split.first), - get_subgraph(g, split.second)}; - } -} - -float estimate_cost(SubParallelComputationGraph const &g, +float estimate_cost(SubParallelComputationGraphView const &g, CostEstimator const &estimator, - MachineMapping const &device_mapping) { + MachineMapping const &device_mapping, + std::unordered_map const &frontier_machine_views) { NOT_IMPLEMENTED(); } @@ -122,26 +104,26 @@ void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { struct OptimalCost { OptimalCost( - SubParallelComputationGraph const &g, + SubParallelComputationGraphView const &g, CostEstimator const &cost_estimator, MachineSpecification const &resource, - optional const &source_machine_view, // assume perfect SP - optional const &sink_machine_view, + std::unordered_map const &given_machine_views, + std::unordered_map const &frontier_machine_views, std::function( Operator const &, MachineSpecification const &)> const &allowed_machine_views, OptimalCostCache &cached_subgraph_costs) : g(g), cost_estimator(cost_estimator), resource(resource), - source_machine_view(source_machine_view), - sink_machine_view(sink_machine_view), + given_machine_views(restrict_keys(given_machine_views, get_nodes(g))), + frontier_machine_views(restrict_keys(frontier_machine_views, get_edges(g))), allowed_machine_views(allowed_machine_views), cached_subgraph_costs(cached_subgraph_costs) {} - SubParallelComputationGraph const &g; + SubParallelComputationGraphView const &g; CostEstimator const &cost_estimator; MachineSpecification const &resource; - optional const &source_machine_view; - optional const &sink_machine_view; + std::unordered_map const &given_machine_views; + std::unordered_map const &frontier_machine_views; std::function( Operator const &, MachineSpecification const &)> const &allowed_machine_views; @@ -149,7 +131,7 @@ struct OptimalCost { template OptimalCostResult operator()(T const &t) const { - OptimalCostState state{g, resource, source_machine_view, sink_machine_view}; + OptimalCostState state{t, resource/*, given_machine_views, frontier_machine_views*/}; optional cached_result = cached_subgraph_costs.load(state); @@ -168,44 +150,40 @@ struct OptimalCost { SerialParallelDecomposition pre_decompn = decomposed.first; SerialParallelDecomposition post_decompn = decomposed.second; - auto subgraphs = apply_split(g, get_graph_split(pre_decompn, post_decompn)); - SubParallelComputationGraph pre_graph = subgraphs.first, - post_graph = subgraphs.second; + GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); + SubParallelComputationGraphView pre_graph = get_subgraph(g, graph_split.first); + SubParallelComputationGraphView post_graph = get_subgraph(g, graph_split.second); - std::unordered_set pre_graph_sinks = get_closed_sinks(pre_graph); std::unordered_set post_graph_sources = get_closed_sources(post_graph); - assert(pre_graph_sinks.size() + post_graph_sources.size() == - 1); // assume perfect SP - - Node const &split_point = - get_only(set_union(pre_graph_sinks, post_graph_sources)); + assert(post_graph_sources.size() == 1); // assume perfect SP + Node split_point = get_only(post_graph_sources); + OutputMultiDiEdge split_edge = get_only(get_open_outputs(pre_graph)); + OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (MachineView const &mv : allowed_machine_views(g.at(split_point), resource)) { - optional pre_sink_mv = - contains(pre_graph_sinks, split_point) ? make_optional(mv) : nullopt; - optional post_source_mv = - contains(post_graph_sources, split_point) ? make_optional(mv) - : nullopt; + auto new_given_machine_views = merge_maps(given_machine_views, std::unordered_map{{split_point, mv}}); + auto new_frontier_machine_views = merge_maps(frontier_machine_views, + std::unordered_map{{split_edge, mv}}); minimize_runtime(optimal_result, OptimalCostResult::sequential_combine( visit(OptimalCost(pre_graph, cost_estimator, resource, - source_machine_view, - pre_sink_mv, + given_machine_views, + new_frontier_machine_views, allowed_machine_views, cached_subgraph_costs), pre_decompn), visit(OptimalCost(post_graph, cost_estimator, resource, - post_source_mv, - sink_machine_view, + new_given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), post_decompn))); @@ -219,23 +197,24 @@ struct OptimalCost { SerialParallelDecomposition decompn1 = decomposed.first; SerialParallelDecomposition decompn2 = decomposed.second; - auto subgraphs = apply_split(g, get_graph_split(decompn1, decompn2)); - SubParallelComputationGraph g1 = subgraphs.first, g2 = subgraphs.second; + GraphSplit graph_split = get_graph_split(decompn1, decompn2); + SubParallelComputationGraphView g1 = get_subgraph(g, graph_split.first), + g2 = get_subgraph(g, graph_split.second); OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( visit(OptimalCost(g1, cost_estimator, resource, - source_machine_view, - sink_machine_view, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn1), visit(OptimalCost(g2, cost_estimator, resource, - source_machine_view, - sink_machine_view, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn2)); @@ -246,16 +225,16 @@ struct OptimalCost { visit(OptimalCost(g1, cost_estimator, resource_split.first, - source_machine_view, - sink_machine_view, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn1), visit(OptimalCost(g2, cost_estimator, resource_split.second, - source_machine_view, - sink_machine_view, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn2))); @@ -265,24 +244,17 @@ struct OptimalCost { } OptimalCostResult optimal_cost(Node const &node) const { - if (source_machine_view) { - assert(get_closed_sources(g).empty()); + if (contains_key(given_machine_views, node)) { assert(contains(allowed_machine_views(g.at(node), resource), source_machine_view.value())); - MachineMapping mv_map{{{node, source_machine_view.value()}}}; - return {estimate_cost(g, cost_estimator, mv_map), mv_map}; - } else if (sink_machine_view) { - assert(get_closed_sinks(g).empty()); - assert(contains(allowed_machine_views(g.at(node), resource), - sink_machine_view.value())); - MachineMapping mv_map{{{node, sink_machine_view.value()}}}; - return {estimate_cost(g, cost_estimator, mv_map), mv_map}; + MachineMapping mv_map{given_machine_views}; + return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), mv_map}; } else { OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (auto mv : allowed_machine_views(g.at(node), resource)) { MachineMapping mv_map{{{node, mv}}}; minimize_runtime(optimal_result, - {estimate_cost(g, cost_estimator, mv_map), mv_map}); + {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), mv_map}); } return optimal_result; } @@ -300,8 +272,8 @@ OptimalCostResult return visit(OptimalCost(pcg_to_subpcg(g), cost_estimator, resources, - nullopt, - nullopt, + {}, + {}, allowed_machine_views, cached_subgraph_costs), get_serial_parallel_decomposition(g)); diff --git a/lib/compiler/src/old/basic_graph.h b/lib/compiler/src/old/basic_graph.h deleted file mode 100644 index fca575e42a..0000000000 --- a/lib/compiler/src/old/basic_graph.h +++ /dev/null @@ -1,158 +0,0 @@ -#ifndef _BASIC_GRAPH_H -#define _BASIC_GRAPH_H - -#include "utils/hash-utils.h" -#include -#include - -namespace FlexFlow { -namespace PCG { -namespace Utils { - -template -struct GraphStructure; -/* -{ - using graph_type = ...; - using node_type = - using tGraph = G; - using tNode = N; - using tEdge = E; - - std::unordered_set get_nodes(G const &) const; - std::unordered_set get_incoming_edges(G const &, N const &) const; - std::unordered_set get_outgoing_edges(G const &, N const &) const; - N get_src(G const &, E const &) const; - N get_dst(G const &, E const &) const; -}; -*/ - -template -struct BasicGraph { - using N = T; - using E = std::pair; - - std::unordered_set nodes; - std::unordered_map> in_edges, out_edges; - - BasicGraph() : BasicGraph({}, {}) {} - - BasicGraph(std::unordered_set const &nodes, std::unordered_set edges) - : nodes(), in_edges(), out_edges() { - this->add_nodes(nodes); - this->add_edges(edges); - } - - void add_edge(N const &src, N const &dst) { - nodes.insert(src); - nodes.insert(dst); - out_edges[src].insert({src, dst}); - in_edges[dst].insert({src, dst}); - } - - void add_edge(E const &e) { - nodes.insert(e.first); - nodes.insert(e.second); - out_edges[e.first].insert(e); - in_edges[e.second].insert(e); - } - - bool has_edge(N const &src, N const &dst) const { - auto iter = this->in_edges.find(dst); - if (iter == this->in_edges.end()) { - return false; - } - - std::unordered_set const &dst_in_edges = iter->second; - return dst_in_edges.find({src, dst}) != dst_in_edges.end(); - } - - bool has_edge(E const &e) const { - return this->has_edge(e.first, e.second); - } - - void remove_edge(N const &src, N const &dst) { - out_edges[src].erase({src, dst}); - in_edges[dst].erase({src, dst}); - } - - void remove_edge(E const &e) { - out_edges[e.first].erase(e); - in_edges[e.second].erase(e); - } - - void add_node(N const &n) { - nodes.insert(n); - } - - template > - void add_nodes(Container const &nodes) { - for (auto const &n : nodes) { - this->add_node(n); - } - } - - template > - void add_edges(Container const &edges) { - for (auto const &e : edges) { - this->add_edge(e); - } - } - - bool operator==(BasicGraph const &other) const { - return this->nodes == other.nodes && this->in_edges == other.in_edges && - this->out_edges == other.out_edges; - } -}; - -template -struct GraphStructure> { - using graph_type = BasicGraph; - using vertex_type = T; - using edge_type = std::pair; - - std::unordered_set get_nodes(graph_type const &g) const { - std::unordered_set nodes(g.nodes); - return nodes; - } - - std::unordered_set get_incoming_edges(graph_type const &g, - vertex_type const &n) const { - std::unordered_set edges; - if (g.in_edges.find(n) != g.in_edges.end()) { - edges.insert(g.in_edges.at(n).begin(), g.in_edges.at(n).end()); - } - return edges; - } - - std::unordered_set get_outgoing_edges(graph_type const &g, - vertex_type const &n) const { - std::unordered_set edges; - if (g.out_edges.find(n) != g.out_edges.end()) { - edges.insert(g.out_edges.at(n).begin(), g.out_edges.at(n).end()); - } - return edges; - } - - vertex_type get_src(graph_type const &g, edge_type const &e) const { - return e.first; - } - - vertex_type get_dst(graph_type const &g, edge_type const &e) const { - return e.second; - } - - void set_src(graph_type const &g, edge_type &e, vertex_type const &n) const { - e.first = n; - } - - void set_dst(graph_type const &g, edge_type &e, vertex_type const &n) const { - e.second = n; - } -}; - -} // namespace Utils -} // namespace PCG -} // namespace FlexFlow - -#endif // _BASIC_GRAPH_H diff --git a/lib/compiler/src/old/dominators.h b/lib/compiler/src/old/dominators.h deleted file mode 100644 index 70449ee001..0000000000 --- a/lib/compiler/src/old/dominators.h +++ /dev/null @@ -1,494 +0,0 @@ -#ifndef _DOMINATORS_H -#define _DOMINATORS_H - -#include "basic_graph.h" -#include "graph_structures.h" -#include "tl/optional.hpp" -#include "utils/dot_file.h" -#include "utils/record_formatter.h" -#include -#include -#include -#include - -namespace FlexFlow { -namespace PCG { -namespace Utils { - -template > -std::unordered_set nodes(G const &g) { - Structure s; - - return s.get_nodes(g); -} - -template > -bool has_edge(G const &g, - typename Structure::vertex_type const &src, - typename Structure::vertex_type const &dst) { - Structure s; - - for (auto const &e : s.get_outgoing_edges(g, src)) { - if (s.get_dst(g, e) == dst) { - return true; - } - } - - return false; -} - -template > -std::unordered_set - outgoing_edges(G const &g, typename Structure::vertex_type const &n) { - Structure s; - return s.get_outgoing_edges(g, n); -} - -template > -std::pair - get_basic_edge(G const &g, typename Structure::edge_type const &e) { - Structure s; - - return {s.get_src(g, e), s.get_dst(g, e)}; -} - -template > -std::vector get_edges(G const &g) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - Structure s; - - std::vector edges; - - for (N const &n : s.get_nodes(g)) { - for (E const &e : s.get_outgoing_edges(g, n)) { - edges.push_back(e); - } - } - - return edges; -} - -template > -void successors(G const &g, - typename Structure::vertex_type const &node, - std::unordered_set *succ) { - Structure s; - for (auto const &edge : s.get_outgoing_edges(g, node)) { - succ->insert(s.get_dst(g, edge)); - } -} - -template > -std::unordered_set - successors(G const &g, typename Structure::vertex_type const &node) { - // using N = typename Structure::vertex_type; - - std::unordered_set succ; - successors(g, node, &succ); - - return succ; -} - -template > -tl::optional - successor(G const &g, typename Structure::vertex_type const &node) { - auto succs = successors(g, node); - if (succs.size() == 1) { - return *succs.begin(); - } else { - return tl::nullopt; - } -} - -template > -void predecessors(G const &g, - typename Structure::vertex_type const &node, - std::unordered_set *pred) { - Structure s; - for (auto const &edge : s.get_incoming_edges(g, node)) { - pred->insert(s.get_src(g, edge)); - } -} - -template > -std::unordered_set - predecessors(G const &g, typename Structure::vertex_type const &node) { - // using N = typename Structure::vertex_type; - - std::unordered_set pred; - predecessors(g, node, &pred); - - return pred; -} - -template > -tl::optional - predecessor(G const &g, typename Structure::vertex_type const &node) { - auto preds = predecessors(g, node); - if (preds.size() == 1) { - return *preds.begin(); - } else { - return tl::nullopt; - } -} - -template > -std::unordered_set roots(G const &g) { - using N = typename Structure::vertex_type; - - Structure s; - - std::unordered_set nodes = s.get_nodes(g); - std::unordered_set roots; - for (auto const &node : nodes) { - if (s.get_incoming_edges(g, node).empty()) { - roots.insert(node); - } - } - - return roots; -} - -template > -std::unordered_set leaves(G const &g) { - return roots>(g); -} - -template > -void topo_sort(G const &g, - std::vector *ordering) { - using N = typename Structure::vertex_type; - - Structure s; - std::unordered_map> predecessors; - - std::queue q; - for (auto const &node : s.get_nodes(g)) { - predecessors[node]; - for (auto const &edge : s.get_incoming_edges(g, node)) { - predecessors.at(node).insert(s.get_src(g, edge)); - } - } - - for (auto it = predecessors.begin(); it != predecessors.end();) { - if (it->second.empty()) { - q.push(it->first); - it = predecessors.erase(it); - } else { - it++; - } - } - - std::unordered_set node_successors; - while (!q.empty()) { - N const ¤t = q.front(); - - ordering->push_back(current); - - node_successors.clear(); - successors(g, current, &node_successors); - for (auto const &succ : node_successors) { - if (predecessors.find(succ) != predecessors.end()) { - predecessors.at(succ).erase(current); - if (predecessors.at(succ).empty()) { - predecessors.erase(succ); - q.push(succ); - } - } - } - - q.pop(); - } -} - -template > -std::unordered_map> - dominators(G const &g) { - using N = typename Structure::vertex_type; - // using E = typename Structure::edge_type; - - // Structure s; - - std::vector nodes; - topo_sort(g, &nodes); - std::unordered_map> dom; - - std::unordered_set pred_part; - for (auto const &node : nodes) { - pred_part.clear(); - predecessors(g, node, &pred_part); - for (auto const &p : pred_part) { - if (dom.find(node) == dom.end()) { - dom[node] = dom.at(p); - } else { - auto &node_dom_set = dom.at(node); - auto const &p_dom_set = dom.at(p); - for (auto it = node_dom_set.begin(); it != node_dom_set.end();) { - if (p_dom_set.find(*it) == p_dom_set.end()) { - it = node_dom_set.erase(it); - } else { - it++; - } - } - } - } - dom[node].insert(node); - } - - return dom; -} - -template > -std::unordered_map> - post_dominators(G const &g) { - return dominators>(g); -} - -template > -std::unordered_map - imm_dominators(G const &g) { - using N = typename Structure::vertex_type; - // using E = typename Structure::edge_type; - - std::vector topo; - topo_sort(g, &topo); - std::unordered_map topo_rank; - for (int i = 0; i < (int)topo.size(); i++) { - topo_rank[topo[i]] = i; - } - std::unordered_map> dom = - dominators(g); - - std::unordered_map imm_dom; - for (auto const &kv : dom) { - N const &n = kv.first; - std::unordered_set const &n_doms = kv.second; - - // if a node is only dominated by itself, set the dominator to itself to - // signify that it has no immediate dominator - if (n_doms.size() == 1) { - imm_dom[n] = n; - continue; - } - - N const *n_imm_dom = nullptr; - int current_topo_rank = std::numeric_limits::min(); - for (auto const &d : n_doms) { - if (topo_rank.at(d) > current_topo_rank && d != n) { - n_imm_dom = &d; - current_topo_rank = topo_rank.at(d); - } - } - imm_dom[n] = *n_imm_dom; - } - - return imm_dom; -} - -template > -void dfs(G const &g, - typename Structure::vertex_type const &n, - std::function const - &visitor) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - Structure s; - - /* auto i_visitor = std::bind(visitor, g, s, n); */ - auto i_visitor = [&](N const &nn) { return visitor(g, s, n, nn); }; - - std::queue q; - std::unordered_set visited; - - auto visit = [&](N const &n) { - if (visited.find(n) == visited.end()) { - q.push(n); - visited.insert(n); - } - }; - - visit(n); - - while (!q.empty()) { - N current = q.front(); - q.pop(); - - i_visitor(current); - - for (E const &edge : s.get_outgoing_edges(g, current)) { - N const &dst = s.get_dst(g, edge); - visit(dst); - } - } - - return; -} - -template > -std::unordered_set - descendants(G const &g, typename Structure::vertex_type const &n) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - std::unordered_set descendants; - - auto dfs_visitor = [&](G const &gg, - Structure const &ss, - N const &dfs_src, - N const ¤t_node) { - descendants.insert(current_node); - }; - - dfs(g, n, dfs_visitor); - - return descendants; -} - -template > -std::vector> - weakly_connected_components(G const &g) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - std::vector> result; - std::unordered_set seen; - - for (N const &n : nodes>(g)) { - if (seen.find(n) != seen.end()) { - continue; - } - - auto component = descendants>(g, n); - seen.insert(component.begin(), component.end()); - result.emplace_back(component); - } - - return result; -} - -template > -std::unordered_map - imm_post_dominators(G const &g) { - return imm_dominators>(g); -} - -template > -BasicGraph transitive_reduction(G const &g) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - Structure s; - BasicGraph reduction; - - std::unordered_set nodes = s.get_nodes(g); - - reduction.add_nodes(nodes); - - std::unordered_set> to_delete; - - auto dfs_visitor = [&](N const &src, - G const &gg, - Structure const &ss, - N const &dfs_src, - N const &nn) { - if (nn != dfs_src && to_delete.find({src, nn}) == to_delete.end() && - has_edge(gg, src, nn)) { - to_delete.insert({src, nn}); - } - }; - - for (N const &n : nodes) { - /* auto n_dfs_visitor = std::bind(dfs_visitor, n); */ - auto n_dfs_visitor = - [&](G const &gg, Structure const &ss, N const &dfs_src, N const &nn) { - return dfs_visitor(n, gg, ss, dfs_src, nn); - }; - - for (N const &child : successors(g, n)) { - dfs(g, child, n_dfs_visitor); - } - } - - for (E const &e : get_edges(g)) { - std::pair basic_edge = get_basic_edge(g, e); - - if (to_delete.find(basic_edge) == to_delete.end()) { - reduction.add_edge(basic_edge); - } - } - - return reduction; -} - -template -void inplace_transitive_reduction(BasicGraph &g) { - using Structure = GraphStructure>; - using G = BasicGraph; - using E = std::pair; - - std::unordered_set to_delete; - - auto dfs_visitor = [&](N const &src, - G const &gg, - Structure const &ss, - N const &dfs_src, - N const &nn) { - if (nn != dfs_src && to_delete.find({src, nn}) == to_delete.end() && - has_edge(gg, src, nn)) { - to_delete.insert({src, nn}); - } - }; - - for (N const &n : g.nodes) { - auto n_dfs_visitor = - [&](G const &gg, Structure const &ss, N const &dfs_src, N const &nn) { - return dfs_visitor(n, gg, ss, dfs_src, nn); - }; - - for (N const &child : successors(g, n)) { - dfs(g, child, n_dfs_visitor); - } - } - - for (E const &e : to_delete) { - g.remove_edge(e); - } -}; - -template > -void export_as_dot( - DotFile &dotfile, - G const &g, - std::function const - &pretty) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - GraphStructure s; - - for (N const &n : s.get_nodes(g)) { - dotfile.add_record_node(n, pretty(n)); - - for (E const &edge : s.get_incoming_edges(g, n)) { - dotfile.add_edge(s.get_src(g, edge), s.get_dst(g, edge)); - } - } - - dotfile.close(); -} - -} // namespace Utils -} // namespace PCG -} // namespace FlexFlow - -#endif // _DOMINATORS_H diff --git a/lib/compiler/src/old/graph.cc b/lib/compiler/src/old/graph.cc deleted file mode 100644 index 191b1028b7..0000000000 --- a/lib/compiler/src/old/graph.cc +++ /dev/null @@ -1,1255 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "graph.h" -#include "dominators.h" -#include "op-attrs/op-attrs.h" -#include "utils/disjoint_set.h" -#include "utils/unique.h" -#include - -// using FlexFlow::utils::Node; -// using FlexFlow::opmeta::OperatorParameters; - -namespace FlexFlow { - -ParallelComputationGraph::Graph(std::string const &logger_name) - : Graph(spdlog::get(logger_name)) {} - -ParallelComputationGraph::Graph(std::shared_ptr const &logger) - : logger(logger) {} - -Graph::Graph(utils::AdjacencyMultiDiGraph const &g, - utils::bidict const &nodeMap, - std::shared_ptr const &logger) - : g(g), nodeMap(nodeMap), logger(logger) {} - -/* using namespace Legion; */ -/* using FlexFlow::MachineView; */ - -/* LegionRuntime::Logger::Category log_graph("graph"); */ -/* LegionRuntime::Logger::Category log_simplify("graph_simplify"); */ - -void Graph::add_edge(Node const &srcOp, - Node const &dstOp, - int srcIdx, - int dstIdx) { - this->g.add_edge({srcOp, dstOp, (std::size_t)srcIdx, (std::size_t)dstIdx}); -} - -Node Graph::add_node(PCGOperatorAttrs const ¶ms) { - Node n = this->g.add_node(); - this->nodeMap.equate(n, params); - return n; -} - -void Graph::add_edge(utils::MultiDiEdge const &e) { - this->g.add_edge(e); -} - -void Graph::remove_edge(utils::MultiDiEdge const &e, - bool remove_node_if_unused) { - this->g.remove_edge(e); - utils::remove_node_if_unused(this->g, e.src); - utils::remove_node_if_unused(this->g, e.dst); -} - -bool Graph::has_edge(utils::MultiDiEdge const &e) const { - return utils::contains_edge(this->g, e); -} - -void Graph::print_dot() const { - this->print_dot(std::cout); -} - -void Graph::print_dot(std::ostream &s) const { - auto directed = unsafe_view_as_digraph(this->g); - - DotFile dot(s); - - export_as_dot(dot, directed, [&](utils::Node const &node) -> RecordFormatter { - RecordFormatter rf; - rf << node.to_string(); - tl::optional sub_rf = as_dot(this->nodeMap.at_l(node)); - if (sub_rf.has_value()) { - rf << sub_rf.value(); - } - - return rf; - }); - s << std::endl; -} - -bool Graph::has_loop() { - return !utils::is_acyclic(this->g).value_or(true); -} - -/* Node Graph::find_bottleneck_node(Node const &sink_node, */ -/* Node const &source_node) const { */ -/* using FlexFlow::PCG::Utils::GraphStructure; */ -/* using FlexFlow::PCG::Utils::imm_post_dominators; */ -/* using FlexFlow::PCG::Utils::MultisourceGraphStructure; */ -/* using FlexFlow::PCG::Utils::roots; */ - -/* Node source(source_node); */ -/* std::unordered_map ipd; */ -/* std::unordered_set graph_roots = roots(*this); */ -/* if (source_node != Node::INVALID_NODE) { */ -/* ipd = imm_post_dominators(*this); */ -/* } else if (graph_roots.size() == 1) { */ -/* ipd = imm_post_dominators(*this); */ -/* source = *graph_roots.begin(); */ -/* } else { */ -/* ipd = imm_post_dominators>(*this); */ -/* } */ - -/* Node bn_node = ipd.at(source); */ -/* if (bn_node == source || bn_node == sink_node) { */ -/* return Node::INVALID_NODE; */ -/* } */ - -/* return bn_node; */ -/* } */ - -Graph Graph::subgraph(std::unordered_set const &nodes) const { - AdjacencyMultiDiGraph sub_g = subgraph(this->g, nodes); - - bidict sub_nodeMap; - for (auto const &kv : this->nodeMap) { - if (contains(nodes, kv.first)) { - sub_nodeMap.equate(kv.first, kv.second); - } - } - - return {sub_g, sub_nodeMap, this->logger}; -} - -void Graph::remove_node(Node const &node, bool purge_edges) { - assert(purge_edges == true); - utils::remove_node(this->g, node); - this->nodeMap.erase_l(node); -} - -/*static*/ -Graph Graph::singleton(PCGOperatorAttrs const ¶ms) { - Graph g; - g.add_node(params); - return g; -} - -bool Graph::empty() const { - return utils::empty(this->g); -} - -void Graph::replace_subgraph(std::unordered_set const ¤tNodes, - Graph const &replaceWith) { - assert(currentNodes.size() > 0); - if (replaceWith.empty()) { - Graph subgraph = this->subgraph(currentNodes); - assert(!subgraph.empty()); - Node source_node = subgraph.find_source_node(); - Node noop = - this->model->get_or_create_noop_node(source_node.ptr->inputs[0]); - this->replace_subgraph_with_nonempty(currentNodes, - Graph::singleton(this->model, noop)); - this->contract_out_node(noop); - } else { - this->replace_subgraph_with_nonempty(currentNodes, replaceWith); - } -} - -void Graph::replace_subgraph_with_nonempty( - std::unordered_set const ¤tNodes, Graph const &replaceWith) { - using FlexFlow::PCG::Utils::get_edges; - using FlexFlow::PCG::Utils::nodes; - - Node new_sink_node = replaceWith.find_sink_node(); - - Graph old_subgraph = this->subgraph(currentNodes); - Node old_sink_node = old_subgraph.find_sink_node(); - Node old_source_node = old_subgraph.find_source_node(); - - std::unordered_set all_nodes = nodes(*this); - - for (Edge const &old_inner_edge : get_edges(old_subgraph)) { - this->remove_edge(old_inner_edge, false); - } - for (Edge const &new_inner_edge : get_edges(replaceWith)) { - this->add_edge(new_inner_edge); - } - - std::unordered_set old_in_edges = this->inEdges.at(old_source_node); - if (!old_in_edges.empty()) { - Node new_source_node = replaceWith.find_source_node(); - for (Edge const &old_in_edge : old_in_edges) { - Edge new_in_edge(old_in_edge); - new_in_edge.dstOp = new_source_node; - this->remove_edge(old_in_edge, false); - this->add_edge(new_in_edge); - } - } - - std::unordered_set old_out_edges = this->outEdges.at(old_sink_node); - for (Edge const &old_out_edge : old_out_edges) { - Edge new_out_edge(old_out_edge); - new_out_edge.srcOp = new_sink_node; - this->remove_edge(old_out_edge, false); - this->add_edge(new_out_edge); - } - - for (Node const &node : currentNodes) { - this->remove_node(node); - } - - assert(this->check_correctness()); -} - -void Graph::contract_out_node(Node const &node) { - contract_node(this->g, node); - this->nodeMap.erase_l(node); -} - -/* std::pair, std::unique_ptr> */ -/* Graph::split_at_node(Node const &bottleneck) const { */ -/* using FlexFlow::PCGe:Utils::topo_sort; */ - -/* auto first_graph = std::unique_ptr(new Graph(this->model)); */ -/* auto second_graph = std::unique_ptr(new Graph(this->model)); */ - -/* std::unordered_set used_nodes; */ -/* { */ -/* std::vector topo_sorted; */ -/* topo_sort(*this, &topo_sorted); */ - -/* for (auto const &node : topo_sorted) { */ -/* if (node == bottleneck) { */ -/* break; */ -/* } */ - -/* used_nodes.insert(node); */ -/* } */ -/* used_nodes.insert(bottleneck); */ - -/* assert(used_nodes.size() < topo_sorted.size()); */ -/* } */ - -/* for (auto const &it : this->inEdges) { */ -/* auto const &inList = it.second; */ -/* if (used_nodes.find(it.first) != used_nodes.end()) { */ -/* // Add all in-edges of used_nodes in to the first_graph */ -/* for (auto const &it2 : inList) { */ -/* first_graph->add_edge(it2); */ -/* } */ -/* } else { */ -/* // Add all in-edges of not_used_nodes into the second_graph */ -/* for (auto const &it2 : inList) { */ -/* second_graph->add_edge(it2); */ -/* } */ -/* } */ -/* } */ - -/* return {std::move(first_graph), std::move(second_graph)}; */ -/* } */ - -void Graph::remove_input_nodes() { - using FlexFlow::PCG::Utils::nodes; - - for (auto const &n : nodes(*this)) { - if (n.ptr->op_type == OP_INPUT) { - this->remove_node(n, true /*purge_edges*/); - } - } -} - -Node Graph::clone_node(Node const &n) { - Node cloned = n; - cloned.original_guid = n.guid; - cloned.guid = this->model->node_global_guid++; - this->add_node(cloned); - return cloned; -} - -Node Graph::declone_node(Node const &n) { - assert(n.original_guid.has_value()); - Node decloned = n; - decloned.guid = n.original_guid.value(); - decloned.original_guid = tl::nullopt; - this->add_node(decloned); - return decloned; -} - -std::pair> - Graph::deduplicate_input_node(Node const &n) { - using FlexFlow::PCG::Utils::nodes; - using FlexFlow::PCG::Utils::outgoing_edges; - - assert(n.original_guid.has_value()); - std::unordered_set old_all_nodes = nodes(*this); - Node decloned = this->declone_node(n); - - std::unordered_set old_nodes; - std::unordered_set new_edges; - for (Node const &nn : old_all_nodes) { - if (nn.original_guid == n.original_guid) { - old_nodes.insert(nn); - for (Edge const &e : outgoing_edges(*this, nn)) { - Edge decloned_edge(e); - decloned_edge.replace_node(nn, decloned); - new_edges.insert(decloned_edge); - } - this->remove_node(nn, true /*purge_edges*/); - } - } - - for (Edge const &e : new_edges) { - this->add_edge(e); - } - - return {decloned, old_nodes}; -} - -std::unordered_map Graph::deduplicate_input_nodes() { - using FlexFlow::PCG::Utils::nodes; - - std::unordered_map deduplication_map; - - bool done; - while (true) { - done = true; - for (Node const &n : nodes(*this)) { - if (n.original_guid.has_value()) { - done = false; - auto kv = this->deduplicate_input_node(n); - for (auto const &r : kv.second) { - deduplication_map[r] = kv.first; - } - break; - } - } - if (done) { - break; - } - } - - return deduplication_map; -} - -void Graph::duplicate_input_node(Node const &n) { - using FlexFlow::PCG::Utils::outgoing_edges; - using FlexFlow::PCG::Utils::successors; - - assert(n.ptr->op_type == OP_INPUT); - - std::unordered_map clones; - - for (auto const &s : successors(*this, n)) { - clones[s] = this->clone_node(n); - } - - for (auto const &e : outgoing_edges(*this, n)) { - Edge cloned(e); - cloned.srcOp = clones.at(e.dstOp); - this->add_edge(cloned); - } - this->remove_node(n, true /*purge_edges*/); -} - -void Graph::duplicate_input_nodes() { - using FlexFlow::PCG::Utils::nodes; - - for (auto const &n : nodes(*this)) { - if (n.ptr->op_type == OP_INPUT) { - this->duplicate_input_node(n); - } - } -} - -std::pair, std::unique_ptr> - Graph::split_horizontal(Node const &source_node, - Node const &sink_node) const { - using FlexFlow::PCG::Utils::weakly_connected_components; - - Graph trimmed_graph(*this); - assert(sink_node != - Node::INVALID_NODE); // sink node should never be invalid node - if (source_node != Node::INVALID_NODE) { - trimmed_graph.remove_node(source_node, true /*purge_edges*/); - } - trimmed_graph.remove_node(sink_node, true /*purge_edges*/); - std::vector> wccs = - weakly_connected_components(trimmed_graph); - assert(wccs.size() >= 2); - std::unordered_set first_branch = wccs.back(); - wccs.pop_back(); - std::unordered_set rest; - for (auto const &wcc : wccs) { - rest.insert(wcc.begin(), wcc.end()); - } - if (source_node != Node::INVALID_NODE) { - first_branch.insert(source_node); - rest.insert(source_node); - } - first_branch.insert(sink_node); - rest.insert(sink_node); - - auto first_graph = - std::unique_ptr(new Graph(this->subgraph(first_branch))); - auto second_graph = std::unique_ptr(new Graph(this->subgraph(rest))); - - return {std::move(first_graph), std::move(second_graph)}; -} - -GraphCostResult GraphCostResult::invalid() { - return {std::numeric_limits::infinity(), {}}; -} - -bool GraphCostResult::operator<(GraphCostResult const &other) const { - return this->cost < other.cost; -} - -std::ostream &operator<<(std::ostream &s, GraphCostResult const &r) { - s << "GraphCostResult{cost=" << r.cost << "}"; - return s; -} - -std::ostream &operator<<(std::ostream &s, GraphOptimizeResult const &r) { - s << "GraphOptimizeResult{cost=" << r.cost << "}"; - return s; -} - -template <> -GraphCostResult sequence_cost(GraphCostResult const &first, - GraphCostResult const &second) { - GraphCostResult result(first); - result.cost += second.cost; - result.views.insert(second.views.cbegin(), second.views.cend()); - return result; -} - -template <> -float sequence_cost(float const &first, float const &second) { - return first + second; -} - -template <> -GraphOptimizeResult - sequence_cost(GraphOptimizeResult const &first, - GraphOptimizeResult const &second) { - GraphOptimizeResult result; - result.cost = first.cost + second.cost; - result.views.insert(first.views.cbegin(), first.views.cend()); - result.views.insert(second.views.cbegin(), second.views.cend()); - - result.graph = second.graph; - Node second_src = result.graph.value().find_source_node(); - result.graph.value().replace_subgraph({second_src}, first.graph.value()); - return result; -} - -template <> -GraphCostResult parallel_cost(GraphCostResult const &first, - GraphCostResult const &second) { - GraphCostResult result; - result.cost = std::max(first.cost, second.cost); - result.views.insert(first.views.cbegin(), first.views.cend()); - result.views.insert(second.views.cbegin(), second.views.cend()); - - return result; -} - -template <> -float parallel_cost(float const &first, float const &second) { - return std::max(first, second); -} - -float Graph::optimal_cost() const { - return this->generic_optimal_cost(); -} - -std::unordered_map Graph::optimal_views() const { - return this->generic_optimal_cost().views; -} - -Graph Graph::reduced() const { - using FlexFlow::PCG::Utils::BasicGraph; - using FlexFlow::PCG::Utils::get_edges; - using FlexFlow::PCG::Utils::transitive_reduction; - - BasicGraph transitive_skeleton = transitive_reduction(*this); - - Graph reduced_graph(this->model); - - for (Edge const &e : get_edges(*this)) { - if (transitive_skeleton.has_edge(e.srcOp, e.dstOp)) { - reduced_graph.add_edge(e); - } - } - - return reduced_graph; -} - -/** - * @brief A generic cost function for a graph capable of finding both the cost - * and the optimal views - * - * @note A templated function is used here because while the caching behaviors - * of the cost and the optimal views are different, much of the code between the - * two versions is almost identical. By using a few template specializations we - * can avoid duplicating all this code. - * - * @tparam T the result type (can be either float or GraphCostResult) - * @return T the cost of the graph (along with any additional data in the return - * type) - */ -template -T Graph::generic_optimal_cost() const { - using FlexFlow::PCG::Utils::GraphStructure; - - Graph reduced_graph = this->reduced(); - // GraphStructure s; - // if (source_node.ptr->op_type == OP_INPUT) { - // for (auto const &e : s.get_outgoing_edges(reduced_graph, source_node)) { - // reduced_graph.remove_edge(e, false/*remove_node_if_unused*/); - // } - // reduced_graph.remove_node(source_node); - // } - - Node sink_node = reduced_graph.find_sink_node(); - this->search->logger->info() << "Found sink node: " << sink_node.to_string(); - - MachineResource resource(model->config); - - std::vector valid_views = - search->get_valid_machine_views(sink_node, resource, true); - - T optimal = search->infinity(); - - this->search->logger->info() - << "Exploring " << valid_views.size() << " valid views"; - for (MachineView const &sink_view : valid_views) { - this->search->logger->info() << " Exploring valid view " << sink_view; - T new_cost = - search->graph_cost(&reduced_graph, - {Node::INVALID_NODE, MachineView::NO_VIEW}, - {sink_node, sink_view}, - resource, - true); - if (new_cost < optimal) { - optimal = new_cost; - } - } - - return optimal; -} - -size_t Graph::hash(void) const { - // Graph hash should be additive and independent to the ordering of the nodes - size_t total_hash = 0; - for (auto const &it : inEdges) { - auto const &inList = it.second; - size_t node_hash = std::hash()((size_t)it.first.ptr); - for (auto const &e : inList) { - size_t edge_hash = 17; - edge_hash = edge_hash * 31 + std::hash()((size_t)e.srcOp.ptr); - edge_hash = edge_hash * 31 + std::hash()(e.srcIdx); - edge_hash = edge_hash * 31 + std::hash()(e.dstIdx); - node_hash *= edge_hash; - } - total_hash += node_hash; - } - return total_hash; -} - -size_t dp_state_hash(Graph const *graph, - Node const &sink_node, - MachineView const &sink_view, - Node const &source_node, - MachineView const &source_view, - MachineResource const &resource) { - size_t key = graph->hash(); - hash_combine(key, sink_node.ptr); - hash_combine(key, sink_view.hash()); - hash_combine(key, source_node.ptr); - hash_combine(key, resource.hash()); - return key; -} - -GraphOptimalViewSerialized - Graph::graph_optimize_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - FFModel *model = *((FFModel **)task->args); - if (model->config.search_num_nodes.has_value()) { - model->config.numNodes = model->config.search_num_nodes.value(); - } - if (model->config.search_num_workers.has_value()) { - model->config.workersPerNode = model->config.search_num_workers.value(); - } - model->all_valid_views.clear(); - model->register_all_machine_views(model->config.numNodes, - model->config.workersPerNode, - model->config.cpusPerNode, - model->all_valid_views); - Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) - .only_kind(Memory::GPU_FB_MEM) - .best_affinity_to(task->target_proc) - .first(); - MachineModel *machine; - if (model->config.machine_model_version == 0) { - machine = - (MachineModel *)new SimpleMachineModel(model->config.numNodes, - model->config.workersPerNode, - gpu_mem.capacity()); - } else if (model->config.machine_model_version == 1 and - !model->config.machine_model_file.empty()) { - machine = (MachineModel *)new EnhancedMachineModel( - model->config.machine_model_file, gpu_mem.capacity()); - } else { - assert(false && - "machine model creation error: currently only support " - "machine-model-version = 0 or 1. When machine-model-version = 1, " - "machine-model-file should not be empty."); - } - model->simulator = - make_unique(model, model->handlers[0], gpu_mem, machine); - std::unique_ptr best_graph; - std::unordered_map optimal_views; - if (model->config.only_data_parallel) { - Graph *graph = new Graph(model); - std::unordered_map op_to_node_map; - for (FlexFlow::Op const *dstOp : model->operators) { - Node dstNode; - dstNode.ptr = dstOp; - dstNode.guid = model->node_global_guid++; - op_to_node_map[dstOp] = dstNode; - for (int j = 0; j < dstOp->numInputs; j++) { - FlexFlow::Op const *srcOp = dstOp->inputs[j]->owner_op; - assert(op_to_node_map.find(srcOp) != op_to_node_map.end()); - Node srcNode = op_to_node_map[srcOp]; - graph->add_edge(srcNode, dstNode, dstOp->inputs[j]->owner_idx, j); - } - } - best_graph = std::unique_ptr(graph); - MachineView data_parallel_view; - data_parallel_view.device_type = MachineView::GPU; - data_parallel_view.ndims = 1; - data_parallel_view.dim[0] = - model->config.numNodes * model->config.workersPerNode; - data_parallel_view.stride[0] = 1; - data_parallel_view.start_device_id = 0; - for (auto const &node : best_graph->inEdges) { - optimal_views[node.first] = data_parallel_view; - } - } else { - model->graph_optimize(model->config.search_budget, - model->config.only_data_parallel, - best_graph, - optimal_views); - } - /* Serializer sez; */ - /* // First serialize graph */ - /* sez.serialize(best_graph->inEdges.size()); */ - /* std::unordered_map todos; */ - /* std::vector opList; */ - /* for (auto const &it : best_graph->inEdges) { */ - /* auto const &inList = it.second; */ - /* todos[it.first] = (int)inList.size(); */ - /* if (todos[it.first] == 0) { */ - /* opList.push_back(it.first); */ - /* } */ - /* } */ - /* size_t node_idx = 0; */ - /* while (node_idx < opList.size()) { */ - /* Node cur_node = opList[node_idx++]; */ - /* auto const &outList = best_graph->outEdges[cur_node]; */ - /* for (auto const &e : outList) { */ - /* todos[e.dstOp]--; */ - /* if (todos[e.dstOp] == 0) { */ - /* opList.push_back(e.dstOp); */ - /* } */ - /* } */ - /* auto const &inList = best_graph->inEdges[cur_node]; */ - /* sez.serialize(inList.size()); */ - /* for (auto const &e : inList) { */ - /* sez.serialize(e.srcOp.guid); */ - /* assert(e.dstOp.guid == cur_node.guid); */ - /* sez.serialize(e.srcIdx); */ - /* sez.serialize(e.dstIdx); */ - /* } */ - /* sez.serialize((size_t)10101010); // safe guard for the end of inedges */ - /* Op const *op = cur_node.ptr; */ - /* assert(op != NULL); */ - /* sez.serialize(cur_node.guid); */ - /* sez.serialize(op->op_type); */ - /* switch (op->op_type) { */ - /* case OP_INPUT: { */ - /* assert(op->numOutputs == 1); */ - /* NoOp *noop = (NoOp *)op; */ - /* sez.serialize(noop->op_type); */ - /* sez.serialize(noop->input_tensor_guid); */ - /* sez.serialize(noop->outputs[0]->data_type); */ - /* sez.serialize(noop->outputs[0]->num_dims); */ - /* for (int i = 0; i < noop->outputs[0]->num_dims; i++) { */ - /* sez.serialize(noop->outputs[0]->dims[i]); */ - /* } */ - /* break; */ - /* } */ - /* case OP_NOOP: { */ - /* break; */ - /* } */ - /* case OP_CONCAT: { */ - /* Concat *concat = (Concat *)op; */ - /* sez.serialize(concat->legion_axis); */ - /* break; */ - /* } */ - /* case OP_SPLIT: { */ - /* Split *split = (Split *)op; */ - /* sez.serialize(split->legion_axis); */ - /* sez.serialize(split->numOutputs); */ - /* for (int i = 0; i < split->numOutputs; i++) { */ - /* sez.serialize(split->outputs[i]->dims[split->legion_axis].size); */ - /* } */ - /* break; */ - /* } */ - /* case OP_EMBEDDING: { */ - /* Embedding *embed = (Embedding *)op; */ - /* sez.serialize(embed->layer_guid.id); */ - /* sez.serialize(embed->num_entries); */ - /* sez.serialize(embed->out_channels); */ - /* sez.serialize(embed->aggr); */ - /* sez.serialize(embed->data_type); */ - /* break; */ - /* } */ - /* case OP_EW_ADD: */ - /* case OP_EW_SUB: */ - /* case OP_EW_MUL: */ - /* case OP_EW_MAX: */ - /* case OP_EW_MIN: { */ - /* sez.serialize(op->op_type); */ - /* break; */ - /* } */ - /* case OP_MULTIHEAD_ATTENTION: { */ - /* MultiHeadAttention *attn = (MultiHeadAttention *)op; */ - /* sez.serialize(attn->layer_guid.id); */ - /* sez.serialize(attn->oProjSize); */ - /* sez.serialize(attn->num_heads); */ - /* sez.serialize(attn->qProjSize); */ - /* sez.serialize(attn->vProjSize); */ - /* sez.serialize(attn->dropout); */ - /* sez.serialize(attn->bias); */ - /* sez.serialize(attn->add_bias_kv); */ - /* sez.serialize(attn->add_zero_attn); */ - /* break; */ - /* } */ - /* case OP_SOFTMAX: { */ - /* Softmax *softmax = (Softmax *)op; */ - /* sez.serialize(softmax->dim); */ - /* break; */ - /* } */ - /* case OP_REPARTITION: { */ - /* Repartition *repart = (Repartition *)op; */ - /* sez.serialize(repart->repartition_dim); */ - /* sez.serialize(repart->repartition_degree); */ - /* break; */ - /* } */ - /* case OP_REPLICATE: { */ - /* Replicate *replicate = (Replicate *)op; */ - /* sez.serialize(replicate->replicate_dim); */ - /* sez.serialize(replicate->replicate_degree); */ - /* break; */ - /* } */ - /* case OP_REDUCTION: { */ - /* Reduction *reduction = (Reduction *)op; */ - /* sez.serialize(reduction->reduction_dim); */ - /* sez.serialize(reduction->reduction_degree); */ - /* break; */ - /* } */ - /* case OP_COMBINE: { */ - /* Combine *combine = (Combine *)op; */ - /* sez.serialize(combine->combine_dim); */ - /* sez.serialize(combine->combine_degree); */ - /* break; */ - /* } */ - /* case OP_FUSED_PARALLEL: { */ - /* FusedParallelOp *fused = (FusedParallelOp *)op; */ - /* sez.serialize(fused->num_parallel_ops); */ - /* for (int i = 0; i < fused->num_parallel_ops; i++) { */ - /* sez.serialize(fused->parallel_ops[i]); */ - /* } */ - /* break; */ - /* } */ - /* default: { */ - /* op->serialize(sez); */ - /* } */ - /* } */ - /* sez.serialize((size_t)12345678); // safe guard for the end of an op */ - /* } */ - /* assert(node_idx == best_graph->inEdges.size()); */ - /* // Second, serialize optimal machine view */ - /* printf("opotimal_views.size = %zu\n", optimal_views.size()); */ - /* sez.serialize(optimal_views.size()); */ - /* for (auto const &it : optimal_views) { */ - /* sez.serialize((size_t)98765432); // safe guard */ - /* sez.serialize(it.first.guid); */ - /* sez.serialize(it.second); */ - /* } */ - /* assert(sez.get_used_bytes() < GraphOptimalViewSerialized::buffer_size); */ - /* GraphOptimalViewSerialized ret; */ - /* ret.total_bytes = sez.get_used_bytes(); */ - /* memcpy(ret.data, sez.get_buffer(), ret.total_bytes); */ - /* // Deallocate best_graph */ - /* // delete best_graph; */ - /* return ret; */ -} - -}; // namespace FlexFlow - -namespace FlexFlow { - -using PCG::Edge; -using PCG::Graph; -using PCG::GraphCostResult; -using PCG::Node; - -void FFModel::register_all_machine_views( - int num_nodes, - int gpus_per_node, - int cpus_per_node, - std::vector &valid_views) { - // Single-parallelism-dimension views - for (int i = 1; i <= num_nodes * gpus_per_node; i++) { - if (num_nodes * gpus_per_node % i == 0) { - MachineView view; - view.device_type = MachineView::GPU; - view.ndims = 1; - view.dim[0] = i; - view.stride[0] = 1; - view.start_device_id = 0; - valid_views.push_back(view); - } - } - // Two-dimensional views - /* for (int i = 1; i <= num_nodes; i++) { */ - /* for (int j = 1; j <= gpus_per_node; j++) { */ - /* MachineView view; */ - /* view.device_type = MachineView::GPU; */ - /* view.ndims = 2; */ - /* view.dim[0] = i; */ - /* view.stride[0] = 1; */ - /* view.dim[1] = j; */ - /* view.stride[1] = 1; */ - /* view.start_device_id = 0; */ - /* valid_views.push_back(view); */ - /* } */ - /* } */ -} - -float FFModel::graph_cost(Graph const *graph, - Node const &sink_node, - MachineView const &sink_view, - Node const &source_node, - MachineView const &source_view, - MachineResource const &resources, - bool include_sink_compute_time, - bool constructing_optimal_view) { - assert(!graph->inEdges.empty()); - - return this->search->graph_cost(graph, - {source_node, source_view}, - {sink_node, sink_view}, - resources, - include_sink_compute_time); -} - -void FFModel::construct_optimal_view( - Graph const *graph, - Node const &sink_node, - MachineView const &sink_view, - Node const &source_node, - MachineView const &source_view, - MachineResource const &resources, - bool include_sink_compute_time, - float optimal_cost, - std::unordered_map &optimal_views) { - GraphCostResult result = - this->search->graph_cost(graph, - {source_node, source_view}, - {sink_node, sink_view}, - resources, - include_sink_compute_time); - - optimal_views.insert(result.views.begin(), result.views.end()); -} - -/* void FFModel::deserialize_graph_optimal_view( */ -/* Legion::Deserializer &dez, */ -/* Graph *graph, */ -/* std::unordered_map &optimal_views) { */ -/* // Deserializer dez(serialized.data, serialized.total_bytes); */ -/* std::unordered_map guid_to_nodes; */ -/* size_t num_nodes; */ -/* dez.deserialize(num_nodes); */ -/* // best_graph = new Graph(this); */ -/* for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { */ -/* Edge inedges[MAX_NUM_INPUTS]; */ -/* ParallelTensor inputs[MAX_NUM_INPUTS]; */ -/* size_t num_inputs; */ -/* dez.deserialize(num_inputs); */ -/* for (size_t j = 0; j < num_inputs; j++) { */ -/* size_t src_guid; */ -/* int src_idx, dst_idx; */ -/* dez.deserialize(src_guid); */ -/* assert(guid_to_nodes.find(src_guid) != guid_to_nodes.end()); */ -/* dez.deserialize(src_idx); */ -/* dez.deserialize(dst_idx); */ -/* assert(dst_idx < (int)num_inputs); */ -/* inedges[dst_idx].srcOp = guid_to_nodes[src_guid]; */ -/* inedges[dst_idx].srcIdx = src_idx; */ -/* inedges[dst_idx].dstIdx = dst_idx; */ -/* inputs[dst_idx] = inedges[dst_idx].srcOp.ptr->outputs[src_idx]; */ -/* } */ -/* { */ -/* size_t safecode; */ -/* dez.deserialize(safecode); */ -/* assert(safecode == 10101010); */ -/* } */ -/* Node node = Node::INVALID_NODE; */ -/* size_t guid; */ -/* OperatorType op_type; */ -/* dez.deserialize(guid); */ -/* dez.deserialize(op_type); */ -/* switch (op_type) { */ -/* case OP_INPUT: { */ -/* assert(num_inputs == 0); */ -/* int num_dims; */ -/* ParallelDim dims[MAX_TENSOR_DIM]; */ -/* OperatorType op_type; */ -/* dez.deserialize(op_type); */ -/* size_t input_tensor_guid; */ -/* dez.deserialize(input_tensor_guid); */ -/* DataType data_type; */ -/* dez.deserialize(data_type); */ -/* dez.deserialize(num_dims); */ -/* for (int i = 0; i < num_dims; i++) { */ -/* dez.deserialize(dims[i]); */ -/* } */ -/* ParallelTensor t = */ -/* create_parallel_tensor_legion_ordering(num_dims, */ -/* dims, */ -/* data_type, */ -/* nullptr, */ -/* 0, */ -/* true create_grad, */ -/* input_tensor_guid); */ -/* node.ptr = t->owner_op; */ -/* node.guid = node_global_guid++; */ -/* break; */ -/* } */ -/* case OP_NOOP: { */ -/* assert(num_inputs == 1); */ -/* node = get_or_create_noop_node(inputs[0]); */ -/* break; */ -/* } */ -/* case OP_BATCHMATMUL: { */ -/* node = BatchMatmul::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_CAST: { */ -/* node = Cast::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_CONCAT: { */ -/* int legion_axis; */ -/* dez.deserialize(legion_axis); */ -/* node = get_or_create_node( */ -/* {std::begin(inputs), std::begin(inputs) + num_inputs}, */ -/* {legion_axis}); */ -/* break; */ -/* } */ -/* case OP_SPLIT: { */ -/* int legion_axis; */ -/* dez.deserialize(legion_axis); */ -/* int num_outputs; */ -/* dez.deserialize(num_outputs); */ -/* std::vector splits; */ -/* for (int i = 0; i < num_outputs; i++) { */ -/* int dim_size; */ -/* dez.deserialize(dim_size); */ -/* splits.push_back(dim_size); */ -/* } */ -/* node = get_or_create_node(inputs[0], {splits, legion_axis}); - */ -/* break; */ -/* } */ -/* case OP_EMBEDDING: { */ -/* assert(num_inputs == 1); */ -/* AggrMode aggr; */ -/* int num_entries, out_channels; */ -/* size_t id; */ -/* DataType data_type; */ -/* dez.deserialize(id); */ -/* LayerID layer_guid(id); */ -/* dez.deserialize(num_entries); */ -/* dez.deserialize(out_channels); */ -/* dez.deserialize(aggr); */ -/* dez.deserialize(data_type); */ - -/* EmbeddingParams params; */ -/* params.aggr = aggr; */ -/* params.num_entries = num_entries; */ -/* params.out_channels = out_channels; */ -/* params.layer_guid = layer_guid; */ -/* params.data_type = data_type; */ -/* node = get_or_create_node(inputs[0], params); */ -/* break; */ -/* } */ -/* case OP_EW_ADD: */ -/* case OP_EW_SUB: */ -/* case OP_EW_MUL: */ -/* case OP_EW_MAX: */ -/* case OP_EW_MIN: { */ -/* assert(num_inputs == 2); */ -/* OperatorType op_type; */ -/* dez.deserialize(op_type); */ -/* node = get_or_create_node({inputs[0], inputs[1]}, */ -/* {op_type}); */ -/* break; */ -/* } */ -/* case OP_CONV2D: { */ -/* node = Conv2D::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_DROPOUT: { */ -/* node = Dropout::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_EXP: */ -/* case OP_SIN: */ -/* case OP_COS: */ -/* case OP_SCALAR_MULTIPLY: */ -/* case OP_SCALAR_FLOOR_DIV: */ -/* case OP_SCALAR_TRUE_DIV: */ -/* case OP_SCALAR_ADD: */ -/* case OP_SCALAR_SUB: */ -/* case OP_RELU: */ -/* case OP_SIGMOID: */ -/* case OP_TANH: */ -/* case OP_POW: */ -/* case OP_IDENTITY: */ -/* case OP_GELU: */ -/* case OP_ELU: { */ -/* node = ElementUnary::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_FLAT: { */ -/* node = Flat::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_GATHER: { */ -/* node = Gather::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_LAYERNORM: { */ -/* node = LayerNorm::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_LINEAR: { */ -/* node = Linear::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_MULTIHEAD_ATTENTION: { */ -/* assert(num_inputs == 3); */ -/* int embed_dim, num_heads, k_dim, v_dim; */ -/* float dropout; */ -/* bool bias, add_bias_kv, add_zero_attn; */ -/* size_t id; */ -/* dez.deserialize(id); */ -/* LayerID layer_guid(id); */ -/* dez.deserialize(embed_dim); */ -/* dez.deserialize(num_heads); */ -/* dez.deserialize(k_dim); */ -/* dez.deserialize(v_dim); */ -/* dez.deserialize(dropout); */ -/* dez.deserialize(bias); */ -/* dez.deserialize(add_bias_kv); */ -/* dez.deserialize(add_zero_attn); */ - -/* MultiHeadAttentionParams params; */ -/* params.embed_dim = embed_dim; */ -/* params.num_heads = num_heads; */ -/* params.kdim = k_dim; */ -/* params.vdim = v_dim; */ -/* params.dropout = dropout; */ -/* params.bias = bias; */ -/* params.add_bias_kv = add_bias_kv; */ -/* params.add_zero_attn = add_zero_attn; */ -/* params.layer_guid = layer_guid; */ -/* node = get_or_create_node( */ -/* {inputs[0], inputs[1], inputs[2]}, params); */ -/* break; */ -/* } */ -/* case OP_TOPK: { */ -/* node = TopK::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_GROUP_BY: { */ -/* node = Group_by::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_AGGREGATE: { */ -/* // node = Aggregate::deserialize(*this, dez, inputs, num_inputs); */ -/* int n; */ -/* float lambda_bal; */ -/* dez.deserialize(n); */ -/* dez.deserialize(lambda_bal); */ -/* assert(num_inputs == n + 4); */ -/* AggregateParams params; */ -/* params.n = n; */ -/* params.lambda_bal = lambda_bal; */ -/* node = get_or_create_node( */ -/* {std::begin(inputs), std::begin(inputs) + num_inputs}, params); - */ -/* break; */ -/* } */ -/* case OP_POOL2D: { */ -/* node = Pool2D::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_REDUCE_SUM: { */ -/* node = Reduce::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_RESHAPE: { */ -/* node = Reshape::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_SOFTMAX: { */ -/* assert(num_inputs == 1); */ -/* int softmax_dim; */ -/* dez.deserialize(softmax_dim); */ -/* node = get_or_create_node(inputs[0], {softmax_dim}); */ -/* break; */ -/* } */ -/* case OP_TRANSPOSE: { */ -/* node = Transpose::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_COMBINE: { */ -/* assert(num_inputs == 1); */ -/* int combine_dim, combine_degree; */ -/* dez.deserialize(combine_dim); */ -/* dez.deserialize(combine_degree); */ -/* node = get_or_create_node(inputs[0], */ -/* {combine_dim, combine_degree}); */ -/* break; */ -/* } */ -/* case OP_REPARTITION: { */ -/* assert(num_inputs == 1); */ -/* int repartition_dim, repartition_degree; */ -/* dez.deserialize(repartition_dim); */ -/* dez.deserialize(repartition_degree); */ -/* node = get_or_create_node( */ -/* inputs[0], {repartition_dim, repartition_degree}); */ -/* break; */ -/* } */ -/* case OP_REPLICATE: { */ -/* assert(num_inputs == 1); */ -/* int replicate_dim, replicate_degree; */ -/* dez.deserialize(replicate_dim); */ -/* dez.deserialize(replicate_degree); */ -/* node = get_or_create_node(inputs[0], */ -/* {replicate_dim, - * replicate_degree}); */ -/* break; */ -/* } */ -/* case OP_REDUCTION: { */ -/* assert(num_inputs == 1); */ -/* int reduction_dim, reduction_degree; */ -/* dez.deserialize(reduction_dim); */ -/* dez.deserialize(reduction_degree); */ -/* node = get_or_create_node(inputs[0], */ -/* {reduction_dim, - * reduction_degree}); */ -/* break; */ -/* } */ -/* case OP_FUSED_PARALLEL: { */ -/* assert(num_inputs == 1); */ -/* std::vector parallel_ops; */ -/* int num_parallel_ops; */ -/* dez.deserialize(num_parallel_ops); */ -/* for (int i = 0; i < num_parallel_ops; i++) { */ -/* ParallelOpInfo info; */ -/* dez.deserialize(info); */ -/* parallel_ops.push_back(info); */ -/* } */ -/* node = get_or_create_node(inputs[0], - * {parallel_ops}); */ -/* break; */ -/* } */ -/* default: { */ -/* fprintf(stderr, */ -/* "The following operator type is currently not supported" */ -/* " for graph deserialization: %s\n" */ -/* "Report the issue to the FlexFlow developers\n", */ -/* get_operator_type_name(op_type).c_str()); */ -/* assert(false && "Unsupported operator type"); */ -/* } */ -/* } */ -/* { */ -/* size_t safecode; */ -/* dez.deserialize(safecode); */ -/* assert(safecode == 12345678); */ -/* } */ -/* assert(node.ptr != nullptr); */ -/* guid_to_nodes[guid] = node; */ -/* for (size_t i = 0; i < num_inputs; i++) { */ -/* inedges[i].dstOp = node; */ -/* graph->add_edge(inedges[i]); */ -/* } */ -/* } */ -/* // Second, deserialize optimal machine view */ -/* size_t num_views; */ -/* dez.deserialize(num_views); */ -/* printf("views.size() = %zu\n", num_views); */ -/* for (size_t i = 0; i < num_views; i++) { */ -/* size_t safecode, guid; */ -/* MachineView view; */ -/* dez.deserialize(safecode); */ -/* assert(safecode == 98765432); */ -/* dez.deserialize(guid); */ -/* assert(guid_to_nodes.find(guid) != guid_to_nodes.end()); */ -/* dez.deserialize(view); */ -/* optimal_views[guid_to_nodes[guid]] = view; */ -/* } */ -/* assert(dez.get_remaining_bytes() == 0); */ -/* printf("Deserialized Views...\n"); */ -/* for (auto const &it : optimal_views) { */ -/* printf("node[%zu]: type(%s) view(%d %d %d) ", */ -/* it.first.guid, */ -/* it.first.to_string().c_str(), */ -/* it.second.ndims, */ -/* it.second.dim[0], */ -/* it.second.start_device_id); */ -/* auto const &list = graph->inEdges.at(it.first); */ -/* for (auto const &it2 : list) { */ -/* Edge e = it2; */ -/* printf(" inEdge(node(%zu) idx(%d))", e.srcOp.guid, e.srcIdx); */ -/* } */ -/* printf("\n"); */ -/* } */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/compiler/src/old/graph.h b/lib/compiler/src/old/graph.h deleted file mode 100644 index db313b080d..0000000000 --- a/lib/compiler/src/old/graph.h +++ /dev/null @@ -1,248 +0,0 @@ -/* Copyright 2021 CMU, Facebook, LANL, MIT, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _FLEXFLOW_GRAPH_H_ -#define _FLEXFLOW_GRAPH_H_ -#include "basic_graph.h" -/* #include "node.h" */ -#include "graph_structures.h" -#include "op-attrs/op-attrs.h" -#include "pcg/machine_view.h" -#include "utils/bidict.h" -#include "utils/dot_file.h" -#include "utils/graph.h" -#include "utils/graph/serialparallel.h" -#include "utils/recursive_logger.h" -#include -#include - -// extern LegionRuntime::Logger::Category log_dp; - -/* namespace FlexFlow { */ -/* namespace ffc { */ - -/* class SearchHelper; */ - -/* struct GraphOptimalViewSerialized { */ -/* #ifdef LEGION_MAX_RETURN_SIZE */ -/* static const size_t buffer_size = LEGION_MAX_RETURN_SIZE - 8; */ -/* #else */ -/* static const size_t buffer_size = 1024 * 1024 - 8; */ -/* #endif */ -/* size_t total_bytes; */ -/* char data[buffer_size]; */ -/* }; */ - -/* class Graph { */ -/* public: */ -/* Graph() = default; */ -/* Graph(std::string const &logger_name); */ -/* Graph(std::shared_ptr const &logger); */ - -/* void add_edge(utils::Node const &srcOp, utils::Node const &dstOp, int - * srcIdx, int dstIdx); */ -/* utils::Node add_node(opmeta::OperatorParameters const &); */ -/* void add_edge(utils::MultiDiEdge const &e); */ -/* void remove_node(utils::Node const &, bool purge_edges = false); */ -/* void remove_edge(utils::MultiDiEdge const &e, bool remove_node_if_unused = - * true); */ -/* bool has_edge(utils::MultiDiEdge const &e) const; */ -/* void replace_subgraph(std::unordered_set const - * ¤tNodes, */ -/* Graph const &replaceWith); */ -/* Graph subgraph(std::unordered_set const &nodes) const; */ -/* void contract_out_node(opmeta::OperatorParameters const &); */ -/* float optimal_cost() const; */ -/* std::unordered_map optimal_views() - * const; */ -/* void remove_input_nodes(); */ -/* void duplicate_input_node(opmeta::OperatorParameters const &); */ -/* void duplicate_input_nodes(); */ -/* opmeta::OperatorParameters clone_node(opmeta::OperatorParameters const &); - */ -/* std::pair> */ -/* deduplicate_input_node(opmeta::OperatorParameters const &); */ -/* std::unordered_map - * deduplicate_input_nodes(); */ -/* opmeta::OperatorParameters declone_node(opmeta::OperatorParameters const - * &); */ - -/* size_t hash(void) const; */ -/* void print(void) const; */ -/* void print_dot() const; */ -/* void print_dot(std::ostream &) const; */ - -/* bool check_correctness(void); */ -/* bool has_loop(void); */ -/* //bool map_operators_to_layers(std::vector &layers) const; */ -/* //static GraphOptimalViewSerialized */ -/* // graph_optimize_task(Legion::Task const *task, */ -/* // std::vector const - * ®ions, */ -/* // Legion::Context ctx, */ -/* // Legion::Runtime *runtime); */ -/* /1* opmeta::OperatorParameters - * find_bottleneck_node(opmeta::OperatorParameters const &sink_node, *1/ */ -/* /1* opmeta::OperatorParameters const - * &source_node) const; *1/ */ -/* void print_strategy_computation_graph( */ -/* std::unordered_map const - * &strategy) const; */ -/* void export_strategy_computation_graph( */ -/* std::unordered_map const - * &strategy, */ -/* std::string const &out_filename) const; */ -/* void export_strategy_computation_graph( */ -/* std::unordered_map const - * &strategy, */ -/* DotFile &dot) const; */ - -/* /1* std::pair, std::unique_ptr> *1/ */ -/* /1* split_at_node(opmeta::OperatorParameters const &bottleneck) const; - * *1/ */ -/* /1* std::pair, std::unique_ptr> *1/ */ -/* /1* split_horizontal(opmeta::OperatorParameters const &source_node, - * opmeta::OperatorParameters const &sink_node) const; *1/ */ - -/* Graph reduced() const; */ - -/* opmeta::OperatorParameters find_sink_node() const; */ -/* opmeta::OperatorParameters find_source_node() const; */ -/* void reshape_output_tensor(opmeta::ParallelTensorShape const &shape); */ -/* std::unique_ptr */ -/* with_output_tensor_reshaped_to(opmeta::ParallelTensorShape const - * &shape) const; */ - -/* static Graph singleton(opmeta::OperatorParameters const &); */ -/* bool empty() const; */ - -/* template */ -/* T generic_optimal_cost() const; */ - -/* private: */ -/* void remove_inverse_parallel_ops(); */ -/* void replace_subgraph_with_nonempty( */ -/* std::unordered_set const ¤tNodes, - * Graph const &replaceWith); */ -/* private: */ -/* Graph(utils::AdjacencyMultiDiGraph const &, utils::bidict const &, std::shared_ptr const - * &); */ - -/* utils::AdjacencyMultiDiGraph g; */ -/* utils::bidict nodeMap; */ -/* std::shared_ptr logger; */ -/* }; */ - -/* struct GraphOptimizeResult { */ -/* tl::optional graph; */ -/* float cost; */ -/* std::unordered_map views; */ - -/* friend std::ostream &operator<<(std::ostream &, GraphOptimizeResult const - * &); */ -/* }; */ - -/* /1* namespace Utils { *1/ */ -/* /1* template <> *1/ */ -/* /1* struct GraphStructure { *1/ */ -/* /1* using G = FlexFlow::PCG::Graph; *1/ */ -/* /1* using graph_type = FlexFlow::PCG::Graph; *1/ */ -/* /1* using vertex_type = FlexFlow::PCG::Node; *1/ */ -/* /1* using edge_type = FlexFlow::PCG::Edge; *1/ */ - -/* /1* std::unordered_set get_nodes(G const &g) const { *1/ */ -/* /1* std::unordered_set nodes; *1/ */ -/* /1* for (auto const &kv : g.inEdges) { *1/ */ -/* /1* nodes.insert(kv.first); *1/ */ -/* /1* } *1/ */ -/* /1* for (auto const &kv : g.outEdges) { *1/ */ -/* /1* nodes.insert(kv.first); *1/ */ -/* /1* } *1/ */ - -/* /1* return nodes; *1/ */ -/* /1* } *1/ */ - -/* /1* std::unordered_set get_incoming_edges(G const &g, *1/ */ -/* /1* vertex_type const &n) - * const { *1/ */ -/* /1* if (g.inEdges.find(n) == g.inEdges.end()) { *1/ */ -/* /1* return {}; *1/ */ -/* /1* } else { *1/ */ -/* /1* return {g.inEdges.at(n).begin(), g.inEdges.at(n).end()}; *1/ */ -/* /1* } *1/ */ -/* /1* } *1/ */ - -/* /1* std::unordered_set get_outgoing_edges(G const &g, *1/ */ -/* /1* vertex_type const &n) - * const { *1/ */ -/* /1* if (g.outEdges.find(n) == g.outEdges.end()) { *1/ */ -/* /1* return {}; *1/ */ -/* /1* } else { *1/ */ -/* /1* return {g.outEdges.at(n).begin(), g.outEdges.at(n).end()}; *1/ */ -/* /1* } *1/ */ -/* /1* } *1/ */ - -/* /1* vertex_type get_src(G const &g, edge_type const &e) const { *1/ */ -/* /1* return e.srcOp; *1/ */ -/* /1* } *1/ */ - -/* /1* vertex_type get_dst(G const &g, edge_type const &e) const { *1/ */ -/* /1* return e.dstOp; *1/ */ -/* /1* } *1/ */ - -/* /1* void set_src(G const &g, edge_type &e, vertex_type const &n) const { - * *1/ */ -/* /1* e.srcOp = n; *1/ */ -/* /1* } *1/ */ - -/* /1* void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - * *1/ */ -/* /1* e.dstOp = n; *1/ */ -/* /1* } *1/ */ -/* /1* }; *1/ */ - -/* size_t dp_state_hash(Graph const *graph, */ -/* opmeta::OperatorParameters const &sink_node, */ -/* MachineView const &sink_view, */ -/* opmeta::OperatorParameters const &source_node, */ -/* MachineView const &source_view, */ -/* MachineResource const &resource); */ - -/* // template <> */ -/* // struct invalid_node> { */ -/* // using G = Graph; */ -/* // using Structure = GraphStructure; */ -/* // using vertex_type = typename Structure::vertex_type; */ -/* // */ -/* // vertex_type operator()() const { */ -/* // return vertex_type::INVALID_NODE; */ -/* // } */ -/* // }; */ -/* // */ -/* // template <> */ -/* // struct invalid_node, GraphStructure>> { - */ -/* // Node operator()() const { */ -/* // return Node::INVALID_NODE; */ -/* // } */ -/* // }; */ - -/* /1* } // namespace Utils *1/ */ -/* } // namespace ffc */ -/* } // namespace FlexFlow */ - -#endif diff --git a/lib/compiler/src/old/graph_structures.h b/lib/compiler/src/old/graph_structures.h deleted file mode 100644 index 8b921794e1..0000000000 --- a/lib/compiler/src/old/graph_structures.h +++ /dev/null @@ -1,269 +0,0 @@ -#ifndef _GRAPH_STRUCTURES_H -#define _GRAPH_STRUCTURES_H - -#include "basic_graph.h" - -namespace FlexFlow { -namespace PCG { -namespace Utils { - -template -struct ReverseStructure { - using graph_type = typename BaseStructure::graph_type; - using G = graph_type; - using vertex_type = typename BaseStructure::vertex_type; - using edge_type = typename BaseStructure::edge_type; - - std::unordered_set get_nodes(G const &g) const { - return this->base.get_nodes(g); - } - - std::unordered_set get_incoming_edges(G const &g, - vertex_type const &n) const { - return this->base.get_outgoing_edges(g, n); - } - - std::unordered_set get_outgoing_edges(G const &g, - vertex_type const &n) const { - return this->base.get_incoming_edges(g, n); - } - - vertex_type get_src(G const &g, edge_type const &e) const { - return this->base.get_dst(g, e); - } - - vertex_type get_dst(G const &g, edge_type const &e) const { - return this->base.get_src(g, e); - } - - void set_src(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_dst(g, e, n); - } - - void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_src(g, e, n); - } - - BaseStructure base; -}; - -template -struct UndirectedEdge { - union Edge { - NotReversed not_reversed; - Reversed reversed; - - Edge() {} - }; - - bool is_reversed; - Edge edge; - - UndirectedEdge() {} - - bool operator==(UndirectedEdge const &other) const { - if (other.is_reversed != this->is_reversed) { - return false; - } - if (this->is_reversed) { - return this->edge.reversed == other.edge.reversed; - } else { - return this->edge.not_reversed == other.edge.not_reversed; - } - } -}; - -template > -struct UndirectedStructure { - using graph_type = typename BaseStructure::graph_type; - using vertex_type = typename BaseStructure::vertex_type; - using not_reversed_edge_type = typename BaseStructure::edge_type; - using reversed_edge_type = - typename ReverseStructure::edge_type; - using edge_type = UndirectedEdge; - - std::unordered_set get_nodes(G const &g) const { - return this->base.get_nodes(g); - } - - std::unordered_set get_incoming_edges(G const &g, - vertex_type const &n) const { - std::unordered_set incoming; - auto base_edges = this->base.get_incoming_edges(g, n); - auto reversed_edges = this->reversed.get_incoming_edges(g, n); - - for (auto const &e : base_edges) { - edge_type lifted; - lifted.is_reversed = false; - lifted.edge.not_reversed = e; - incoming.insert(lifted); - } - - for (auto const &e : reversed_edges) { - edge_type lifted; - lifted.is_reversed = true; - lifted.edge.reversed = e; - incoming.insert(lifted); - } - - return incoming; - } - - std::unordered_set get_outgoing_edges(G const &g, - vertex_type const &n) const { - std::unordered_set outgoing; - auto base_edges = this->base.get_outgoing_edges(g, n); - auto reversed_edges = this->reversed.get_outgoing_edges(g, n); - - for (auto const &e : base_edges) { - edge_type lifted; - lifted.is_reversed = false; - lifted.edge.not_reversed = e; - outgoing.insert(lifted); - } - - for (auto const &e : reversed_edges) { - edge_type lifted; - lifted.is_reversed = true; - lifted.edge.reversed = e; - outgoing.insert(lifted); - } - - return outgoing; - } - - vertex_type get_src(G const &g, edge_type const &e) const { - if (e.is_reversed) { - return this->reversed.get_src(g, e.edge.reversed); - } else { - return this->base.get_src(g, e.edge.not_reversed); - } - } - - vertex_type get_dst(G const &g, edge_type const &e) const { - if (e.is_reversed) { - return this->reversed.get_dst(g, e.edge.reversed); - } else { - return this->base.get_dst(g, e.edge.not_reversed); - } - } - - void set_src(G const &g, edge_type &e, vertex_type const &n) const { - if (e.is_reversed) { - this->reversed.set_src(g, e.edge.reversed, n); - } else { - this->base.set_src(g, e.edge.not_reversed, n); - } - } - - void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - if (e.is_reversed) { - this->reversed.set_src(g, e.edge.reversed, n); - } else { - this->base.set_src(g, e.edge.not_reversed, n); - } - } - - BaseStructure base; - ReverseStructure reversed; -}; - -template > -struct invalid_node; - -template , - typename Invalid = invalid_node> -struct MultisourceGraphStructure { - using graph_type = typename BaseStructure::graph_type; - using vertex_type = typename BaseStructure::vertex_type; - using edge_type = typename BaseStructure::edge_type; - - std::unordered_set get_nodes(G const &g) const { - Invalid invalid; - - std::unordered_set nodes = this->base.get_nodes(g); - nodes.insert(invalid()); - return nodes; - } - - std::unordered_set get_incoming_edges(G const &g, - vertex_type const &n) const { - Invalid invalid; - - if (n == invalid()) { - return {}; - } - - std::unordered_set edges = this->base.get_incoming_edges(g, n); - if (edges.empty()) { - edge_type e; - this->base.set_src(g, e, invalid()); - this->base.set_dst(g, e, n); - return {e}; - } - - return edges; - } - - std::unordered_set get_outgoing_edges(G const &g, - vertex_type const &n) const { - Invalid invalid; - - if (n == invalid()) { - std::unordered_set edges; - for (auto const &node : this->base.get_nodes(g)) { - if (this->base.get_incoming_edges(g, node).empty()) { - edge_type e; - this->base.set_src(g, e, invalid()); - this->base.set_dst(g, e, node); - edges.insert(e); - } - } - return edges; - } - - return this->base.get_outgoing_edges(g, n); - } - - vertex_type get_src(G const &g, edge_type const &e) const { - return this->base.get_src(g, e); - } - - vertex_type get_dst(G const &g, edge_type const &e) const { - return this->base.get_dst(g, e); - } - - void set_src(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_src(g, e, n); - } - - void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_dst(g, e, n); - } - - BaseStructure base; -}; -} // namespace Utils -} // namespace PCG -} // namespace FlexFlow - -namespace std { -using FlexFlow::PCG::Utils::UndirectedEdge; - -template -struct hash> { - size_t operator()(UndirectedEdge const &e) const { - size_t result; - result = std::hash()(e.is_reversed); - if (e.is_reversed) { - hash_combine(result, e.edge.reversed); - } else { - hash_combine(result, e.edge.not_reversed); - } - return result; - } -}; -} // namespace std - -#endif // _GRAPH_STRUCTURES_H diff --git a/lib/compiler/src/old/node.h b/lib/compiler/src/old/node.h deleted file mode 100644 index eb33a39ae7..0000000000 --- a/lib/compiler/src/old/node.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef _FLEXFLOW_FFC_NODE_H -#define _FLEXFLOW_FFC_NODE_H - -#include - -#include "op-attrs/op-attrs.h" -#include "tl/optional.hpp" - -namespace FlexFlow { -namespace ffc { - -struct Node { - Node() = delete; - Node(size_t guid, PCGOperatorAttrs const &op_params); - - std::string to_string(void) const; - - using AsTuple = - std::tuple &>; - using AsConstTuple = std::tuple const &>; - - AsTuple as_tuple(); - AsConstTuple as_tuple() const; - -public: - size_t guid; - PCGOperatorAttrs op_params; - tl::optional original_guid = tl::nullopt; -}; - -bool operator==(Node const &, Node const &); -bool operator!=(Node const &, Node const &); -bool operator<(Node const &, Node const &); - -} // namespace ffc -} // namespace FlexFlow - -namespace std { -template <> -struct hash<::FlexFlow::ffc::Node> { - size_t operator()(::FlexFlow::ffc::Node const &n) const; -}; -} // namespace std - -#endif diff --git a/lib/compiler/src/old/parallel_dim_mapping_record.h b/lib/compiler/src/old/parallel_dim_mapping_record.h deleted file mode 100644 index 8e2c265489..0000000000 --- a/lib/compiler/src/old/parallel_dim_mapping_record.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef _FLEXFLOW_FFC_PARALLEL_DIM_MAPPING_RECORD_H -#define _FLEXFLOW_FFC_PARALLEL_DIM_MAPPING_RECORD_H - -#endif diff --git a/lib/compiler/src/old/search_helper.cc b/lib/compiler/src/old/search_helper.cc deleted file mode 100644 index 2e7eafa5fd..0000000000 --- a/lib/compiler/src/old/search_helper.cc +++ /dev/null @@ -1,525 +0,0 @@ -#include "search_helper.h" - -namespace FlexFlow { -namespace PCG { - -SearchHelper::SearchHelper() { - this->logger = std::unique_ptr(new RecursiveLogger("DP")); -} - -template -T SearchHelper::execute_sequence_split(std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - SequenceSplit const &bn) const { - return sequence_cost( - this->graph_cost(pre_graph.get(), source, bn, resources, true), - this->graph_cost(post_graph.get(), bn, sink, resources, false)); -} - -template -T SearchHelper::find_optimal_sequence_graph_time( - Graph const *g, - Node const &bn_node, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const { - std::unique_ptr pre_graph; - std::unique_ptr post_graph; - std::tie(pre_graph, post_graph) = g->split_at_node(bn_node); - - T optimal = this->infinity(); - - std::vector valid_views = - this->get_valid_machine_views(bn_node.op_params, resources); - // A Corner Case: - // If bn_node is a parallel_op and an input to sink_node, - // Add sink_node's view to the list, since sink_node's view - // may not be a valid view for resources, but UniFlow support - // this case since parallel_op does not trigger computation - if (is_parallel_op(bn_node.op_params)) { - bool found = false; - auto const &inList = g->inEdges.find(sink.node)->second; - for (auto const &e : inList) { - if (e.srcOp == bn_node) { - found = true; - break; - } - } - if (found) { - for (int j = 0; j < bn_node.ptr->numOutputs; j++) { - if (!bn_node.ptr->outputs[j]->is_valid_machine_view(sink.view)) { - found = false; - } - } - } - if (found) { - valid_views.push_back(sink.view); - } - } - - if (valid_views.empty()) { - return optimal; - } - - float optimal_cost = std::numeric_limits::infinity(); - MachineView best_view; - - for (MachineView const &bn_view : valid_views) { - float cost = this->execute_sequence_split( - pre_graph, post_graph, source, sink, resources, {bn_node, bn_view}); - - if (cost < optimal_cost) { - best_view = bn_view; - optimal_cost = cost; - } - } - - if (optimal_cost != std::numeric_limits::infinity()) { - optimal = this->execute_sequence_split( - pre_graph, post_graph, source, sink, resources, {bn_node, best_view}); - } - - check_matches_graph(g, optimal, sink.node); - - return optimal; -} - -template -T SearchHelper::execute_nonsequence_split( - std::unique_ptr const &first_graph, - std::unique_ptr const &second_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - NonsequenceSplit const &split) const { - Graph const *first = first_graph.get(); - Graph const *second = second_graph.get(); - if (split.flip_graphs) { - std::swap(first, second); - } - switch (split.type) { - case SplitType::SEQUENTIAL: - this->logger->debug() << "Exploring sequential nonsequence split"; - return sequence_cost( - this->graph_cost(first, source, sink, resources, false), - this->graph_cost(second, source, sink, resources, false)); - case SplitType::VERTICAL: { - this->logger->debug() << "Exploring vertical nonsequence split (" - << split.param << ", " << split.flip_graphs << ")"; - MachineResource firstRes = resources, secondRes = resources; - firstRes.num_nodes = split.param; - secondRes.num_nodes = resources.num_nodes - split.param; - secondRes.start_gpu_id = - resources.start_gpu_id + resources.all_gpus_per_node * split.param; - - return parallel_cost( - this->graph_cost(first, source, sink, firstRes, false), - this->graph_cost(second, source, sink, secondRes, false)); - } - case SplitType::HORIZONTAL: { - this->logger->debug() << "Exploring horizontal nonsequence split (" - << split.param << ", " << split.flip_graphs << ")"; - MachineResource firstRes = resources, secondRes = resources; - firstRes.available_gpus_per_node = split.param; - secondRes.available_gpus_per_node = - resources.available_gpus_per_node - split.param; - secondRes.start_gpu_id = resources.start_gpu_id + split.param; - - return parallel_cost( - this->graph_cost(first, source, sink, firstRes, false), - this->graph_cost(second, source, sink, secondRes, false)); - } - default: - assert(false); - } -} - -template -T SearchHelper::find_optimal_nonsequence_graph_time( - Graph const *g, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const { - std::unique_ptr first_graph; - std::unique_ptr second_graph; - std::tie(first_graph, second_graph) = - g->split_horizontal(source.node, sink.node); - - std::vector potential_splits; - - for (int i = 1; i < resources.num_nodes; i++) { - potential_splits.push_back(NonsequenceSplit::vertical(i, false)); - potential_splits.push_back(NonsequenceSplit::vertical(i, true)); - } - for (int i = 1; i < resources.available_gpus_per_node; i++) { - potential_splits.push_back(NonsequenceSplit::horizontal(i, false)); - potential_splits.push_back(NonsequenceSplit::horizontal(i, true)); - } - - NonsequenceSplit best_split = NonsequenceSplit::sequential(); - float best_cost = this->execute_nonsequence_split( - first_graph, second_graph, source, sink, resources, best_split); - for (NonsequenceSplit const &split : potential_splits) { - float cost = this->execute_nonsequence_split( - first_graph, second_graph, source, sink, resources, split); - this->logger->debug() << "Found cost: " << cost; - - if (cost < best_cost) { - best_cost = cost; - best_split = split; - } - } - - switch (best_split.type) { - case SplitType::SEQUENTIAL: - this->logger->debug() << "Best split: SEQUENTIAL"; - break; - case SplitType::VERTICAL: - this->logger->debug() << "Best split: VERTICAL(" << best_split.param - << ", " << best_split.flip_graphs << ")"; - break; - case SplitType::HORIZONTAL: - this->logger->debug() << "Best split: HORIZONTAL(" << best_split.param - << ", " << best_split.flip_graphs << ")"; - break; - } - T optimal = this->execute_nonsequence_split( - first_graph, second_graph, source, sink, resources, best_split); - - check_matches_graph(g, optimal, sink.node); - - return optimal; -} - -std::vector SearchHelper::get_valid_machine_views( - Node const &node, MachineResource const &resource, bool log) const { - this->logger->info() << "Getting valid machine views for " - << node.to_string(); - return this->get_valid_machine_views(node.ptr, resource, log); -} - -std::vector SearchHelper::get_valid_machine_views( - Op const *op, MachineResource const &resource, bool log) const { - std::vector const *cached_op_views = NULL; - std::vector valid_views; - - auto const &iter = cached_operator_valid_views.find(op->op_guid); - if (iter != cached_operator_valid_views.end()) { - cached_op_views = iter->second.get(); - } else { - auto to_cache = std::unique_ptr>( - new std::vector()); - if (log) { - this->logger->info() << "Considering a total of " - << this->model->all_valid_views.size() - << " potential valid views"; - } - for (size_t i = 0; i < this->model->all_valid_views.size(); i++) { - bool valid = true; - for (int j = 0; j < op->numOutputs; j++) { - if (!op->outputs[j]->is_valid_machine_view( - this->model->all_valid_views[i])) { - valid = false; - { - MachineView const &view = this->model->all_valid_views[i]; - std::ostringstream oss; - oss << "[" << view.ndims << "]("; - for (int i = 0; i < view.ndims; i++) { - oss << view.dim[i] << "/" << view.stride[i]; - if (i != view.ndims - 1) { - oss << " "; - } - } - oss << ")"; - if (log) { - this->logger->info() << "Rejecting machine view: " << oss.str(); - } - } - break; - } - } - if (valid) { - { - MachineView const &view = this->model->all_valid_views[i]; - std::ostringstream oss; - oss << "[" << view.ndims << "]("; - for (int i = 0; i < view.ndims; i++) { - oss << view.dim[i] << "/" << view.stride[i]; - if (i != view.ndims - 1) { - oss << " "; - } - } - oss << ")"; - if (log) { - this->logger->info() << "Accepting machine view: " << oss.str(); - } - } - to_cache->push_back(this->model->all_valid_views[i]); - } - } - cached_operator_valid_views[op->op_guid] = std::move(to_cache); - cached_op_views = cached_operator_valid_views.at(op->op_guid).get(); - } - if (log) { - this->logger->info() << "Found " << cached_op_views->size() - << " cached op views"; - } - for (size_t i = 0; i < cached_op_views->size(); i++) { - MachineView view = (*cached_op_views)[i]; - if (view.device_type == MachineView::GPU) { - view.start_device_id = resource.start_gpu_id; - } else if (view.device_type == MachineView::CPU) { - view.start_device_id = resource.start_cpu_id; - } else { - assert(false); - } - if (resource.is_valid_machine_view(view)) { - valid_views.push_back(view); - } - } - return valid_views; -} - -template <> -bool SearchHelper::is_invalid(float const &cost) const { - return cost == std::numeric_limits::infinity(); -} - -template <> -bool SearchHelper::is_invalid( - GraphCostResult const &cost) const { - return cost.cost == std::numeric_limits::infinity(); -} - -/** - * @brief Asserts that the results of graph optimization are valid for the graph - * - * @param g the graph to check against - * @param r the results to check - * @param sink the sink node of the graph g - * @param include_sink whether or not to include the sink node - */ -template <> -void SearchHelper::check_matches_graph( - Graph const *g, GraphCostResult const &r, Node const &sink) const { - using FlexFlow::PCG::Utils::nodes; - - if (this->is_invalid(r)) { - return; - } - - std::unordered_set g_nodes = nodes(*g); - g_nodes.erase(sink); - - std::unordered_set r_nodes; - for (auto const &kv : r.views) { - r_nodes.insert(kv.first); - } - - assert(g_nodes == r_nodes); -} - -template <> -void SearchHelper::check_matches_graph(Graph const *g, - float const &r, - Node const &sink) const {} - -template <> -std::pair - SearchHelper::try_get_cost_from_cache(size_t hash) const { - if (this->cached_graph_costs.find(hash) == this->cached_graph_costs.end()) { - return {false, std::numeric_limits::infinity()}; - } else { - return {true, this->cached_graph_costs.at(hash)}; - } -} - -template <> -std::pair - SearchHelper::try_get_cost_from_cache(size_t hash) const { - return {false, GraphCostResult::invalid()}; -} - -template <> -void SearchHelper::try_cache_result(size_t hash, - float const &value) const { - this->logger->debug() << "cached_graph_costs[" << hash << "] = " << value; - this->cached_graph_costs[hash] = value; -} - -template <> -void SearchHelper::try_cache_result( - size_t hash, GraphCostResult const &value) const { - this->logger->debug() << "cached_graph_costs[" << hash << "=" << value.cost - << "]"; - this->cached_graph_costs[hash] = value.cost; -} - -template <> -float SearchHelper::infinity() const { - return std::numeric_limits::infinity(); -} - -template <> -GraphCostResult SearchHelper::infinity() const { - return {std::numeric_limits::infinity(), {}}; -} - -template <> -float SearchHelper::empty() const { - return 0.0f; -} - -template <> -GraphCostResult SearchHelper::empty() const { - return {0.0f, {}}; -} - -template -T SearchHelper::estimate_xfer_cost(Graph const *graph, - NodeAssignment const &source, - NodeAssignment const &sink) const { - T result = this->empty(); - - if (source.node != Node::INVALID_NODE) { - auto const &inList = graph->inEdges.find(sink.node)->second; - float op_cost = 0.0f; - for (auto const &it2 : inList) { - assert(it2.srcOp == source.node); - assert(sink.node.ptr->inputs[it2.dstIdx]->is_valid_machine_view( - source.view)); - - float estimated_xfer_cost = this->model->simulator->estimate_xfer_cost( - sink.node.ptr, it2.dstIdx, source.view, sink.view); - // printf("Estimated xfer cost from %s to %s: %fms\n", - // source.node.ptr->name, sink.node.ptr->name, estimated_xfer_cost); - op_cost += estimated_xfer_cost; - } - this->add_operator_cost(source, op_cost, &result); - } else { - Node real_source = graph->find_source_node(); - assert(real_source.ptr->op_type == OP_INPUT); - this->add_operator_cost({real_source, MachineView::NO_VIEW}, 0.0f, &result); - } - - return result; -} - -template <> -void SearchHelper::add_operator_cost(NodeAssignment const &node, - float node_cost, - float *cost) const { - *cost += node_cost; -} - -template <> -void SearchHelper::add_operator_cost( - NodeAssignment const &node, float node_cost, GraphCostResult *cost) const { - cost->cost += node_cost; - cost->views[node.node] = node.view; -} - -template <> -float SearchHelper::get_cost(float const &f) const { - return f; -} - -template <> -float SearchHelper::get_cost( - GraphCostResult const &gcr) const { - return gcr.cost; -} - -template -T SearchHelper::graph_cost(Graph const *graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - bool include_sink_compute_time) const { - TAG_ENTER(this->logger); - this->logger->debug() << "sink(" << sink.node.guid << ") " - << "sink.view(" << sink.view.ndims << " " - << sink.view.start_device_id << " " << sink.view.dim[0] - << ") " - << "source(" << source.node.guid << ") " - << "source.view(" << source.view.ndims << " " - << source.view.start_device_id << " " - << source.view.dim[0] << ") " - << "resources(" << resources.num_nodes << " " - << resources.start_gpu_id << " " - << resources.available_gpus_per_node << ")"; - if (this->model->config.profiling) { - graph->print_dot(); - } - - assert(graph->inEdges.find(sink.node) != graph->inEdges.end()); - if (source.node != Node::INVALID_NODE) { - assert(graph->outEdges.find(source.node) != graph->outEdges.end()); - } - - size_t hash = dp_state_hash( - graph, sink.node, sink.view, source.node, source.view, resources); - this->logger->spew() << "hash = " << hash; - - T result; - - std::pair from_cache = this->try_get_cost_from_cache(hash); - if (from_cache.first) { - // cached_graph_costs does not include sink_compute_time - result = from_cache.second; - } else { - if (graph->inEdges.size() <= 2) { - result = this->estimate_xfer_cost(graph, source, sink); - this->logger->debug() - << "Estimated xfer cost is " << this->get_cost(result); - } else { - Node bn_node = graph->find_bottleneck_node(sink.node, source.node); - if (bn_node != Node::INVALID_NODE) { - // We found a bottleneck node - this->logger->debug() << "Found bn_node = " << bn_node.guid; - - result = this->find_optimal_sequence_graph_time( - graph, - bn_node, - {source.node, source.view}, - {sink.node, sink.view}, - resources); - } else { - // sink node must have multiple branches - // otherwise we should not be here - assert(graph->inEdges.find(sink.node)->second.size() > 1); - - result = this->find_optimal_nonsequence_graph_time( - graph, - {source.node, source.view}, - {sink.node, sink.view}, - resources); - } - } - - this->try_cache_result(hash, result); - } - - check_matches_graph(graph, result, sink.node); - - if (include_sink_compute_time) { - CostMetrics metrics = - this->model->simulator->measure_operator_cost(sink.node.ptr, sink.view); - this->logger->debug() << "Sink node cost: " - << "forward(" << metrics.forward_time << ") " - << "backward(" << metrics.backward_time << ") " - << "sync(" << metrics.sync_time << ")"; - this->add_operator_cost(sink, - metrics.forward_time + metrics.backward_time + - metrics.sync_time, - &result); - } - - return result; -} - -} // namespace PCG -} // namespace FlexFlow diff --git a/lib/compiler/src/old/search_helper.h b/lib/compiler/src/old/search_helper.h deleted file mode 100644 index 95350ce6af..0000000000 --- a/lib/compiler/src/old/search_helper.h +++ /dev/null @@ -1,122 +0,0 @@ -#ifndef _FLEXFLOW_FFC_SRC_SEARCH_HELPER_H -#define _FLEXFLOW_FFC_SRC_SEARCH_HELPER_H - -#include "graph.h" -#include "split_types.h" - -namespace FlexFlow { - -struct GraphCostResult { - float cost; - std::unordered_map views; - - static GraphCostResult invalid(); - - bool operator<(GraphCostResult const &other) const; - - friend std::ostream &operator<<(std::ostream &, GraphCostResult const &); -}; - -template -T sequence_cost(T const &first, T const &second); - -template -T parallel_cost(T const &first, T const &second); - -class SearchHelper { -public: - SearchHelper(); - - template - T graph_cost(Graph const *graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - bool include_sink_compute_time) const; - template - T find_optimal_sequence_graph_time(Graph const *g, - Node const &bottleneck_node, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const; - template - T find_optimal_nonsequence_graph_time(Graph const *g, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const; - /* void find_optimal_nonsequence_graph_views(Graph const *g, */ - /* NodeAssignment const &source, */ - /* NodeAssignment const &sink, */ - /* MachineResource const &resources, - */ - /* float optimal_cost, */ - /* std::unordered_map& optimal_views) const; */ - std::vector - get_valid_machine_views(Node const &node, - MachineResource const &resource, - bool log = false) const; - std::vector - get_valid_machine_views(PCGOperatorAttrs const &op, - MachineResource const &resource, - bool log = false) const; - - template - std::pair try_get_cost_from_cache(size_t hash) const; - - template - void try_cache_result(size_t hash, T const &value) const; - - template - T infinity() const; - - template - T empty() const; - - template - bool is_invalid(T const &) const; - - template - T estimate_xfer_cost(Graph const *g, - NodeAssignment const &source, - NodeAssignment const &sink) const; - - template - void add_operator_cost(NodeAssignment const &, float, T *) const; - - template - float get_cost(T const &) const; - - template - void check_matches_graph(Graph const *, T const &, Node const &) const; - -public: - mutable std::unique_ptr logger; - -private: - template - T execute_nonsequence_split(std::unique_ptr const &first_graph, - std::unique_ptr const &second_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - NonsequenceSplit const &split) const; - - template - T execute_sequence_split(std::unique_ptr const &first_graph, - std::unique_ptr const &second_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - SequenceSplit const &split) const; - -private: - mutable std::unordered_map cached_graph_costs; - mutable std::unordered_map>> - cached_operator_valid_views; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/old/simplification.cc b/lib/compiler/src/old/simplification.cc deleted file mode 100644 index 18fc2fb71a..0000000000 --- a/lib/compiler/src/old/simplification.cc +++ /dev/null @@ -1,189 +0,0 @@ -#include "simplification.h" -#include "spdlog/spdlog.h" -#include - -namespace FlexFlow { -namespace PCG { - -Simplifier::Simplifier(std::string const &logger_name) - : logger(spdlog::get(logger_name)) {} - -void Simplifier::simplify_parallel_ops() { - logger->debug("Trying to simplify parallel ops"); - - /* using FlexFlow::PCG::Utils::nodes; */ - /* using FlexFlow::PCG::Utils::predecessor; */ - /* using FlexFlow::PCG::Utils::predecessors; */ - /* using FlexFlow::PCG::Utils::successor; */ - - std::queue work_queue; - for (Node const &node : nodes(*this)) { - if (node.ptr->is_parallel_op()) { - work_queue.push(node); - } - } - - while (!work_queue.empty()) { - Node node = work_queue.front(); - log_simplify.debug() << "Trying to simplify starting from " - << node.to_string(); - work_queue.pop(); - - auto opt_succ = successor(*this, node); - if (!opt_succ.has_value()) { - log_simplify.debug() << "Skipping because does not have single successor"; - continue; - } - Node succ = opt_succ.value(); - if (!succ.ptr->is_parallel_op()) { - log_simplify.debug() << "Skipping because successor is not a parallel op"; - continue; - } - - std::vector node_parallel_op_info, - successor_parallel_op_info; - ((ParallelOp *)node.ptr)->append_parallel_op_info(node_parallel_op_info); - ((ParallelOp *)succ.ptr) - ->append_parallel_op_info(successor_parallel_op_info); - ParallelOpJoinResult result = try_join_parallel_ops( - node_parallel_op_info.front(), successor_parallel_op_info.front()); - - if (!result.join_did_succeed) { - log_simplify.debug() << "Skipping because join did not succeed"; - continue; - } - log_simplify.debug() << "Did join nodes"; - log_simplify.debug() << " " << node.to_string(); - log_simplify.debug() << " " << succ.to_string(); - - for (Node const &p : predecessors(*this, node)) { - if (p.ptr->is_parallel_op()) { - work_queue.push(p); - } - } - - Graph new_g(this->model); - if (result.op.has_value()) { - Node new_op = this->model->get_or_create_parallel_op_node( - node.ptr->inputs[0], result.op.value()); - work_queue.push(new_op); - new_g.add_node(new_op); - } - this->replace_subgraph({node, succ}, new_g); - } - log_simplify.debug() << "Finished simplifying parallel ops"; -} - -void Graph::simplify(SimplificationSettings const &settings) { - // Simplify the graph by eliminating reverse parallel ops - // and fusing multiple parallel ops - // old graph: e1->n1->e2->n2->en - // new graph: e1->new_node->en - // TODO: temporarily disabled graph simplification - if (settings.simplify_parallel_ops) { - this->simplify_parallel_ops(); - } - if (settings.fuse_parallel_ops) { - bool simplify = true; - while (simplify) { - simplify = false; - for (auto const &it : this->inEdges) { - if (it.first.ptr == NULL) { - continue; - } - if (it.first.ptr->is_parallel_op()) { - Node n2 = it.first; - assert(it.second.size() == 1); - Edge e2 = *it.second.begin(); - Node n1 = e2.srcOp; - // Check that n1 is a parallel op - // Check that n1 must have a single out edge - if (n1.ptr->is_parallel_op() && - this->outEdges.find(n1)->second.size() == 1) { - // merge n1 and n2 - std::vector parallel_ops; - ((ParallelOp *)n1.ptr)->append_parallel_op_info(parallel_ops); - ((ParallelOp *)n2.ptr)->append_parallel_op_info(parallel_ops); - Node new_node = model->get_or_create_fused_parallel_node( - n1.ptr->inputs[0], parallel_ops); - auto const &inList = this->inEdges.find(n1)->second; - assert(inList.size() == 1); - Edge e1 = *inList.begin(); - // Update graph by adding edges - this->add_edge(e1.srcOp, new_node, e1.srcIdx, 0); - this->remove_edge(e1); - this->remove_edge(e2); - // make a copy of outList - if (this->outEdges.find(n2) != this->outEdges.end()) { - auto const outList = this->outEdges.find(n2)->second; - for (auto const &e : outList) { - this->add_edge(new_node, e.dstOp, 0, e.dstIdx); - this->remove_edge(e); - } - } - simplify = true; - } - } - if (simplify) { - break; - } - } - } - } - - if (settings.remove_trailing_parallel_ops) { - // Remove final parallel ops - std::vector candidates; - for (auto const &it : this->outEdges) { - if (it.second.size() == 0 && it.first.ptr->op_type != OP_REDUCTION && - it.first.ptr->op_type != OP_FUSED_PARALLEL && - it.first.ptr->is_parallel_op()) { - candidates.push_back(it.first); - } - } - size_t index = 0; - while (index < candidates.size()) { - Node parallel_op = candidates[index++]; - auto const &inList = this->inEdges.find(parallel_op)->second; - assert(inList.size() == 1); - Edge e = *inList.begin(); - this->remove_edge(e); - if (this->outEdges.find(e.srcOp)->second.size() == 0 && - e.srcOp.ptr->is_parallel_op()) { - candidates.push_back(e.srcOp); - } - } - } - - if (settings.remove_noops) { - // Remove NoOps - std::vector noop_nodes; - for (auto const &it : this->inEdges) { - if (it.first.ptr == NULL) { - continue; - } - if (it.first.ptr->op_type == OP_NOOP) { - noop_nodes.push_back(it.first); - } - } - size_t index = 0; - while (index < noop_nodes.size()) { - Node noop = noop_nodes[index++]; - auto const &inList = this->inEdges.find(noop)->second; - assert(inList.size() == 1); - Edge in_edge = *inList.begin(); - // make a copy of outList - if (this->outEdges.find(noop) != this->outEdges.end()) { - auto const outList = this->outEdges.find(noop)->second; - for (auto const &e : outList) { - this->add_edge(in_edge.srcOp, e.dstOp, in_edge.srcIdx, e.dstIdx); - this->remove_edge(e); - } - } - this->remove_edge(in_edge); - } - } -} - -} // namespace PCG -} // namespace FlexFlow diff --git a/lib/compiler/src/old/simplification.h b/lib/compiler/src/old/simplification.h deleted file mode 100644 index d83c16eb91..0000000000 --- a/lib/compiler/src/old/simplification.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef _FLEXFLOW_FFC_SIMPLIFICATION_H -#define _FLEXFLOW_FFC_SIMPLIFICATION_H - -#include "graph.h" -#include "spdlog/spdlog.h" -#include - -namespace FlexFlow { -namespace PCG { - -struct SimplificationSettings { - bool simplify_parallel_ops = false; - bool fuse_parallel_ops = false; - bool remove_trailing_parallel_ops = false; - bool remove_noops = false; -}; - -class Simplifier { -public: - Simplifier(std::string const &logger_name); - - Graph const &simplify(SimplificationSettings const &, Graph const &); - -private: - void simplify_parallel_ops(); - -private: - std::shared_ptr logger; -}; - -} // namespace PCG -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/old/split_types.cc b/lib/compiler/src/old/split_types.cc deleted file mode 100644 index e9648344d4..0000000000 --- a/lib/compiler/src/old/split_types.cc +++ /dev/null @@ -1,36 +0,0 @@ -#include "split_types.h" - -namespace FlexFlow { -namespace PCG { - -/*static*/ -NonsequenceSplit NonsequenceSplit::sequential() { - NonsequenceSplit s; - s.type = SplitType::SEQUENTIAL; - s.flip_graphs = false; - - return s; -} - -/*static*/ -NonsequenceSplit NonsequenceSplit::vertical(int param, bool flip_graphs) { - NonsequenceSplit s; - s.type = SplitType::VERTICAL; - s.param = param; - s.flip_graphs = flip_graphs; - - return s; -} - -/*static*/ -NonsequenceSplit NonsequenceSplit::horizontal(int param, bool flip_graphs) { - NonsequenceSplit s; - s.type = SplitType::HORIZONTAL; - s.param = param; - s.flip_graphs = flip_graphs; - - return s; -} - -} // namespace PCG -} // namespace FlexFlow diff --git a/lib/compiler/src/old/split_types.h b/lib/compiler/src/old/split_types.h deleted file mode 100644 index 3c49ad5b7a..0000000000 --- a/lib/compiler/src/old/split_types.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _FLEXFLOW_FFC_SPLIT_TYPES_H -#define _FLEXFLOW_FFC_SPLIT_TYPES_H - -#include "node.h" -#include "pcg/machine_view.h" - -namespace FlexFlow { -namespace PCG { - -enum class SplitType { SEQUENTIAL, VERTICAL, HORIZONTAL }; - -struct NonsequenceSplit { - SplitType type; - int param; - bool flip_graphs; - - static NonsequenceSplit sequential(); - static NonsequenceSplit vertical(int param, bool flip_graphs); - static NonsequenceSplit horizontal(int param, bool flip_graphs); -}; - -struct NodeAssignment { - Node node; - MachineView view; -}; - -using SequenceSplit = NodeAssignment; - -} // namespace PCG -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/old/substitution.cc b/lib/compiler/src/old/substitution.cc deleted file mode 100644 index 9f8381093c..0000000000 --- a/lib/compiler/src/old/substitution.cc +++ /dev/null @@ -1,3733 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "substitution.h" -#include "graph.h" -#include "graph_structures.h" -#include "op-meta/op-meta.h" -#include "parallel_ops/combine.h" -#include "parallel_ops/fused_parallel_op.h" -#include "parallel_ops/partition.h" -#include "parallel_ops/reduction.h" -#include "parallel_ops/replicate.h" -#include "utils/dot/dot_file.h" -#include -#include - -using namespace ::FlexFlow::substitutions; - -namespace FlexFlow { -namespace ffc { - -const TensorX TensorX::NO_TX = TensorX(); - -bool TensorX::operator==(TensorX const &other) const { - return this->op == other.op && this->idx == other.idx; -} - -bool TensorX::operator!=(TensorX const &other) const { - return !this->operator==(other); -} - -Rule create_combine_inception(int num_convs, int num_dims, int num_parts); -Rule create_combine_concat(int num_inputs, int num_dims, int num_parts); -Rule create_replicate_linear_combine(int num_dims, - int num_parts, - ActiMode activation, - bool use_bias); -Rule create_partition_linear_combine(int num_dims, - int num_parts, - ActiMode activation, - bool use_bias); -Rule create_partition_conv2d_combine(int num_dims, int num_parts); -Rule create_partition_attention_combine(int num_heads, int num_parts); -Rule create_replicate_attention_reduce(int num_heads, int num_parts); -Rule create_partition_add_combine(int parallel_dim, int num_parts); -Rule create_partition_relu_combine(int parallel_dim, int num_parts); -Rule create_partition_concat_combine(int num_inputs, - int concat_dim, - int parallel_dim, - int num_parts); -Rule create_partition_softmax_combine(int softmax_dim, - int part_dim, - int num_parts); -Rule leading_relu_branch_combine(int parallel_dim, - int num_parts, - int num_combines); -Rule leading_relu_branch_partition(int parallel_dim, - int num_parts, - int num_partitions); -Rule create_linear_relu_merge(int num_dims, bool use_bias); - -PMConstraint::PMConstraint(Compare c, PMParameter p, int v) - : comp(c), para(p), value(v) {} - -TNConstraint::TNConstraint(Compare c, TNParameter p, DIMParameter d, int v) - : singlePara(true), comp(c), para1(p), dim1(d), value(v) {} - -TNConstraint::TNConstraint( - Compare c, TNParameter p1, DIMParameter d1, TNParameter p2, DIMParameter d2) - : singlePara(false), comp(c), para1(p1), para2(p2), dim1(d1), dim2(d2) {} - -tl::optional TensorX::to_tensor(GraphXfer const *xfer) const { - if (op != NULL) { - assert(op->mapOp.ptr != NULL); - return op->mapOp.ptr->outputs[idx]; - } else { - auto const &it = xfer->mappedInputs.find(idx); - if (it == xfer->mappedInputs.end()) { - return tl::nullopt; - } - assert(it != xfer->mappedInputs.end()); - Node op = it->second.first; - int outIdx = it->second.second; - return op.ptr->outputs[outIdx]; - } -} - -OpX::OpX(const OperatorType _type, - int num_inputs, - int num_outputs, - TensorX const &input0, - TensorX const &input1, - TensorX const &input2, - TensorX const &input3) - : type(_type), mapOp(Node::INVALID_NODE), matchOpX(NULL) { - TensorX all_inputs[MAX_NUM_INPUTS]; - all_inputs[0] = input0; - all_inputs[1] = input1; - all_inputs[2] = input2; - all_inputs[3] = input3; - for (int i = 0; i < num_inputs; i++) { - inputs.push_back(all_inputs[i]); - } - for (int i = 0; i < num_outputs; i++) { - TensorX out(this, i); - outputs.push_back(out); - } -} - -OpX::OpX(const OperatorType _type, - int num_inputs, - int num_outputs, - TensorX const *input_array) - : type(_type), mapOp(Node::INVALID_NODE), matchOpX(NULL) { - for (int i = 0; i < num_inputs; i++) { - inputs.push_back(input_array[i]); - } - for (int i = 0; i < num_outputs; i++) { - TensorX out(this, i); - outputs.push_back(out); - } -} - -bool OpX::add_pm_constraint(Compare comp, PMParameter para, int value) { - PMConstraint pmc(comp, para, value); - pmConstraints.push_back(pmc); - return true; -} - -bool OpX::add_input_constraint(Compare comp, - TNParameter para, - DIMParameter dim, - int value) { - TNConstraint tnc(comp, para, dim, value); - tnConstraints.push_back(tnc); - return true; -} - -bool OpX::add_input_constraint(Compare comp, - TNParameter para1, - DIMParameter dim1, - TNParameter para2, - DIMParameter dim2) { - TNConstraint tnc(comp, para1, dim1, para2, dim2); - tnConstraints.push_back(tnc); - return true; -} - -bool OpX::get_pm_constraint(PMParameter para, int &value) const { - for (size_t i = 0; i < pmConstraints.size(); i++) { - if ((pmConstraints[i].comp == COMPARE_EQ) && - (pmConstraints[i].para == para)) { - value = pmConstraints[i].value; - return true; - } - } - return false; -} - -GraphXfer::GraphXfer(FFModel *_model) : model(_model), tensorId(10) {} - -TensorX GraphXfer::new_tensor(void) { - TensorX t; - t.op = NULL; - t.idx = tensorId++; - return t; -} - -bool GraphXfer::map_output(TensorX const &src, TensorX const &dst) { - mappedOutputs[src] = dst; - return true; -} - -bool GraphXfer::can_match(OpX *srcOp, Node const &op, Graph const *graph) { - if (srcOp->type != op.ptr->op_type) { - return false; - } - // check num input tensors - if ((int)srcOp->inputs.size() != op.ptr->numInputs) { - return false; - } - // check pmConstraints - for (size_t i = 0; i < srcOp->pmConstraints.size(); i++) { - PMConstraint pmc = srcOp->pmConstraints[i]; - int actValue = 0; - assert(op.ptr->get_int_parameter(pmc.para, &actValue)); - // printf("pmc[%d] para(%d) comp(%d) value(%d) actValue(%d)\n", - // i, pmc.para, pmc.comp, pmc.value, actValue); - switch (pmc.comp) { - case COMPARE_EQ: { - if (actValue != pmc.value) { - return false; - } - break; - } - case COMPARE_NE: { - if (actValue == pmc.value) { - return false; - } - break; - } - case COMPARE_LT: { - if (actValue >= pmc.value) { - return false; - } - break; - } - case COMPARE_LE: { - if (actValue > pmc.value) { - return false; - } - break; - } - case COMPARE_GT: { - if (actValue <= pmc.value) { - return false; - } - break; - } - case COMPARE_GE: { - if (actValue < pmc.value) { - return false; - } - break; - } - default: - assert(false); - } - } - // check inputs - std::map> newMapInputs; - for (size_t i = 0; i < srcOp->inputs.size(); i++) { - TensorX in = srcOp->inputs[i]; - if (in.op == NULL) { - // input tensor - std::multimap>::const_iterator it; - it = mappedInputs.find(in.idx); - if (it != mappedInputs.end()) { - Node mappedOp = it->second.first; - int mappedIdx = it->second.second; - if (!(graph->has_edge(mappedOp, op, mappedIdx, i))) { - return false; - } - } else { - std::map>::const_iterator newit; - newit = newMapInputs.find(in.idx); - if (newit != newMapInputs.end()) { - Node mappedOp = newit->second.first; - int mappedIdx = newit->second.second; - if (!(graph->has_edge(mappedOp, op, mappedIdx, i))) { - return false; - } - } else { - auto const &list = graph->inEdges.find(op)->second; - for (auto const &e : list) { - if (e.dstIdx == (int)i) { - newMapInputs.insert( - std::make_pair(in.idx, std::make_pair(e.srcOp, e.srcIdx))); - } - } - } - // Do nothing when we check the match - /* mapped in.idx to an op - std::set list = graph->inEdges.find(op)->second; - std::set::const_iterator it2; - for (it2 = list.begin(); it2 != list.end(); it2++) { - Edge e = *it2; - if (e.dstIdx == i) - mappedInputs[in.idx] = std::make_pair(e.srcOp, e.srcIdx); - }*/ - } - } else { - // intermediate tensor - assert(in.op->mapOp != Node::INVALID_NODE); - if (!(graph->has_edge(in.op->mapOp, op, in.idx, i))) { - return false; - } - } - } - // check tnConstraints - for (size_t i = 0; i < srcOp->tnConstraints.size(); i++) { - TNConstraint tnc = srcOp->tnConstraints[i]; - int actValue = 0, expValue = 0; - if (tnc.singlePara) { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - expValue = tnc.value; - } else { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - assert(op.ptr->get_tensor_parameter(tnc.para2, tnc.dim2, &expValue)); - } - switch (tnc.comp) { - case COMPARE_EQ: { - if (actValue != expValue) { - return false; - } - break; - } - case COMPARE_NE: { - if (actValue == expValue) { - return false; - } - break; - } - case COMPARE_LT: { - if (actValue >= expValue) { - return false; - } - break; - } - case COMPARE_LE: { - if (actValue > expValue) { - return false; - } - break; - } - case COMPARE_GT: { - if (actValue <= expValue) { - return false; - } - break; - } - case COMPARE_GE: { - if (actValue < expValue) { - return false; - } - break; - } - default: - assert(false); - } - } - return true; -} - -void GraphXfer::match(OpX *srcOp, Node const &op, Graph const *graph) { - for (size_t i = 0; i < srcOp->inputs.size(); i++) { - TensorX in = srcOp->inputs[i]; - if (in.op == NULL) { - // Update mappedInputs - auto const &list = graph->inEdges.find(op)->second; - for (auto const &e : list) { - if (e.dstIdx == (int)i) { - mappedInputs.insert( - std::make_pair(in.idx, std::make_pair(e.srcOp, e.srcIdx))); - } - } - } - } - // Map srcOp to Op - srcOp->mapOp = op; - mappedOps[op] = srcOp; -} - -void GraphXfer::unmatch(OpX *srcOp, Node const &op, Graph const *graph) { - for (size_t i = 0; i < srcOp->inputs.size(); i++) { - log_xfer_matches.spew() << "umatch iteration " << i; - TensorX in = srcOp->inputs[i]; - if (in.op == NULL) { - // Update mappedInputsa - std::multimap>::iterator it; - log_xfer_matches.spew() << "Starting find"; - it = mappedInputs.find(in.idx); - log_xfer_matches.spew() << "Finished find"; - if (it != mappedInputs.end()) { - mappedInputs.erase(it); - } - } - } - log_xfer_matches.spew() << "Finished the unmatch loop"; - // Unmap op - mappedOps.erase(op); - srcOp->mapOp.guid = 0; - srcOp->mapOp.ptr = NULL; - log_xfer_matches.spew() << "Returning from unmatch"; -} - -GraphXferMatch::GraphXferMatch(GraphXfer const *xfer) : xfer(xfer) {} - -void GraphXferMatch::add_mapping(Node const &node, OpX *opx) { - this->nodeToOpX[node] = opx; - this->opXToNode[opx] = node; -} - -void GraphXferMatch::add_mapping(OpX *opx, Node const &node) { - this->add_mapping(node, opx); -} - -void GraphXferMatch::add_output_mapping(TensorX const &src, - TensorX const &dst) { - this->mappedOutputs[src] = dst; -} - -OpX *GraphXferMatch::at(Node const &n) const { - return this->nodeToOpX.at(n); -} - -Node GraphXferMatch::at(OpX *opx) const { - return this->opXToNode.at(opx); -} - -void GraphXferMatch::set_graph(Graph const *g) { - this->graph_hash = g->hash(); -} - -bool GraphXferMatch::containsNode(Graph const *g, Node const &n) const { - assert(g->hash() == this->graph_hash); - - return this->nodeToOpX.find(n) != this->nodeToOpX.end(); -} - -bool GraphXferMatch::containsEdge(Graph const *g, Edge const &e) const { - assert(g->hash() == this->graph_hash); - - bool contains_src = this->containsNode(g, e.srcOp); - bool contains_dst = this->containsNode(g, e.dstOp); - - return contains_src && contains_dst; -} - -GraphXfer const *GraphXferMatch::get_xfer() const { - return this->xfer; -} - -std::unordered_set GraphXferMatch::get_nodes() const { - std::unordered_set nodes; - for (auto const &kv : nodeToOpX) { - nodes.insert(kv.first); - } - - return nodes; -} - -GraphXferMatch GraphXfer::get_match_record(Graph const *g) const { - GraphXferMatch match(this); - - for (auto const &kv : this->mappedOps) { - match.add_mapping(kv.first, kv.second); - } - - for (auto const &kv : this->mappedOutputs) { - match.add_output_mapping(kv.first, kv.second); - } - - match.set_graph(g); - - return match; -} - -void GraphXfer::find_matches(Graph const *graph, - std::vector &matches) { - this->find_matches(0, graph, matches); -} - -void GraphXfer::find_matches(int depth, - Graph const *graph, - std::vector &matches) { - log_xfer_matches.spew() << "find_matches at depth: " << depth; - if (depth >= (int)srcOps.size()) { - log_xfer_matches.spew() << "Achieved adequate depth"; - // Create dst operators - bool pass = true; - for (OpX *dstOp : this->dstOps) { - pass &= create_new_operator(dstOp, dstOp->mapOp); - if (!pass) { - break; - } - } - log_xfer_matches.spew() << "Completed create dst operators"; - if (!pass) { - log_xfer_matches.spew() << "Did not pass. Returning."; - return; - } - log_xfer_matches.spew() << "Checking external edges"; - // Check that output tensors with external edges are mapped - for (auto const &opIt : mappedOps) { - auto const &list = graph->outEdges.at(opIt.first); - for (auto const &e : list) { - if (mappedOps.find(e.dstOp) == mappedOps.end()) { - // dstOp is external, (srcOp, srcIdx) must be in mappedOutputs - TensorX srcTen; - srcTen.op = opIt.second; - srcTen.idx = e.srcIdx; - if (mappedOutputs.find(srcTen) == mappedOutputs.end()) { - pass = false; - return; - } - } - } - } - log_xfer_matches.spew() << "Completed checking external edges"; - // Generate a new graph by applying xfer rule - log_xfer_matches.spew() << "Creating new graph"; - SimplificationSettings - settings; // leave everything disabeld since we don't care about cost - Graph *newGraph = this->create_new_graph(graph, settings); - log_xfer_matches.spew() << "Completed creating new graph"; - - // Check that the new graph should not have any loop - log_xfer_matches.spew() << "Checking for loop"; - if (newGraph->has_loop()) { - printf("Found a new graph with LOOP!!!!\n"); - newGraph->print(); - delete newGraph; - return; - } - log_xfer_matches.spew() << "Finished checking for loop"; - // TODO: remove me for better performance - log_xfer_matches.spew() << "Checking correctness"; - assert(newGraph->check_correctness()); - log_xfer_matches.spew() << "Finished checking correctness"; - log_xfer_matches.spew() << "Getting match record"; - GraphXferMatch match_record = this->get_match_record(graph); - log_xfer_matches.spew() << "Finished getting match record"; - matches.push_back(match_record); - } else { - OpX *srcOp = srcOps[depth]; - for (auto const &it : graph->inEdges) { - log_xfer_matches.spew() << "Exploring node " << it.first.to_string(); - // printf("can_match(%d)\n", can_match(srcOp, it->first, graph)); - if (can_match(srcOp, it.first, graph) && - (mappedOps.find(it.first) == mappedOps.end())) { - Node op = it.first; - // Check mapOutput - this->match(srcOp, op, graph); - this->find_matches(depth + 1, graph, matches); - log_xfer_matches.spew() << "Completed find matches. Unmatching"; - this->unmatch(srcOp, op, graph); - log_xfer_matches.spew() << "Finished unmatching"; - } - } - } -} - -template -void GraphXfer::run( - int depth, - Graph *graph, - std::priority_queue, GraphComparator> - &candidates, - std::unordered_set &hashmap, - float threshold, - int maxNumOps, - SimplificationSettings const &simplification_settings, - int &num_matches_found, - int &num_matches_rejected) { - // printf("run: depth(%d) srcOps.size(%zu) graph.size(%zu) candidates(%zu)\n", - // depth, srcOps.size(), graph->inEdges.size(), candidates.size()); - if (depth >= (int)srcOps.size()) { - // Create dst operators - bool pass = true; - for (OpX *dstOp : this->dstOps) { - if (pass) { - pass &= create_new_operator(dstOp, dstOp->mapOp); - } - } - if (!pass) { - return; - } - // Check that output tensors with external edges are mapped - for (auto const &opIt : mappedOps) { - auto const &list = graph->outEdges[opIt.first]; - for (auto const &e : list) { - if (mappedOps.find(e.dstOp) == mappedOps.end()) { - // dstOp is external, (srcOp, srcIdx) must be in mappedOutputs - TensorX srcTen; - srcTen.op = opIt.second; - srcTen.idx = e.srcIdx; - if (mappedOutputs.find(srcTen) == mappedOutputs.end()) { - pass = false; - return; - } - } - } - } - // Generate a new graph by applying xfer rule - log_xfers.spew() << "Found a match for xfer: " << this->get_name(); - num_matches_found++; - Graph *newGraph = this->create_new_graph(graph, simplification_settings); - // Check that the new graph should not have any loop - if (newGraph->has_loop()) { - printf("Found a new graph with LOOP!!!!\n"); - newGraph->print(); - delete newGraph; - return; - } - // TODO: remove me for better performance - assert(newGraph->check_correctness()); - if (newGraph->optimal_cost() < threshold && - (int)newGraph->inEdges.size() < maxNumOps) { - if (hashmap.find(newGraph->hash()) == hashmap.end()) { - hashmap.insert(newGraph->hash()); - log_xfers.spew() << "Found new candidate"; - // newGraph->print_dot(); - candidates.push(newGraph); - } - } else { - num_matches_rejected++; - delete newGraph; - } - } else { - OpX *srcOp = srcOps[depth]; - for (auto const &it : graph->inEdges) { - // printf("can_match(%d)\n", can_match(srcOp, it->first, graph)); - if (can_match(srcOp, it.first, graph) && - (mappedOps.find(it.first) == mappedOps.end())) { - Node op = it.first; - // Check mapOutput - match(srcOp, op, graph); - run(depth + 1, - graph, - candidates, - hashmap, - threshold, - maxNumOps, - simplification_settings, - num_matches_found, - num_matches_rejected); - unmatch(srcOp, op, graph); - } - } - } -} - -void Graph::reshape_output_tensor(ParallelTensorShape const &desired_shape) { - Node output_node = this->find_sink_node(); - - assert(output_node.ptr->numOutputs == 1); - ParallelTensor output_tensor = output_node.ptr->outputs[0]; - - assert(output_tensor->num_dims == desired_shape.num_dims); - - for (int i = 0; i < output_tensor->num_dims; i++) { - int current_size = output_tensor->dims[i].size; - int current_degree = output_tensor->dims[i].degree; - - int desired_size = desired_shape.dims[i].size; - int desired_degree = desired_shape.dims[i].degree; - - assert(current_size == desired_size); - - if (current_degree < desired_degree) { - // we need to partition - assert(desired_degree % current_degree == 0); - int partition_factor = desired_degree / current_degree; - - Node partition_node = model->get_or_create_node( - output_tensor, {i /*legion_dim*/, partition_factor}); - this->add_edge(output_node, partition_node, 0, 0); - - output_node = partition_node; - output_tensor = partition_node.ptr->outputs[0]; - current_degree *= partition_factor; - - } else if (current_degree > desired_degree) { - // we need to combine - assert(current_degree % desired_degree == 0); - int combine_factor = current_degree / desired_degree; - - Node combine_node = model->get_or_create_node( - output_tensor, {i /*legion_dim*/, combine_factor}); - this->add_edge(output_node, combine_node, 0, 0); - - output_node = combine_node; - output_tensor = combine_node.ptr->outputs[0]; - current_degree /= combine_factor; - } - - assert(current_degree == desired_degree); - } - - assert(output_tensor == output_node.ptr->outputs[0]); - assert(output_tensor->num_dims == desired_shape.num_dims); - for (int i = 0; i < desired_shape.num_dims; i++) { - assert(output_tensor->dims[i].size == desired_shape.dims[i].size); - assert(output_tensor->dims[i].degree == desired_shape.dims[i].degree); - } -} - -std::unique_ptr Graph::with_output_tensor_reshaped_to( - ParallelTensorShape const &shape) const { - auto g = std::unique_ptr(new Graph(*this)); - g->reshape_output_tensor(shape); - return g; -} - -/* Graph::Graph(Graph const &graph) */ -/* : Graph(&graph) */ -/* { } */ - -/* Graph::Graph(Graph const *graph) */ -/* : Graph(graph->model) */ -/* { */ -/* for (auto const &kv : graph->inEdges) { */ -/* Node const &node = kv.first; */ -/* std::unordered_set const &edge_set = kv.second; */ - -/* for (auto const &edge : edge_set) { */ -/* this->add_edge(edge.srcOp, edge.dstOp, edge.srcIdx) */ -/* } */ -/* } */ -/* } */ - -Graph *GraphXfer::create_new_graph( - Graph const *graph, SimplificationSettings const &simplification_settings) { - Graph *newGraph = new Graph(model); - // Step 1: map dst ops - std::vector::const_iterator dstIt; - // Step 2: add edges to the graph - for (auto const &opIt : graph->inEdges) { - if (mappedOps.find(opIt.first) == mappedOps.end()) { - // Unmapped ops - auto const &list = opIt.second; - for (auto const &it : list) { - if (mappedOps.find(it.srcOp) != mappedOps.end()) { - // mapped src -> unmapped dst - TensorX srcTen; - srcTen.op = mappedOps[it.srcOp]; - srcTen.idx = it.srcIdx; - assert(mappedOutputs.find(srcTen) != mappedOutputs.end()); - TensorX dstTen = mappedOutputs[srcTen]; - newGraph->add_edge(dstTen.op->mapOp, it.dstOp, dstTen.idx, it.dstIdx); - } else { - // unmapped src -> unmmaped dst - newGraph->add_edge(it.srcOp, it.dstOp, it.srcIdx, it.dstIdx); - } - } - } - } - // Step 3: add edges for mapped ops - for (dstIt = dstOps.begin(); dstIt != dstOps.end(); dstIt++) { - OpX *dstOp = *dstIt; - for (size_t i = 0; i < dstOp->inputs.size(); i++) { - if (dstOp->inputs[i].op == NULL) { - // unmapped src -> mapped dst - std::multimap>::const_iterator it = - mappedInputs.find(dstOp->inputs[i].idx); - assert(it != mappedInputs.end()); - std::pair const &srcEdge = it->second; - newGraph->add_edge(srcEdge.first, dstOp->mapOp, srcEdge.second, i); - } else { - // mapped src -> mapped dst - OpX *srcOp = dstOp->inputs[i].op; - int srcIdx = dstOp->inputs[i].idx; - newGraph->add_edge(srcOp->mapOp, dstOp->mapOp, srcIdx, i); - } - } - } - newGraph->simplify(simplification_settings); - - return newGraph; -} - -bool GraphXfer::create_new_operator(OpX const *opx, Node &op) { - ParallelTensor inputs[MAX_NUM_INPUTS]; - for (size_t i = 0; i < opx->inputs.size(); i++) { - tl::optional mapped = opx->inputs[i].to_tensor(this); - if (!mapped.has_value()) { - return false; - } - inputs[i] = mapped.value(); - } - // Check that the total degree of inputs[0] does not exceed available - // resources - if (opx->inputs.size() > 0) { - int degree = 1; - for (int i = 0; i < inputs[0]->num_dims; i++) { - degree *= inputs[0]->dims[i].degree; - } - if (degree > model->config.workersPerNode * model->config.numNodes && - (degree > model->config.cpusPerNode * model->config.numNodes)) { - return false; - } - } - int num_inputs; - if (opx->get_pm_constraint(PM_NUM_INPUTS, num_inputs) && - opx->inputs.size() != num_inputs) { - return false; - } - int num_outputs; - if (opx->get_pm_constraint(PM_NUM_OUTPUTS, num_outputs) && - opx->outputs.size() != num_outputs) { - return false; - } - switch (opx->type) { - case OP_NOOP: { - op = model->get_or_create_noop_node(inputs[0]); - break; - } - case OP_CONCAT: { - int axis; - assert(opx->get_pm_constraint(PM_AXIS, axis)); - op = model->get_or_create_node( - {std::begin(inputs), std::end(inputs)}, {axis}); - break; - } - case OP_SPLIT: { - int axis; - assert(opx->get_pm_constraint(PM_AXIS, axis)); - int num_outputs = opx->outputs.size(); - int input_size = inputs[0]->dims[axis].size; - - if (input_size % num_outputs != 0) { - op = Node::INVALID_NODE; - } else { - int split_size = input_size / num_outputs; - std::vector split_sizes(num_outputs, split_size); - assert(split_sizes.size() == num_outputs); - op = model->get_or_create_node(inputs[0], {split_sizes, axis}); - } - break; - } - case OP_EW_ADD: - case OP_EW_SUB: - case OP_EW_MUL: - case OP_EW_MAX: - case OP_EW_MIN: { - op = model->get_or_create_node({inputs[0], inputs[1]}, - {opx->type}); - break; - } - case OP_RELU: { - ElementUnaryParams params; - params.op_type = opx->type; - params.inplace = false; - params.scalar = 0.0f; - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_CONV2D: { - Conv2D *conv = (Conv2D *)opx->matchOpX->mapOp.ptr; - Conv2DParams params = conv->get_params(); - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_POOL2D: { - Pool2D *pool = (Pool2D *)opx->matchOpX->mapOp.ptr; - Pool2DParams params = pool->get_params(); - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_FLAT: { - Flat *flat = (Flat *)opx->matchOpX->mapOp.ptr; - op = model->get_or_create_node(inputs[0], {}); - break; - } - case OP_LINEAR: { - int activation; - assert(opx->matchOpX != NULL); - assert(opx->matchOpX->mapOp.ptr != NULL); - Linear *linear = (Linear *)opx->matchOpX->mapOp.ptr; - // assert(opx->get_pm_constraint(PM_OUTPUT_CHANNELS, output_channels)); - assert(opx->get_pm_constraint(PM_ACTI, activation)); - LinearParams params = linear->get_params(); - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_MULTIHEAD_ATTENTION: { - int num_heads; - assert(opx->matchOpX != NULL); - assert(opx->matchOpX->mapOp.ptr != NULL); - MultiHeadAttention *attn = (MultiHeadAttention *)opx->matchOpX->mapOp.ptr; - assert(opx->get_pm_constraint(PM_NUM_HEADS, num_heads)); - MultiHeadAttentionParams params = attn->get_params(); - op = model->get_or_create_node( - {inputs[0], inputs[1], inputs[2]}, params); - break; - } - case OP_SOFTMAX: { - int softmax_dim; - assert(opx->get_pm_constraint(PM_SOFTMAX_DIM, softmax_dim)); - op = model->get_or_create_node(inputs[0], {softmax_dim}); - break; - } - case OP_REPARTITION: { - int repartition_dim, repartition_degree; - assert(opx->get_pm_constraint(PM_REPARTITION_DIM, repartition_dim)); - assert(opx->get_pm_constraint(PM_REPARTITION_DEGREE, repartition_degree)); - - int degree = inputs[0]->get_total_num_parts() * repartition_degree; - if (degree > model->config.workersPerNode * model->config.numNodes && - (degree > model->config.cpusPerNode * model->config.numNodes)) { - op = Node::INVALID_NODE; - } else { - op = model->get_or_create_node( - inputs[0], {repartition_dim, repartition_degree}); - } - break; - } - case OP_REPLICATE: { - int replicate_dim, replicate_degree; - assert(opx->get_pm_constraint(PM_REPLICATE_DIM, replicate_dim)); - assert(opx->get_pm_constraint(PM_REPLICATE_DEGREE, replicate_degree)); - - if (inputs[0]->dims[replicate_dim].degree * replicate_degree > - model->config.workersPerNode) { - op = Node::INVALID_NODE; - } else { - int degree = inputs[0]->get_total_num_parts() * replicate_degree; - if (degree > model->config.workersPerNode * model->config.numNodes && - (degree > model->config.cpusPerNode * model->config.numNodes)) { - op = Node::INVALID_NODE; - } else { - op = model->get_or_create_node( - inputs[0], {replicate_dim, replicate_degree}); - } - } - break; - } - case OP_REDUCTION: { - int reduction_dim, reduction_degree; - assert(opx->get_pm_constraint(PM_REDUCTION_DIM, reduction_dim)); - assert(opx->get_pm_constraint(PM_REDUCTION_DEGREE, reduction_degree)); - op = model->get_or_create_node( - inputs[0], {reduction_dim, reduction_degree}); - break; - } - case OP_COMBINE: { - int combine_dim, combine_degree; - assert(opx->get_pm_constraint(PM_COMBINE_DIM, combine_dim)); - assert(opx->get_pm_constraint(PM_COMBINE_DEGREE, combine_degree)); - op = model->get_or_create_node(inputs[0], - {combine_dim, combine_degree}); - break; - } - default: { - std::cout << "opx->type = " << get_operator_type_name(opx->type) - << std::endl; - assert(false); - } - } - // Check operator validness - if (op == Node::INVALID_NODE) { - return false; - } - // Check tnConstraints - for (size_t i = 0; i < opx->tnConstraints.size(); i++) { - TNConstraint tnc = opx->tnConstraints[i]; - int actValue = 0, expValue = 0; - if (tnc.singlePara) { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - expValue = tnc.value; - } else { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - assert(op.ptr->get_tensor_parameter(tnc.para2, tnc.dim2, &expValue)); - } - switch (tnc.comp) { - case COMPARE_EQ: - if (actValue != expValue) { - return false; - } - break; - case COMPARE_NE: - if (actValue == expValue) { - return false; - } - break; - case COMPARE_LT: - if (actValue >= expValue) { - return false; - } - break; - case COMPARE_LE: - if (actValue > expValue) { - return false; - } - break; - case COMPARE_GT: - if (actValue <= expValue) { - return false; - } - break; - case COMPARE_GE: - if (actValue < expValue) { - return false; - } - break; - default: - assert(false); - } - } - return true; -} - -OpX *GraphXfer::create_noop(TensorX const &input) { - OpX *noop = new OpX(OP_NOOP, 1, 1, input); - return noop; -} - -OpX *GraphXfer::create_concat(TensorX const *inputs, - int num_inputs, - OpX const *_matchOpX, - int concat_dim) { - OpX *concat = new OpX(OP_CONCAT, num_inputs, 1 /*outputs*/, inputs); - concat->matchOpX = _matchOpX; - concat->add_pm_constraint(COMPARE_EQ, PM_AXIS, concat_dim); - return concat; -} - -OpX *GraphXfer::create_element_unary(TensorX const &input, - OperatorType op_type) { - OpX *eu = new OpX(op_type, 1 /*numInputs*/, 1, input); - return eu; -} - -OpX *GraphXfer::create_relu(TensorX const &input) { - return this->create_element_unary(input, OP_RELU); -} - -OpX *GraphXfer::create_element_binary(TensorX const &input1, - TensorX const &input2, - OperatorType op_type) { - OpX *eb = new OpX(op_type, 2 /*numInputs*/, 1, input1, input2); - return eb; -} - -OpX *GraphXfer::create_linear(TensorX const &input, - OpX const *_matchOpX, - int num_dims, - ActiMode acti_mode, - bool use_bias) { - // TODO FIXME @lockshaw @zhihao use_bias is completely unused - OpX *li = new OpX(OP_LINEAR, 1, 1, input); - li->matchOpX = _matchOpX; - // li->add_pm_constraint(COMPARE_EQ, PM_OUTPUT_CHANNELS, out_channels); - li->add_pm_constraint(COMPARE_EQ, PM_ACTI, acti_mode); - li->add_input_constraint(COMPARE_EQ, INPUT_0, DIM_ND, num_dims); - return li; -} - -OpX *GraphXfer::create_conv2d(TensorX const &input, OpX const *matchOpX) { - OpX *conv = new OpX(OP_CONV2D, 1, 1, input); - conv->matchOpX = matchOpX; - return conv; -} - -OpX *GraphXfer::create_pool2d(TensorX const &input, OpX const *matchOpX) { - OpX *pool = new OpX(OP_POOL2D, 1, 1, input); - pool->matchOpX = matchOpX; - return pool; -} - -OpX *GraphXfer::create_attention(TensorX const &query, - TensorX const &key, - TensorX const &value, - OpX const *_matchOpX, - int num_heads) { - OpX *attn = new OpX(OP_MULTIHEAD_ATTENTION, 3, 1, query, key, value); - attn->matchOpX = _matchOpX; - attn->add_pm_constraint(COMPARE_EQ, PM_NUM_HEADS, num_heads); - attn->add_input_constraint(COMPARE_EQ, INPUT_0, DIM_ND, 4); - attn->add_input_constraint(COMPARE_EQ, INPUT_1, DIM_ND, 4); - attn->add_input_constraint(COMPARE_EQ, INPUT_2, DIM_ND, 4); - return attn; -} - -OpX *GraphXfer::create_softmax(TensorX const &input, int softmax_dim) { - OpX *softmax = new OpX(OP_SOFTMAX, 1, 1, input); - softmax->add_pm_constraint(COMPARE_EQ, PM_SOFTMAX_DIM, softmax_dim); - return softmax; -} - -OpX *GraphXfer::create_repartition(TensorX const &input, - int repartition_dim, - int num_parts) { - OpX *part = new OpX(OP_REPARTITION, 1, 1, input); - part->add_pm_constraint(COMPARE_EQ, PM_REPARTITION_DIM, repartition_dim); - part->add_pm_constraint(COMPARE_EQ, PM_REPARTITION_DEGREE, num_parts); - return part; -} - -OpX *GraphXfer::create_replicate(TensorX const &input, - int replicate_dim, - int num_parts) { - OpX *replicate = new OpX(OP_REPLICATE, 1, 1, input); - replicate->add_pm_constraint(COMPARE_EQ, PM_REPLICATE_DIM, replicate_dim); - replicate->add_pm_constraint(COMPARE_EQ, PM_REPLICATE_DEGREE, num_parts); - return replicate; -} - -OpX *GraphXfer::create_reduction(TensorX const &input, - int reduction_dim, - int num_parts) { - OpX *reduction = new OpX(OP_REDUCTION, 1, 1, input); - reduction->add_pm_constraint(COMPARE_EQ, PM_REDUCTION_DIM, reduction_dim); - reduction->add_pm_constraint(COMPARE_EQ, PM_REDUCTION_DEGREE, num_parts); - return reduction; -} - -OpX *GraphXfer::create_combine(TensorX const &input, - int combine_dim, - int num_parts) { - OpX *part = new OpX(OP_COMBINE, 1, 1, input); - part->add_pm_constraint(COMPARE_EQ, PM_COMBINE_DIM, combine_dim); - part->add_pm_constraint(COMPARE_EQ, PM_COMBINE_DEGREE, num_parts); - return part; -} - -void Graph::print_strategy_computation_graph( - std::unordered_map const &strategy) const { - DotFile dot(std::cout); - this->export_strategy_computation_graph(strategy, dot); -} - -void Graph::export_strategy_computation_graph( - std::unordered_map const &strategy, - std::string const &out_filename) const { - DotFile dot(out_filename); - - this->export_strategy_computation_graph(strategy, dot); -} - -void Graph::export_strategy_computation_graph( - std::unordered_map const &strategy, - DotFile &dot) const { - using FlexFlow::PCG::Utils::GraphStructure; - - GraphStructure s; - - for (auto const &node : s.get_nodes(*this)) { - // Add node - if (strategy.find(node) == strategy.end()) { - // Check FusedParallel node here and print out the detailed information - if (node.ptr->op_type == OperatorType::OP_FUSED_PARALLEL) { - RecordFormatter rf; - std::vector rows{}; - - FusedParallelOp *fused_op = (FusedParallelOp *)node.ptr; - for (int i = 0; i < fused_op->num_parallel_ops; i++) { - RecordFormatter row{}; - ParallelOpInfo op_info = fused_op->parallel_ops[i]; - std::string op_type_str = get_operator_type_name(op_info.op_type); - row << op_type_str << "dim: " + std::to_string(op_info.parallel_dim) - << "degree: " + std::to_string(op_info.parallel_degree); - rows.emplace_back(row); - } - rf << node.to_string(); - for (auto &r : rows) { - rf << r; - } - dot.add_record_node(node, rf); - } else { - dot.add_node(node, {{"label", node.to_string()}}); - } - } else { - RecordFormatter rf, meta_row, machine_view_row, runtime_code, memory_code, - runtime_cost_row, memory_cost_row; - MachineView mv = strategy.at(node); - std::ostringstream oss; - CostMetrics op_cost = - this->model->simulator->measure_operator_cost(node.ptr, mv); - switch (node.ptr->op_type) { - case OP_REPARTITION: { - Repartition *rp = (Repartition *)node.ptr; - meta_row << std::to_string(rp->repartition_dim) - << std::to_string(rp->repartition_degree); - break; - } - case OP_COMBINE: { - Combine *c = (Combine *)node.ptr; - meta_row << std::to_string(c->combine_dim) - << std::to_string(c->combine_degree); - break; - } - case OP_REPLICATE: { - Replicate *r = (Replicate *)node.ptr; - meta_row << std::to_string(r->replicate_dim) - << std::to_string(r->replicate_degree); - break; - } - case OP_REDUCTION: { - Reduction *r = (Reduction *)node.ptr; - meta_row << std::to_string(r->reduction_dim) - << std::to_string(r->reduction_degree); - break; - } - default: { - if (mv.ndims == 0) { - meta_row << "N/A"; - } else { - for (int i = 0; i < mv.ndims; i++) { - meta_row << std::to_string(mv.dim[i]); - } - } - } - } - - // Fetch machine view information - for (int device_id : mv.device_ids()) { - machine_view_row << std::to_string(device_id); - } - rf << node.to_string() << std::to_string(node.guid) << meta_row - << machine_view_row; - - // get memory cost - if (this->model->config.include_costs_dot_graph) { - float input_mem = (float)op_cost.inputs_memory; - if (node.ptr->numInputs > 0) { - input_mem /= (*node.ptr->inputs)->get_total_num_parts(); - } - float output_mem = (float)op_cost.outputs_memory; - if (node.ptr->numOutputs > 0) { - output_mem /= (*node.ptr->outputs)->get_total_num_parts(); - } - float weight_mem = (float)op_cost.weights_memory; - if (node.ptr->numWeights > 0) { - weight_mem /= (*node.ptr->weights)->get_total_num_parts(); - } - - runtime_code << "fwd" - << "bwd" - << "sync" - << "secs"; - runtime_cost_row << op_cost.forward_time << op_cost.backward_time - << op_cost.sync_time; - memory_code << "in" - << "out" - << "weight" - << "bytes"; - memory_cost_row << input_mem << output_mem << weight_mem; - rf << runtime_code << runtime_cost_row << memory_code - << memory_cost_row; - } - - dot.add_record_node(node, rf); - } - - // Add edges - for (auto const &edge : s.get_incoming_edges(*this, node)) { - dot.add_edge(s.get_src(*this, edge), s.get_dst(*this, edge)); - } - } - - dot.close(); -} - -template -void create_mapping_xfers( - FFModel *model, - int degree, - std::vector &xfers, - tl::optional> dims = tl::nullopt) { - std::vector records; - T::construct_output_mappings(records); - std::unordered_map output_mappings; - - std::unordered_set all_dims; - for (ParallelDimMappingRecord const &record : records) { - assert(record.input_idx == 0); - assert(record.get_type() == MappingRecordType::INPUT_OUTPUT); - assert(record.output_idx == 0); - assert(record.operation.has_value()); - - all_dims.insert(record.input_dim); - output_mappings.insert({record.input_dim, record}); - } - - if (dims.has_value()) { - all_dims = dims.value(); - } - - for (int const input_dim : all_dims) { - int output_dim = output_mappings.at(input_dim).output_dim; - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - - OpX *original_op = subst->create_opx(input, NULL /*matchOpX*/); - subst->srcOps.push_back(original_op); - - OpX *pre; - std::string pre_name; - switch (output_mappings.at(input_dim).operation.value()) { - case MappingOperation::PARTITION: - pre = subst->create_repartition(input, input_dim, degree); - pre_name = "partition"; - break; - case MappingOperation::REPLICATE: - pre = subst->create_replicate(input, input_dim, degree); - pre_name = "replicate"; - break; - } - subst->dstOps.push_back(pre); - - OpX *new_op = - subst->create_opx(pre->outputs[0], original_op /*matchOpX*/); - subst->dstOps.push_back(new_op); - - OpX *post; - std::string post_name; - switch (output_mappings.at(input_dim).operation.value()) { - case MappingOperation::PARTITION: - post = subst->create_combine(new_op->outputs[0], output_dim, degree); - post_name = "combine"; - break; - case MappingOperation::REPLICATE: - post = subst->create_reduction(new_op->outputs[0], output_dim, degree); - post_name = "reduce"; - break; - } - subst->dstOps.push_back(post); - - subst->map_output(original_op->outputs[0], post->outputs[0]); - - std::ostringstream oss; - std::string op_type_name = get_operator_type_name(new_op->type); - std::transform(op_type_name.begin(), - op_type_name.end(), - op_type_name.begin(), - [](unsigned char c) { return std::tolower(c); }); - oss << "mapping::" << pre_name << "_" << op_type_name << "_" << post_name - << "[" - << "input_dim=" << input_dim << ",degree=" << degree << "]"; - subst->name = oss.str(); - - xfers.push_back(subst); - } -} - -std::string GraphXfer::get_name() const { - if (this->name.has_value()) { - return this->name.value(); - } else { - std::ostringstream oss; - oss << "unknown_xfer(" << this << ")"; - return oss.str(); - } -} - -/* int get_num_outputs(sl::Operator const &op) { */ -/* switch (op.op_type) { */ -/* case OP_SPLIT: */ -/* return op.at(PM_NUM_OUTPUTS).value(); */ -/* default: */ -/* return 1; */ -/* } */ -/* } */ - -/* int get_num_inputs(sl::Operator const &op) { */ -/* switch (op.op_type) { */ -/* case OP_EW_ADD: // binary ops */ -/* case OP_EW_SUB: */ -/* case OP_EW_MUL: */ -/* case OP_EW_DIV: */ -/* case OP_EW_EQUAL: */ -/* case OP_EW_GREATER: */ -/* case OP_EW_LESS: */ -/* case OP_EW_MAX: */ -/* case OP_EW_MIN: */ -/* return 2; */ -/* case OP_SPLIT: */ -/* return 1; */ -/* case OP_LINEAR: */ -/* return 1; */ -/* case OP_CONV2D: */ -/* return 1; */ -/* case OP_RELU: */ -/* case OP_IDENTITY: */ -/* case OP_SIGMOID: */ -/* case OP_TANH: */ -/* case OP_ELU: */ -/* return 1; */ -/* case OP_CONCAT: */ -/* return op.at(PM_NUM_INPUTS).value(); */ -/* case OP_INPUT: */ -/* return 0; */ -/* case OP_REPARTITION: */ -/* case OP_COMBINE: */ -/* case OP_REPLICATE: */ -/* case OP_REDUCTION: */ -/* case OP_PIPELINE: */ -/* return 1; */ -/* default: */ -/* throw std::runtime_error("Unknown num_inputs for operator " + */ -/* get_operator_type_name(op.op_type)); */ -/* } */ -/* } */ - -OpX *create_opx(sl::Operator const &op, - int parallel_degree, - TensorX const &input1, - TensorX const &input2, - TensorX const &input3, - TensorX const &input4) { - int num_inputs = get_num_inputs(op); - int num_outputs = get_num_outputs(op); - - OpX *opx = new OpX( - op.op_type, num_inputs, num_outputs, input1, input2, input3, input4); - for (sl::Parameter const &p : op.para) { - if (p.key == PM_PARALLEL_DEGREE) { - tl::optional degree_key = tl::nullopt; - switch (op.op_type) { - case OP_REPARTITION: - degree_key = PM_REPARTITION_DEGREE; - break; - case OP_COMBINE: - degree_key = PM_COMBINE_DEGREE; - break; - case OP_REDUCTION: - degree_key = PM_REDUCTION_DEGREE; - break; - case OP_REPLICATE: - degree_key = PM_REPLICATE_DEGREE; - break; - } - - if (degree_key.has_value()) { - // Assume the generator only consider a parallel degree of 2 - assert(p.value == 2); - opx->add_pm_constraint(COMPARE_EQ, degree_key.value(), parallel_degree); - } - } else if (p.key == PM_PARALLEL_DIM) { - tl::optional dim_key = tl::nullopt; - switch (op.op_type) { - case OP_REPARTITION: - dim_key = PM_REPARTITION_DIM; - break; - case OP_COMBINE: - dim_key = PM_COMBINE_DIM; - break; - case OP_REDUCTION: - dim_key = PM_REDUCTION_DIM; - break; - case OP_REPLICATE: - dim_key = PM_REPLICATE_DIM; - break; - } - - if (dim_key.has_value()) { - opx->add_pm_constraint(COMPARE_EQ, dim_key.value(), p.value); - } - } else if (p.key == PM_PAD) { - opx->add_pm_constraint(COMPARE_EQ, PM_PADDING_H, p.value); - opx->add_pm_constraint(COMPARE_EQ, PM_PADDING_W, p.value); - } else { - opx->add_pm_constraint(COMPARE_EQ, p.key, p.value); - } - } - - return opx; -} - -OpX *find_opx_with_type(std::vector const &src_ops, - OperatorType op_type) { - OpX *matchOpX = nullptr; - for (size_t k = 0; k < src_ops.size(); k++) { - if (src_ops[k]->type == op_type) { - assert(matchOpX == nullptr); - matchOpX = src_ops[k]; - } - } - assert(matchOpX != nullptr); - return matchOpX; -} - -std::vector - create_rule_graph(GraphXfer &xfer, - std::vector const &ops, - std::function const &get_input_tensor, - std::vector *const src_ops, - int parallel_degree) { - std::vector rule_graph; - - for (int i = 0; i < ops.size(); i++) { - sl::Operator const &op = ops[i]; - std::array inputs; - std::fill(inputs.begin(), inputs.end(), TensorX::NO_TX); - - for (int j = 0; j < op.input.size(); j++) { - int opId = op.input[j].opId; - int tsId = op.input[j].tsId; - if (opId < 0) { - inputs[j] = get_input_tensor(opId, tsId); - } else { - inputs[j] = rule_graph[opId]->outputs[tsId]; - } - } - - // We need the matched OpX for constructing conv2d/pool2d/linear - OpX *opx = nullptr; - switch (ops[i].op_type) { - case OP_CONV2D: { - OpX *matchOpX = src_ops == nullptr - ? nullptr - : find_opx_with_type(*src_ops, ops[i].op_type); - opx = xfer.create_conv2d(inputs[0], matchOpX); - break; - } - case OP_POOL2D: { - OpX *matchOpX = src_ops == nullptr - ? nullptr - : find_opx_with_type(*src_ops, ops[i].op_type); - opx = xfer.create_pool2d(inputs[0], matchOpX); - break; - } - default: - opx = create_opx(ops[i], - parallel_degree, - inputs[0], - inputs[1], - inputs[2], - inputs[3]); - } - rule_graph.push_back(opx); - } - - return rule_graph; -} - -void create_xfer(GraphXfer &xfer, sl::Rule const &r, int parallel_degree) { - std::unordered_map, TensorX> input_tensors; - std::function get_input_tensor = - [&xfer, &input_tensors](int opId, int tsId) -> TensorX { - if (input_tensors.find({opId, tsId}) == input_tensors.end()) { - input_tensors[{opId, tsId}] = xfer.new_tensor(); - } - return input_tensors.at({opId, tsId}); - }; - - xfer.srcOps = create_rule_graph( - xfer, r.srcOp, get_input_tensor, nullptr, parallel_degree); - xfer.dstOps = create_rule_graph( - xfer, r.dstOp, get_input_tensor, &xfer.srcOps, parallel_degree); - xfer.name = r.name; - if (xfer.srcOps.size() == 1) { - printf("Here!\n"); - } - - for (sl::MapOutput const &m : r.mappedOutput) { - TensorX srcTensorX = xfer.srcOps[m.srcOpId]->outputs[m.srcTsId]; - TensorX dstTensorX = xfer.dstOps[m.dstOpId]->outputs[m.dstTsId]; - xfer.map_output(srcTensorX, dstTensorX); - } -} - -bool check_opxes_have_same_type_and_constraints(OpX const &src_opx, - OpX const &dst_opx) { - if (src_opx.type != dst_opx.type) { - return false; - } - if (src_opx.pmConstraints.size() != dst_opx.pmConstraints.size()) { - return false; - } - if (src_opx.tnConstraints.size() != dst_opx.tnConstraints.size()) { - return false; - } - for (auto const &c1 : src_opx.pmConstraints) { - bool found_same = false; - for (auto const &c2 : dst_opx.pmConstraints) { - if (c1.comp == c2.comp && c1.para == c2.para && c1.value == c2.value) { - found_same = true; - } - } - if (!found_same) { - return false; - } - } - for (auto const &c1 : src_opx.tnConstraints) { - bool found_same = false; - for (auto const &c2 : dst_opx.tnConstraints) { - if (c1.singlePara && c2.singlePara) { - if (c1.comp == c2.comp && c1.para1 == c2.para1 && c1.dim1 == c2.dim1 && - c1.value == c2.value) { - found_same = true; - } - } else if ((!c1.singlePara) && (!c2.singlePara)) { - if (c1.comp == c2.comp && c1.para1 == c2.para1 && - c1.para2 == c2.para2 && c1.dim1 == c2.dim1 && c1.dim2 == c2.dim2) { - found_same = true; - } - } - } - if (!found_same) { - return false; - } - } - - return true; -} - -std::vector create_xfers(FFModel *model, - sl::RuleCollection const &rules, - int parallel_degree) { - std::vector xfers; - for (sl::Rule const &r : rules.rules) { - GraphXfer *xfer = new GraphXfer(model); - create_xfer(*xfer, r, parallel_degree); - if (xfer->srcOps.size() == 1 && xfer->dstOps.size() == 1) { - delete xfer; - continue; - } - // Pruning redundant xfer - bool found_same_xfer = false; - for (auto const &old_xfer : xfers) { - bool same = true; - if (old_xfer->srcOps.size() != xfer->srcOps.size()) { - same = false; - continue; - } - for (size_t i = 0; i < old_xfer->srcOps.size(); i++) { - if (!check_opxes_have_same_type_and_constraints(*old_xfer->srcOps[i], - *xfer->srcOps[i])) { - same = false; - } - } - if (!same) { - continue; - } - if (old_xfer->dstOps.size() != xfer->dstOps.size()) { - same = false; - continue; - } - for (size_t i = 0; i < old_xfer->dstOps.size(); i++) { - if (!check_opxes_have_same_type_and_constraints(*old_xfer->dstOps[i], - *xfer->dstOps[i])) { - same = false; - } - } - if (same) { - found_same_xfer = true; - break; - } - } - if (!found_same_xfer && xfer->srcOps.size() == 1) { - xfers.push_back(xfer); - } else { - delete (xfer); - } - } - return xfers; -} - -GraphSearchHelper::GraphSearchHelper(FFModel *model) - : model(model), config(model->config), mem_config(1.0) { - this->logger = std::unique_ptr(new RecursiveLogger("gs")); - generate_all_pcg_xfers(); -} - -void GraphSearchHelper::clear_cache() { - cached_optimized_graphs.clear(); -} - -void GraphSearchHelper::load_graph_substitutions( - std::vector &xfers) const { - xfers = all_pcg_xfers; -} - -void GraphSearchHelper::generate_all_pcg_xfers() { - std::vector all_parallel_degrees, single_node_parallel_degrees; - auto const &config = this->model->config; - int workersPerNode = - config.search_num_workers.value_or(config.workersPerNode); - int numNodes = config.search_num_nodes.value_or(config.numNodes); - log_xfers.debug() << "Generating parallel degrees for workersPerNode " - << workersPerNode << " and numNodes " << numNodes; - for (int i = 2; i <= workersPerNode; i++) { - if (workersPerNode % i == 0) { - single_node_parallel_degrees.push_back(i); - all_parallel_degrees.push_back(i); - } - } - for (int i = 2; i <= numNodes; i++) { - if (numNodes % i == 0) { - all_parallel_degrees.push_back(i * workersPerNode); - } - } - { - std::ostringstream oss; - oss << "Generating all_pcg_xfers for all parallel degrees: "; - for (int parallel_degree : all_parallel_degrees) { - oss << parallel_degree << " "; - } - - log_xfers.debug() << oss.str(); - } - - for (auto const &it : single_node_parallel_degrees) { - all_pcg_xfers.push_back(create_replicate_linear_combine( - this->model, 3, it, AC_MODE_RELU, false)); - all_pcg_xfers.push_back(create_replicate_linear_combine( - this->model, 3, it, AC_MODE_SIGMOID, false)); - all_pcg_xfers.push_back(create_replicate_linear_combine( - this->model, 3, it, AC_MODE_NONE, false)); - if (16 % it == 0) { - all_pcg_xfers.push_back( - create_replicate_attention_reduce(this->model, 16 /*num_heads*/, it)); - } - } - for (auto const &it : all_parallel_degrees) { - all_pcg_xfers.push_back( - create_partition_attention_combine(this->model, 16 /*num_heads*/, it)); - } - - if (config.substitution_json_path.has_value()) { - // Currently only consider a subset of all_parallel_degrees - std::vector considered_parallel_degrees; - considered_parallel_degrees.push_back(workersPerNode); - if (numNodes > 1) { - considered_parallel_degrees.push_back(numNodes * workersPerNode); - } - sl::RuleCollection rule_collection = sl::load_rule_collection_from_path( - config.substitution_json_path.value()); - for (int degree : considered_parallel_degrees) { - std::vector xfers = - create_xfers(this->model, rule_collection, degree); - all_pcg_xfers.insert(all_pcg_xfers.end(), xfers.begin(), xfers.end()); - } - } else { - // Manual substitutions - for (int num_dims = 3; num_dims <= 4; num_dims++) { - all_pcg_xfers.push_back( - create_linear_relu_merge(this->model, num_dims, true)); - all_pcg_xfers.push_back( - create_linear_relu_merge(this->model, num_dims, false)); - } - for (int const degree : all_parallel_degrees) { - create_mapping_xfers(this->model, degree, all_pcg_xfers); - create_mapping_xfers(this->model, degree, all_pcg_xfers); - create_mapping_xfers(this->model, degree, all_pcg_xfers); - } - for (auto const &it : all_parallel_degrees) { - // rewrites for the inception model - for (int i = 3; i <= 6; i++) { - all_pcg_xfers.push_back(create_combine_inception( - this->model, i - 1 /*num_convs*/, 5 /*num_dims*/, it)); - all_pcg_xfers.push_back(create_combine_concat( - this->model, i /*num_inputs*/, 5 /*num_dims*/, it)); - } - // all_pcg_xfers.push_back(create_partition_conv2d_combine(this->model, - // 5/*num_dims*/, it)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 3 /*num_dims*/, it, AC_MODE_RELU, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 3 /*num_dims*/, it, AC_MODE_SIGMOID, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 3 /*num_dims*/, it, AC_MODE_NONE, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 4 /*num_dims*/, it, AC_MODE_RELU, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 4 /*num_dims*/, it, AC_MODE_SIGMOID, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 4 /*num_dims*/, it, AC_MODE_NONE, false)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 1 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 2 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 3 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 4 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_relu_combine( - this->model, 3 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_relu_combine( - this->model, 4 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back( - create_partition_softmax_combine(this->model, - 0 /*softmax_dim*/, - 1 /*parallel_dims*/, - it /*num_parts*/)); - for (int num_combines = 1; num_combines < 5; num_combines++) { - all_pcg_xfers.push_back(leading_relu_branch_combine( - this->model, 3 /*parallel_dim*/, it /*num_parts*/, num_combines)); - all_pcg_xfers.push_back(leading_relu_branch_partition( - this->model, 3 /*parallel_dim*/, it /*num_parts*/, num_combines)); - } - { - std::unordered_set concat_num_inputs; - for (size_t i = 0; i < this->model->operators.size(); i++) { - if (this->model->operators[i]->op_type == OP_CONCAT) { - concat_num_inputs.insert(this->model->operators[i]->numInputs); - } - } - for (auto const &it2 : concat_num_inputs) { - all_pcg_xfers.push_back( - create_partition_concat_combine(this->model, - it2 /*num_inputs*/, - 0 /*concat_dim*/, - 1 /*parallel_dims*/, - it /*num_parts*/)); - all_pcg_xfers.push_back( - create_partition_concat_combine(this->model, - it2 /*num_inputs*/, - 2 /*concat_dim*/, - 3 /*parallel_dims*/, - it /*num_parts*/)); - } - } - } - } -} - -Graph *GraphSearchHelper::construct_graph() { - Graph *graph = new Graph(this->model); - std::unordered_map op_to_node_map; - for (FlexFlow::Op const *dstOp : this->model->operators) { - Node dstNode; - dstNode.ptr = dstOp; - dstNode.guid = this->model->node_global_guid++; - op_to_node_map[dstOp] = dstNode; - for (int j = 0; j < dstOp->numInputs; j++) { - FlexFlow::Op const *srcOp = dstOp->inputs[j]->owner_op; - assert(op_to_node_map.find(srcOp) != op_to_node_map.end()); - Node srcNode = op_to_node_map[srcOp]; - graph->add_edge(srcNode, dstNode, dstOp->inputs[j]->owner_idx, j); - } - } - - return graph; -} - -/** - * @brief Unity search algorithm main entrance. - * - * @param[in] budget Not used - * @param[in] only_data_parallel Not used - * @param[out] best_graph The best possible PCG after optimization - * @param[out] optimal_views The corresponding device placement views of the - * best graph - */ -void GraphSearchHelper::graph_optimize( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views) { - // Construct graph structure - this->logger->debug() << "Starting graph optimization"; - - Graph *graph = this->construct_graph(); - graph->duplicate_input_nodes(); - std::unordered_map empty_strategy; - if (!this->config.export_strategy_computation_graph_file.empty()) { - graph->export_strategy_computation_graph( - empty_strategy, this->config.export_strategy_computation_graph_file); - } - - Node sink_node = graph->find_sink_node(); - GraphOptimizeResult optimal = - this->generic_sequence_optimize( - graph, - sink_node, - tl::nullopt /*output_shape*/, - tl::nullopt /*input_shape*/); - this->logger->debug() << "Total cache size: " - << this->cached_optimized_graphs.size(); - std::cout << "Optimal cost: " << optimal.cost << std::endl; - SimplificationSettings settings; - settings.fuse_parallel_ops = true; - settings.remove_noops = true; - settings.remove_trailing_parallel_ops = true; - settings.simplify_parallel_ops = true; - best_graph = std::unique_ptr(new Graph(optimal.graph.value())); - best_graph->simplify(settings); - std::unordered_map duplicated_optimal_views = - best_graph->optimal_views(); - std::unordered_map deduplication_map = - best_graph->deduplicate_input_nodes(); - std::unordered_map real_optimal_views; - for (auto const &kv : duplicated_optimal_views) { - if (deduplication_map.find(kv.first) != deduplication_map.end()) { - real_optimal_views[deduplication_map.at(kv.first)] = kv.second; - } else { - real_optimal_views[kv.first] = kv.second; - } - } - best_graph->print_strategy_computation_graph(optimal.views); - optimal_views = real_optimal_views; -} - -/** - * @brief Experimental DP algorithm to optimize PCG with the consideration of - * memory usage. This is to avoid polluting the current Unity search algorithm - * above. And this should be merged to GraphSearchHelper::graph_optimize - * eventually. - * - * @param[in] budget Not used - * @param[in] only_data_parallel Not used - * @param[out] best_graph The best possible PCG after optimization - * @param[out] optimal_views The corresponding device placement views of the - * best graph - * @param[out] search_result The performance result of the search - */ -void GraphSearchHelper::graph_optimize_with_memory( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views, - MemorySearchResult &search_result) { - this->logger->debug() - << "Starting graph optimization with memory consideration"; - - // Construct graph structure - Graph *graph = this->construct_graph(); - - // The input nodes may need to be duplicated because the PCG was constructed - // to have one input node for one input, but the actual execution graph should - // have the distributed version of inputs (i.e. multiple nodes). - graph->duplicate_input_nodes(); - - // Export an empty schedule if needed. - std::unordered_map empty_strategy; - if (!this->config.export_strategy_computation_graph_file.empty()) { - graph->export_strategy_computation_graph( - empty_strategy, this->config.export_strategy_computation_graph_file); - } - - Node sink_node = graph->find_sink_node(); - - auto const start = std::chrono::system_clock::now(); - GraphOptimizeResultWithMemory optimal = - this->generic_sequence_optimize_with_memory< - GraphOptimizeResultWithMemory>( - graph, sink_node, tl::nullopt, tl::nullopt); - auto const end = std::chrono::system_clock::now(); - - this->logger->debug() << "Total cache size: " - << this->cached_optimized_graphs.size(); - std::cout << "Optimal run time cost: " << optimal.cost - << ", Memory usage: " << optimal.mem_cost - << " | run_time_cost_factor: " - << this->mem_config.run_time_cost_factor << std::endl; - - // Save the search performance results to the output argument - search_result.run_time_cost = optimal.cost; - search_result.memory_cost = optimal.mem_cost.num; - search_result.search_time = - std::chrono::duration_cast(end - start) - .count(); - - // Further simplify the "optimal" graph/schedule to have a more efficient - // graph and more accurate cost. - best_graph = std::unique_ptr(new Graph(optimal.graph.value())); - SimplificationSettings settings; - // Simplify to consider parallel op fusion - settings.fuse_parallel_ops = true; - settings.remove_noops = true; - settings.remove_trailing_parallel_ops = true; - settings.simplify_parallel_ops = true; - best_graph->simplify(settings); - - // Get the real optimal machine views. - std::unordered_map duplicated_optimal_views = - best_graph->optimal_views(); - std::unordered_map deduplication_map = - best_graph->deduplicate_input_nodes(); - std::unordered_map real_optimal_views; - for (auto const &kv : duplicated_optimal_views) { - if (deduplication_map.find(kv.first) != deduplication_map.end()) { - real_optimal_views[deduplication_map.at(kv.first)] = kv.second; - } else { - real_optimal_views[kv.first] = kv.second; - } - } - std::cout << "Dot graph of searched strategy:" << std::endl; - best_graph->print_strategy_computation_graph(optimal.views); - std::cout << std::endl; - - optimal_views = real_optimal_views; -} - -void GraphSearchHelper::graph_optimize_no_split( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views) { - // Construct graph structure - this->logger->debug() << "Starting graph optimization without split"; - - Graph *graph = this->construct_graph(); - std::unordered_map empty_strategy; - if (!this->config.export_strategy_computation_graph_file.empty()) { - graph->export_strategy_computation_graph( - empty_strategy, this->config.export_strategy_computation_graph_file); - } - - SimplificationSettings settings; - settings.simplify_parallel_ops = true; - best_graph = this->base_optimize(graph, settings); - optimal_views = best_graph->optimal_views(); - - this->logger->debug() << "Total cache size: " - << this->cached_optimized_graphs.size(); - std::cout << "Optimal cost: " << best_graph->optimal_cost() << std::endl; -} - -static void graph_log_representation(Graph const *graph, - RecursiveLogger &logger) { - using FlexFlow::PCG::Utils::topo_sort; - - std::vector topo_sorted; - topo_sort(*graph, &topo_sorted); - std::ostringstream oss; - for (Node const &n : topo_sorted) { - logger.spew() << n.to_string(); - } -} - -void GraphSearchHelper::update_mem_optim_config( - MemoryOptimConfig const &new_config) { - mem_config = new_config; -} - -void GraphSearchHelper::find_rewrite_matches( - Graph const *graph, std::vector &matches) const { - std::vector xfers; - this->load_graph_substitutions(xfers); - - for (GraphXfer *xfer : xfers) { - log_xfer_matches.debug() - << "Finding matches for xfer: " << xfer->get_name(); - xfer->find_matches(graph, matches); - } - log_xfer_matches.debug() << "Finished finding xfer matches"; -} - -tl::optional - GraphSearchHelper::find_split_node(Graph const *graph, - int base_optimize_threshold) const { - using FlexFlow::PCG::Utils::get_edges; - using FlexFlow::PCG::Utils::MultisourceGraphStructure; - using FlexFlow::PCG::Utils::nodes; - using FlexFlow::PCG::Utils::post_dominators; - using FlexFlow::PCG::Utils::roots; - - TAG_ENTER(this->logger); - - int graph_size = nodes(*graph).size(); - this->logger->debug() << "Finding split node for graph (size " << graph_size - << ") with threshold " << base_optimize_threshold; - - if (graph_size <= base_optimize_threshold) { - this->logger->debug() - << "Graph size underneath threshold. Returning nullopt"; - return tl::nullopt; - } - - std::vector edges = get_edges(*graph); - std::unordered_map edge_scores; - - for (Edge const &e : edges) { - edge_scores[e] = 0; - } - - std::vector matches; - this->find_rewrite_matches(graph, matches); - this->logger->debug() << "Found " << matches.size() << " rewrite matches"; - { - TAG_ENTER(this->logger); - for (GraphXferMatch const &match : matches) { - auto msg = this->logger->spew(); - msg << match.get_xfer()->get_name() << " : "; - std::unordered_set nodes = match.get_nodes(); - for (Node const &node : nodes) { - msg << node.to_string() << " "; - } - } - } - - for (GraphXferMatch const &match : matches) { - for (Edge const &e : edges) { - if (match.containsEdge(graph, e)) { - edge_scores[e]++; - } - } - } - - this->logger->debug() << "Edge weights: "; - - { - TAG_ENTER(this->logger); - for (Edge const &e : edges) { - this->logger->debug() << e.srcOp.to_string() << "/" << e.srcIdx << " -> " - << e.dstOp.to_string() << "/" << e.dstIdx << " : " - << edge_scores.at(e); - } - } - - std::unordered_map> post_dominator_map = - post_dominators>(*graph); - Node source_node; - { - std::unordered_set source_nodes = roots(*graph); - if (source_nodes.size() != 1) { - source_nodes = roots>(*graph); - } - assert(source_nodes.size() == 1); - source_node = *source_nodes.begin(); - } - std::unordered_set possible_bottlenecks = - post_dominator_map.at(source_node); - Node sink_node = graph->find_sink_node(); - - int best_weight = 0; - tl::optional best = tl::nullopt; - int best_size = graph_size; - { - TAG_ENTER(this->logger); - - for (Node const &possible_bottleneck : possible_bottlenecks) { - if (possible_bottleneck == sink_node || - possible_bottleneck == source_node) { - continue; - } - - int weight = 0; - for (Edge const &e : graph->outEdges.at(possible_bottleneck)) { - weight += edge_scores.at(e); - } - this->logger->debug() - << "Potential bottleneck node " << possible_bottleneck.to_string() - << " has weight " << weight; - if (weight < best_weight) { - best_weight = weight; - best = possible_bottleneck; - } else if (weight == best_weight) { - // break ties by trying to choosing the split that produces the - // pre_graph with size closest to the threshold, favoring everything - // with smaller size over everything with larger size - std::unique_ptr pre_graph, post_graph; - std::tie(pre_graph, post_graph) = - graph->split_at_node(possible_bottleneck); - int current_size = nodes(*pre_graph).size(); - - bool best_is_under = best_size <= base_optimize_threshold; - bool current_is_under = current_size <= base_optimize_threshold; - - bool condition1 = current_is_under && !best_is_under; - bool condition2 = - current_is_under && best_is_under && current_size > best_size; - bool condition3 = - !current_is_under && !best_is_under && current_size < best_size; - - if (condition1 || condition2 || condition3) { - best_weight = weight; - best = possible_bottleneck; - best_size = current_size; - } - } - } - } - - return best; -} - -/** - * @brief Base case of Unity's DP search algorithm. - * - * @param r_graph Graph to be optimized - * @param simplification_settings Settings to simplify the PCG - * @return std::unique_ptr Optimized PCG - */ -std::unique_ptr GraphSearchHelper::base_optimize( - Graph const *r_graph, - SimplificationSettings const &simplification_settings) { - // Construct graph substitutions - TAG_ENTER(this->logger); - - this->logger->debug() << "Optimizing base graph: "; - { - TAG_ENTER(this->logger); - /* graph_log_representation(r_graph, *this->logger); */ - // r_graph->print_dot(); - } - this->logger->debug() << "Starting cost: " << r_graph->optimal_cost(); - - std::vector xfers; - this->load_graph_substitutions(xfers); - - Graph *graph = new Graph(*r_graph); - - std::priority_queue, GraphCompare> candidates; - std::unordered_set hashmap; - candidates.push(graph); - hashmap.insert(graph->hash()); - Graph *best_graph = new Graph(*graph); - float best_cost = best_graph->optimal_cost(); - int counter = 0; - float const alpha = this->model->config.search_alpha; - - int budget = model->config.search_budget; - if (budget == 0) { - log_xfers.warning() - << "Base search budget is set to 0. This is probably not what you want " - "(use the --budget flag to set the base search budget)"; - } - for (int iter = 0; iter < budget || budget == -1; iter++) { - log_xfers.spew() << "Considering " << candidates.size() << " candidates"; - if (candidates.empty()) { - break; - } - - Graph *cur_graph = candidates.top(); - candidates.pop(); - if (cur_graph->optimal_cost() < best_graph->optimal_cost()) { - delete best_graph; - best_graph = cur_graph; - best_cost = cur_graph->optimal_cost(); - } else if (cur_graph->optimal_cost() > best_cost * alpha) { - continue; - } - - log_xfers.info("[%d] cur_cost(%.4lf) best_cost(%.4lf) candidates.size(%zu)", - counter, - cur_graph->optimal_cost(), - best_cost, - candidates.size()); - - log_xfers.debug() << "Considering " << xfers.size() << " possible xfers"; - for (size_t i = 0; i < xfers.size(); i++) { - int num_matches_found = 0, num_matches_rejected = 0; - log_xfers.debug() << "Considering xfer: " << xfers[i]->get_name(); - xfers[i]->run(0, - cur_graph, - candidates, - hashmap, - best_cost * alpha, - 1000, - simplification_settings, - num_matches_found, - num_matches_rejected); - log_xfers.debug() << "Rejected [ " << num_matches_rejected << " / " - << num_matches_found << " ] matches"; - /* std::cout << "." << std::flush; */ - } - /* std::cout << std::endl; */ - if (best_graph != cur_graph) { - delete cur_graph; - } - } - - this->logger->debug() << "Optimized cost: " << best_graph->optimal_cost(); - // best_graph->print_dot(); - return std::unique_ptr(best_graph); -} - -/** - * @brief Experimental. Base case of Unity's DP search algorithm with - * memory consideration. - * - * @param r_graph Graph to be optimized - * @param simplification_settings Settings to simplify the resulting PCG - * @return std::unique_ptr Optimized PCG - */ -std::unique_ptr GraphSearchHelper::base_optimize_with_memory( - Graph const *r_graph, - SimplificationSettings const &simplification_settings) { - TAG_ENTER(this->logger); - this->logger->debug() << "Optimizing base graph with memory: "; - { - TAG_ENTER(this->logger); - /* graph_log_representation(r_graph, *this->logger); */ - // r_graph->print_dot(); - } - this->logger->debug() << "Starting cost: " - << r_graph->optimal_cost_with_memory( - mem_config.run_time_cost_factor); - - // Construct graph substitutions - std::vector xfers; - this->load_graph_substitutions(xfers); - - // Prepare for the search - std::priority_queue, GraphCompareWithMemory> - candidates(GraphCompareWithMemory{mem_config.run_time_cost_factor}); - std::unordered_set hashmap; - - Graph *graph = new Graph(*r_graph); - candidates.push(graph); - hashmap.insert(graph->hash()); - - Graph *best_graph = new Graph(*graph); - float best_cost = - best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); - - int counter = 0; - float const alpha = this->model->config.search_alpha; - int budget = model->config.search_budget; - if (budget == 0) { - log_xfers.warning() - << "Base search budget is set to 0. This is probably not what you want " - "(use the --budget flag to set the base search budget)"; - } - - // Actual exploration - for (int iter = 0; iter < budget || budget == -1; iter++) { - log_xfers.spew() << "Considering " << candidates.size() - << " candidates in base_optimize_with_memory"; - if (candidates.empty()) { - break; - } - - Graph *cur_graph = candidates.top(); - candidates.pop(); - if (cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor) < - best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor)) { - delete best_graph; - best_graph = cur_graph; - best_cost = - cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); - } else if (cur_graph->optimal_cost_with_memory( - mem_config.run_time_cost_factor) > best_cost * alpha) { - continue; - } - - log_xfers.info( - "[%d] cur_cost(%.4lf) best_cost(%.4lf) candidates.size(%zu)", - counter, - cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor), - best_cost, - candidates.size()); - - log_xfers.debug() << "Considering " << xfers.size() - << " possible xfers in base_optimize_with_memory"; - for (size_t i = 0; i < xfers.size(); i++) { - int num_matches_found = 0, num_matches_rejected = 0; - log_xfers.debug() << "Considering xfer: " << xfers[i]->get_name(); - xfers[i]->run(0, - cur_graph, - candidates, - hashmap, - best_cost * alpha, - 1000, - simplification_settings, - num_matches_found, - num_matches_rejected); - log_xfers.debug() << "Rejected [ " << num_matches_rejected << " / " - << num_matches_found << " ] matches"; - } - - if (best_graph != cur_graph) { - delete cur_graph; - } - } - - this->logger->debug() - << "Optimized cost at the end of base_optimize_with_memory: " - << best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); - - return std::unique_ptr(best_graph); -} - -size_t gs_dp_state_hash(Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - size_t key = graph->hash(); - hash_combine(key, sink_node.ptr); - hash_combine(key, output_shape); - hash_combine(key, input_shape); - return key; -} - -float GraphSearchHelper::sequence_optimize( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - return this->generic_sequence_optimize( - graph, sink_node, output_shape, input_shape); -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache(size_t hash) const { - if (this->cached_optimized_graphs.find(hash) == - this->cached_optimized_graphs.end()) { - return tl::nullopt; - } else { - return this->cached_optimized_graphs.at(hash); - } -} - -template <> -float GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - return optimized->generic_optimal_cost(); -} - -template <> -GraphCostResult GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - return optimized->generic_optimal_cost(); -} - -template <> -GraphOptimizeResult GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - GraphOptimizeResult result; - result.graph = *optimized; - GraphCostResult gcr = optimized->generic_optimal_cost(); - result.cost = gcr.cost; - result.views = gcr.views; - return result; -} - -template <> -GraphOptimizeResultWithMemory - GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - GraphOptimizeResultWithMemory result; - result.graph = *optimized; - GraphCostResultWithMemory gcr = - optimized->generic_optimal_cost(); - result.cost = gcr.cost; - result.views = gcr.views; - result.mem_cost = gcr.mem_cost; - return result; -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache( - size_t hash) const { - return tl::nullopt; -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache( - size_t hash) const { - return tl::nullopt; -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache( - size_t hash) const { - return tl::nullopt; -} - -template <> -void GraphSearchHelper::try_cache_result(size_t hash, - float const &value) { - this->cached_optimized_graphs[hash] = value; -} - -template <> -void GraphSearchHelper::try_cache_result( - size_t hash, GraphCostResult const &value) {} - -template <> -void GraphSearchHelper::try_cache_result( - size_t hash, GraphOptimizeResult const &value) {} - -template <> -void GraphSearchHelper::try_cache_result( - size_t hash, GraphOptimizeResultWithMemory const &value) {} - -/** - * @brief Get the cost/result of PCG if sequentially split it. - * - * @details This function is to combine the search results from DP sub-problems. - * The sub-problems are solved by generic_sequence_optimize(). - */ -template -T GraphSearchHelper::execute_sequence_split( - std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - tl::optional const &output_shape, - tl::optional const &input_shape, - Node const &sink_node, - Node const &bottleneck, - ParallelTensorShape const &bottleneck_output_shape) { - return sequence_cost( - this->generic_sequence_optimize( - pre_graph.get(), bottleneck, bottleneck_output_shape, input_shape), - this->generic_sequence_optimize( - post_graph.get(), sink_node, output_shape, bottleneck_output_shape)); -} - -/** - * @brief Experimental. Consider memory usage when spliting the PCG during the - * DP search. This should be merged with execute_sequence_split(). - */ -template -T GraphSearchHelper::execute_sequence_split_with_memory( - std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - tl::optional const &output_shape, - tl::optional const &input_shape, - Node const &sink_node, - Node const &bottleneck, - ParallelTensorShape const &bottleneck_output_shape) { - return sequence_cost( - this->generic_sequence_optimize_with_memory( - pre_graph.get(), bottleneck, bottleneck_output_shape, input_shape), - this->generic_sequence_optimize_with_memory( - post_graph.get(), sink_node, output_shape, bottleneck_output_shape)); -} - -/** - * @brief Top level DP search procedure for Unity. - */ -template -T GraphSearchHelper::generic_sequence_optimize( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - /* int starting_depth = this->logger->get_depth(); */ - - TAG_ENTER(this->logger); - - size_t hash = gs_dp_state_hash(graph, sink_node, output_shape, input_shape); - tl::optional cached = this->try_get_cost_from_cache(hash); - if (cached.has_value()) { - this->logger->spew() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - { - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->spew() << "Retrieved value from cache: " << cached.value(); - } - - /* this->logger->check_same_as(starting_depth); */ - return cached.value(); - } - - this->logger->debug() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - T return_value; - { - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->debug() << "Graph hash: " << std::setw(32) - << std::setfill('0') << graph->hash(); - if (input_shape.has_value()) { - this->logger->debug() << "Input shape: " << input_shape.value(); - } else { - this->logger->debug() << "Input shape: "; - } - if (output_shape.has_value()) { - this->logger->debug() << "Output shape: " << output_shape.value(); - } else { - this->logger->debug() << "Output shape: "; - } - - tl::optional bottleneck = - this->find_split_node(graph, this->config.base_optimize_threshold); - - if (!bottleneck.has_value()) { - this->logger->debug() << "Applying base case"; - Graph to_optimize(*graph); - if (input_shape.has_value()) { - Node input_node = - this->model->get_or_create_input_node(input_shape.value()); - Node noop_node = - this->model->get_or_create_noop_node(input_node.ptr->outputs[0]); - Graph input_graph(this->model); - Edge e(input_node, noop_node, 0, 0); - input_graph.add_edge(e); - - Node old_source_node = graph->find_source_node(); - ParallelTensorShape old_source_output_shape = - old_source_node.ptr->outputs[0]->get_shape(); - input_graph.reshape_output_tensor(old_source_output_shape); - - Node new_sink_node = input_graph.find_sink_node(); - assert(new_sink_node.ptr->numOutputs == 1); - assert(new_sink_node.ptr->outputs[0]->get_shape() == - old_source_output_shape); - - to_optimize.replace_subgraph({old_source_node}, input_graph); - } - SimplificationSettings settings; - if (output_shape.has_value()) { - to_optimize.reshape_output_tensor(output_shape.value()); - Node sink_node = to_optimize.find_sink_node(); - Node noop_node = - this->model->get_or_create_noop_node(sink_node.ptr->outputs[0]); - to_optimize.add_edge(sink_node, noop_node, 0, 0); - } else { - settings.remove_trailing_parallel_ops = true; - } - settings.simplify_parallel_ops = true; - std::unique_ptr optimized = - this->base_optimize(&to_optimize, settings); - return_value = get_optimal_cost( - std::move(optimized)); // optimized->generic_optimal_cost(); - } else { - this->logger->debug() << "Applying recursive case on bottleneck " - << bottleneck.value().guid; - std::unique_ptr pre_graph, post_graph; - std::tie(pre_graph, post_graph) = - graph->split_at_node(bottleneck.value()); - - MachineResource resources(this->model->config); - std::vector valid_machine_views = - this->model->search->get_valid_machine_views(bottleneck.value().ptr, - resources); - - float best_cost = std::numeric_limits::infinity(); - tl::optional best_shape = tl::nullopt; - { - TAG_ENTER(this->logger); - for (ParallelTensorShape const &bottleneck_output_shape : - this->possible_split_output_tensor_shapes(bottleneck.value())) { - this->logger->debug() - << "Considering boundary shape " << bottleneck_output_shape; - float current_cost; - { - TAG_ENTER(this->logger); - // TODO @lockshaw we really should create the merged graph here - // since it's possible though unlikely for there to be hidden - // transfer costs between modules due to device assignment changes - // across the boundaries - - // We wait to add the communication nodes between boundaries so we - // don't accidentally split on them and keep processing the pure - // computation graph The bottleneck node is kept in the postgraph - // purely as a placeholder and will be replaced with an Input/NoOp - // sequence before any rewrites are actually performed - // this->logger->debug() << "Finding cost of pre_graph (" << - // bottleneck_output_shape << ")"; float pre_cost = - // this->generic_sequence_optimize(pre_graph.get(), - // bottleneck.value(), bottleneck_output_shape, input_shape); - // this->logger->debug() << "Cost of pre_graph (" << - // bottleneck_output_shape << "): " << pre_cost; - // this->logger->debug() << "Finding cost of post_graph (" << - // bottleneck_output_shape << ")"; float post_cost = - // this->generic_sequence_optimize(post_graph.get(), - // sink_node, output_shape, bottleneck_output_shape); - // this->logger->debug() << "Cost of post_graph (" << - // bottleneck_output_shape << "): " << post_cost; float current_cost - // = pre_cost + post_cost; - current_cost = - this->execute_sequence_split(pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - bottleneck_output_shape); - - if (current_cost < best_cost) { - best_cost = current_cost; - best_shape = bottleneck_output_shape; - } - } - this->logger->debug() << "Boundary shape " << bottleneck_output_shape - << " has cost: " << current_cost; - } - } - - if (best_shape.has_value()) { - this->logger->debug() - << "Best intermediate shape found: " << best_shape.value(); - } else { - this->logger->debug() << "No valid intermediate shapes found"; - } - - if (best_cost != std::numeric_limits::infinity()) { - return_value = this->execute_sequence_split(pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - best_shape.value()); - } - } - - this->try_cache_result(hash, return_value); - } - return return_value; -} - -/** - * @brief Top level DP search procedure for Unity with the consideration of - * memory usage. - * - * @tparam T Returned type - * @param graph Pre-optimization PCG - * @param sink_node Sink node of the PCG - * @param output_shape ??? - * @param input_shape ??? - * @return T Optimal result - */ -template -T GraphSearchHelper::generic_sequence_optimize_with_memory( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - TAG_ENTER(this->logger); - - // Try to find the result from cache first. But this will only get the cached - // result if the returned type is float. The float number means the best run - // time cost with only machine quantity (without distinguishing machine - // identities). - size_t hash = gs_dp_state_hash(graph, sink_node, output_shape, input_shape); - tl::optional cached = this->try_get_cost_from_cache(hash); - if (cached.has_value()) { - this->logger->spew() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - { - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->spew() << "Retrieved value from cache: " << cached.value(); - } - return cached.value(); - } - - // Couldn't find the result from cache. Try to optimize and get one. - this->logger->debug() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - T return_value; - { - // Print out debug information - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->debug() << "Graph hash: " << std::setw(32) - << std::setfill('0') << graph->hash(); - if (input_shape.has_value()) { - this->logger->debug() << "Input shape: " << input_shape.value(); - } else { - this->logger->debug() << "Input shape: "; - } - if (output_shape.has_value()) { - this->logger->debug() << "Output shape: " << output_shape.value(); - } else { - this->logger->debug() << "Output shape: "; - } - - // Find the node to sequentially split the PCG. - // Decide if the search reaches the base condition by this. - tl::optional bottleneck = - this->find_split_node(graph, this->config.base_optimize_threshold); - - if (!bottleneck.has_value()) { - this->logger->debug() << "Applying base case"; - - // Construct the PCG to optimize based on input_shape and output_shape - // information. - Graph to_optimize(*graph); - if (input_shape.has_value()) { - Node input_node = - this->model->get_or_create_input_node(input_shape.value()); - Node noop_node = - this->model->get_or_create_noop_node(input_node.ptr->outputs[0]); - Graph input_graph(this->model); - Edge e(input_node, noop_node, 0, 0); - input_graph.add_edge(e); - - Node old_source_node = graph->find_source_node(); - ParallelTensorShape old_source_output_shape = - old_source_node.ptr->outputs[0]->get_shape(); - input_graph.reshape_output_tensor(old_source_output_shape); - - Node new_sink_node = input_graph.find_sink_node(); - assert(new_sink_node.ptr->numOutputs == 1); - assert(new_sink_node.ptr->outputs[0]->get_shape() == - old_source_output_shape); - - to_optimize.replace_subgraph({old_source_node}, input_graph); - } - SimplificationSettings settings; - if (output_shape.has_value()) { - to_optimize.reshape_output_tensor(output_shape.value()); - Node sink_node = to_optimize.find_sink_node(); - Node noop_node = - this->model->get_or_create_noop_node(sink_node.ptr->outputs[0]); - to_optimize.add_edge(sink_node, noop_node, 0, 0); - } else { - settings.remove_trailing_parallel_ops = true; - } - settings.simplify_parallel_ops = true; - - // Call base optimization to perform graph substitution. - std::unique_ptr optimized = - this->base_optimize_with_memory(&to_optimize, settings); - return_value = get_optimal_cost(std::move(optimized)); - } else { - this->logger->debug() << "Applying recursive case on bottleneck " - << bottleneck.value().guid; - - std::unique_ptr pre_graph, post_graph; - std::tie(pre_graph, post_graph) = - graph->split_at_node(bottleneck.value()); - - MachineResource resources(this->model->config); - std::vector valid_machine_views = - this->model->search->get_valid_machine_views(bottleneck.value().ptr, - resources); - - // Try to find the best cost and corresponding best bottleneck shape. - // This search process is based on the float version of - // execute_sequence_split_with_memory(). - float best_cost = std::numeric_limits::infinity(); - tl::optional best_shape = tl::nullopt; - { - TAG_ENTER(this->logger); - for (auto const &bottleneck_output_shape : - this->possible_split_output_tensor_shapes(bottleneck.value())) { - this->logger->debug() - << "Considering boundary shape " << bottleneck_output_shape; - float current_cost; - { - TAG_ENTER(this->logger); - // Get the cost from execute_sequence_split_with_memory by - // only changing bottleneck_output_shape. - current_cost = this->execute_sequence_split_with_memory( - pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - bottleneck_output_shape); - - if (current_cost < best_cost) { - best_cost = current_cost; - best_shape = bottleneck_output_shape; - } - } - this->logger->debug() << "Boundary shape " << bottleneck_output_shape - << " has cost: " << current_cost; - } - } - - if (best_shape.has_value()) { - this->logger->debug() - << "Best intermediate shape found: " << best_shape.value(); - } else { - this->logger->debug() << "No valid intermediate shapes found"; - } - - // ? What if best_cost is infinity ? - if (best_cost != std::numeric_limits::infinity()) { - // Get the return value of correct type with previously found - // best_shape. - return_value = - this->execute_sequence_split_with_memory(pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - best_shape.value()); - } - } - // Try to cache the float result - this->try_cache_result(hash, return_value); - } - return return_value; -} - -std::vector - GraphSearchHelper::possible_split_output_tensor_shapes( - Node const &source_node) const { - TAG_ENTER(this->logger); - - this->logger->debug() << "Finding possible output tensor shapes for node " - << source_node.guid; - assert(source_node.ptr->numOutputs == 1); - ParallelTensor output_tensor = source_node.ptr->outputs[0]; - for (int i = 0; i < output_tensor->num_dims; i++) { - assert(output_tensor->dims[i].degree == 1); - } - - std::vector without_replicas; - - int num_devices = this->config.numNodes * this->config.workersPerNode; - int degrees[MAX_TENSOR_DIM]; - std::fill_n(degrees, MAX_TENSOR_DIM, 1); - - ParallelTensorShape base_shape; - base_shape.num_dims = output_tensor->num_dims; - for (int i = 0; i < output_tensor->num_dims; i++) { - base_shape.dims[i].degree = 1; - base_shape.dims[i].size = output_tensor->dims[i].size; - } - without_replicas.push_back(base_shape); - - { - TAG_ENTER(this->logger); - while (true) { - bool is_done = true; - for (int i = 0; i < output_tensor->num_dims; i++) { - degrees[i] *= 2; - if (degrees[i] > num_devices) { - degrees[i] = 1; - } else { - is_done = false; - break; - } - } - std::ostringstream oss; - for (int i = 0; i < output_tensor->num_dims; i++) { - oss << degrees[i] << " "; - } - this->logger->spew() << "Considering: " << oss.str(); - if (is_done) { - break; - } - - bool is_valid = true; - int total_degree = 1; - ParallelTensorShape shape; - shape.num_dims = output_tensor->num_dims; - for (int i = 0; i < output_tensor->num_dims; i++) { - total_degree *= degrees[i]; - shape.dims[i].degree = degrees[i]; - shape.dims[i].size = output_tensor->dims[i].size; - if (shape.dims[i].size % shape.dims[i].degree != 0) { - is_valid = false; - } - } - if (total_degree <= num_devices && is_valid) { - without_replicas.push_back(shape); - } - } - } - - this->logger->debug() << "Found " << without_replicas.size() - << " possible tensor output shapes without replicas"; - this->logger->debug() << "They are:"; - { - TAG_ENTER(this->logger); - for (auto const &shape : without_replicas) { - this->logger->debug() << shape; - } - } - return without_replicas; -} - -void GraphSearchHelper::subgraph_optimize(Graph *subgraph) {} - -template <> -OpX *GraphXfer::create_opx(TensorX const &input, OpX const *matchOpX) { - return this->create_conv2d(input, matchOpX); -} - -template <> -OpX *GraphXfer::create_opx(TensorX const &input, OpX const *matchOpX) { - OpX *pool = new OpX(OP_POOL2D, 1, 1, input); - pool->matchOpX = matchOpX; - return pool; -} - -template <> -OpX *GraphXfer::create_opx(TensorX const &input, OpX const *matchOpX) { - OpX *flat = new OpX(OP_FLAT, 1, 1, input); - flat->matchOpX = matchOpX; - return flat; -} - -GraphXfer *create_partition_linear_combine(FFModel *model, - int num_dims, - int num_parts, - ActiMode activation, - bool use_bias) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *linear1 = subst->create_linear( - input, NULL /*matchOpX*/, num_dims, activation, use_bias); - OpX *repartition = subst->create_repartition(input, num_dims - 2, num_parts); - OpX *linear2 = subst->create_linear(repartition->outputs[0], - linear1 /*matchOpX*/, - num_dims, - activation, - use_bias); - OpX *combine = - subst->create_combine(linear2->outputs[0], num_dims - 2, num_parts); - subst->map_output(linear1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(linear1); - subst->dstOps.push_back(repartition); - subst->dstOps.push_back(linear2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_linear_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts - << ",activation=" << activation << ",use_bias=" << use_bias << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_conv2d_combine(FFModel *model, - int num_dims, - int num_parts) { - assert(num_dims == 5); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *conv1 = subst->create_conv2d(input, NULL /*matchOpX*/); - OpX *repartition = subst->create_repartition(input, num_dims - 2, num_parts); - OpX *conv2 = - subst->create_conv2d(repartition->outputs[0], conv1 /*matchOpX*/); - OpX *combine = - subst->create_combine(conv2->outputs[0], num_dims - 2, num_parts); - subst->map_output(conv1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(conv1); - subst->dstOps.push_back(repartition); - subst->dstOps.push_back(conv2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_conv2d_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_inception(FFModel *model, - int num_convs, - int num_dims, - int num_parts) { - // 3 convs and 1 pool2d - assert(num_dims == 5); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *src_combine = subst->create_combine(input, num_dims - 2, num_parts); - subst->srcOps.push_back(src_combine); - std::vector src_convs; - for (int i = 0; i < num_convs; i++) { - OpX *conv = - subst->create_conv2d(src_combine->outputs[0], NULL /*matchOpX*/); - src_convs.push_back(conv); - subst->srcOps.push_back(conv); - } - OpX *src_pool = - subst->create_pool2d(src_combine->outputs[0], NULL /*matchOpX*/); - subst->srcOps.push_back(src_pool); - // dst ops - std::vector dst_convs; - for (int i = 0; i < num_convs; i++) { - OpX *conv = subst->create_conv2d(input, src_convs[i] /*matchOpX*/); - OpX *comb = - subst->create_combine(conv->outputs[0], num_dims - 2, num_parts); - subst->dstOps.push_back(conv); - subst->dstOps.push_back(comb); - subst->map_output(src_convs[i]->outputs[0], comb->outputs[0]); - } - OpX *dst_pool = subst->create_pool2d(input, src_pool /*matchOpX*/); - OpX *dst_comb = - subst->create_combine(dst_pool->outputs[0], num_dims - 2, num_parts); - subst->dstOps.push_back(dst_pool); - subst->dstOps.push_back(dst_comb); - subst->map_output(src_pool->outputs[0], dst_comb->outputs[0]); - subst->name = "create_combine_inceptionA"; - return subst; -} - -GraphXfer *create_combine_concat(FFModel *model, - int num_inputs, - int num_dims, - int num_parts) { - // assert 5D - assert(num_dims == 5); - GraphXfer *subst = new GraphXfer(model); - std::vector inputs, concat_inputs; - std::vector combines; - for (int i = 0; i < num_inputs; i++) { - inputs.push_back(subst->new_tensor()); - combines.push_back( - subst->create_combine(inputs[i], num_dims - 2, num_parts)); - concat_inputs.push_back(combines[i]->outputs[0]); - subst->srcOps.push_back(combines[i]); - } - OpX *concat1 = subst->create_concat( - concat_inputs.data(), num_inputs, NULL /*matchOpX*/, 2); - subst->srcOps.push_back(concat1); - OpX *concat2 = - subst->create_concat(inputs.data(), num_inputs, concat1 /*matchOpX*/, 2); - OpX *combine = - subst->create_combine(concat2->outputs[0], num_dims - 2, num_parts); - subst->dstOps.push_back(concat2); - subst->dstOps.push_back(combine); - subst->map_output(concat1->outputs[0], combine->outputs[0]); - subst->name = "create_combine_concat"; - return subst; -} - -GraphXfer *create_partition_attention_combine(FFModel *model, - int num_heads, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *attn1 = subst->create_attention( - input, input, input, NULL /*matchOpX*/, num_heads); - OpX *repart = subst->create_repartition(input, 2, num_parts); - OpX *attn2 = subst->create_attention(repart->outputs[0], - repart->outputs[0], - repart->outputs[0], - attn1 /*matchOpX*/, - num_heads); - OpX *combine = subst->create_combine(attn2->outputs[0], 2, num_parts); - subst->map_output(attn1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(attn1); - subst->dstOps.push_back(repart); - subst->dstOps.push_back(attn2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_attention_combine[" - << "num_heads=" << num_heads << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_replicate_attention_reduce(FFModel *model, - int num_heads, - int num_parts) { - assert(num_heads % num_parts == 0); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *attn1 = subst->create_attention( - input, input, input, NULL /*matchOpX*/, num_heads); - OpX *repl = subst->create_replicate(input, 3, num_parts); - OpX *attn2 = subst->create_attention(repl->outputs[0], - repl->outputs[0], - repl->outputs[0], - attn1 /*matchOpX*/, - num_heads / num_parts); - OpX *reduce = subst->create_reduction(attn2->outputs[0], 3, num_parts); - subst->map_output(attn1->outputs[0], reduce->outputs[0]); - subst->srcOps.push_back(attn1); - subst->dstOps.push_back(repl); - subst->dstOps.push_back(attn2); - subst->dstOps.push_back(reduce); - - std::ostringstream oss; - oss << "replicate_attention_reduce[" - << "num_heads=" << num_heads << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_replicate_linear_combine(FFModel *model, - int num_dims, - int num_parts, - ActiMode activation, - bool use_bias) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *linear1 = subst->create_linear( - input, NULL /*matchOpX*/, num_dims, activation, use_bias); - OpX *replicate = subst->create_replicate(input, num_dims - 1, num_parts); - OpX *linear2 = subst->create_linear(replicate->outputs[0], - linear1 /*matchOpX*/, - num_dims, - activation, - use_bias); - OpX *combine = subst->create_combine(linear2->outputs[0], 0, num_parts); - subst->map_output(linear1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(linear1); - subst->dstOps.push_back(replicate); - subst->dstOps.push_back(linear2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "replicate_linear_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts - << ",activation=" << activation << ",use_bias=" << use_bias << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_add_combine(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input1 = subst->new_tensor(); - TensorX input2 = subst->new_tensor(); - OpX *add1 = subst->create_element_binary(input1, input2, OP_EW_ADD); - OpX *repartition1 = - subst->create_repartition(input1, parallel_dim, num_parts); - OpX *repartition2 = - subst->create_repartition(input2, parallel_dim, num_parts); - OpX *add2 = subst->create_element_binary( - repartition1->outputs[0], repartition2->outputs[0], OP_EW_ADD); - OpX *combine = - subst->create_combine(add2->outputs[0], parallel_dim, num_parts); - subst->map_output(add1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(add1); - subst->dstOps.push_back(repartition1); - subst->dstOps.push_back(repartition2); - subst->dstOps.push_back(add2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_add_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_add_partition(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input1 = subst->new_tensor(); - TensorX input2 = subst->new_tensor(); - OpX *add1 = subst->create_element_binary(input1, input2, OP_EW_ADD); - - OpX *combine1 = subst->create_combine(input1, parallel_dim, num_parts); - OpX *combine2 = subst->create_combine(input2, parallel_dim, num_parts); - OpX *add2 = subst->create_element_binary( - combine1->outputs[0], combine2->outputs[0], OP_EW_ADD); - OpX *repartition = - subst->create_repartition(add2->outputs[0], parallel_dim, num_parts); - subst->map_output(add1->outputs[0], repartition->outputs[0]); - subst->srcOps.push_back(add1); - subst->dstOps.push_back(combine1); - subst->dstOps.push_back(combine2); - subst->dstOps.push_back(add2); - subst->dstOps.push_back(repartition); - - std::ostringstream oss; - oss << "combine_add_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_relu_combine(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *relu1 = subst->create_element_unary(input, OP_RELU); - - OpX *partition = subst->create_repartition(input, parallel_dim, num_parts); - OpX *relu2 = subst->create_element_unary(partition->outputs[0], OP_RELU); - OpX *combine = - subst->create_combine(relu2->outputs[0], parallel_dim, num_parts); - - subst->map_output(relu1->outputs[0], combine->outputs[0]); - - subst->srcOps.push_back(relu1); - - subst->dstOps.push_back(partition); - subst->dstOps.push_back(relu2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_relu_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_relu_partition(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *relu1 = subst->create_element_unary(input, OP_RELU); - - OpX *combine = subst->create_combine(input, parallel_dim, num_parts); - OpX *relu2 = subst->create_element_unary(combine->outputs[0], OP_RELU); - OpX *partition = - subst->create_repartition(relu2->outputs[0], parallel_dim, num_parts); - - subst->map_output(relu1->outputs[0], partition->outputs[0]); - - subst->srcOps.push_back(relu1); - - subst->dstOps.push_back(combine); - subst->dstOps.push_back(relu2); - subst->dstOps.push_back(partition); - - std::ostringstream oss; - oss << "combine_relu_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_concat_combine(FFModel *model, - int num_inputs, - int concat_dim, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - assert(num_inputs <= MAX_NUM_INPUTS); - TensorX inputs[MAX_NUM_INPUTS]; - for (int i = 0; i < num_inputs; i++) { - inputs[i] = subst->new_tensor(); - } - OpX *concat = - subst->create_concat(inputs, num_inputs, NULL /*matchOpX*/, concat_dim); - subst->srcOps.push_back(concat); - TensorX new_inputs[MAX_NUM_INPUTS]; - for (int i = 0; i < num_inputs; i++) { - OpX *repartition = - subst->create_repartition(inputs[i], parallel_dim, num_parts); - new_inputs[i] = repartition->outputs[0]; - subst->dstOps.push_back(repartition); - } - OpX *concat2 = subst->create_concat( - new_inputs, num_inputs, concat /*matchOpX*/, concat_dim); - subst->dstOps.push_back(concat2); - OpX *combine = - subst->create_combine(concat2->outputs[0], parallel_dim, num_parts); - subst->dstOps.push_back(combine); - subst->map_output(concat->outputs[0], combine->outputs[0]); - - std::ostringstream oss; - oss << "partition_concat_combine[" - << "num_inputs=" << num_inputs << ",concat_dim=" << concat_dim - << ",parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_softmax_combine(FFModel *model, - int softmax_dim, - int parallel_dim, - int num_parts) { - assert(parallel_dim != softmax_dim); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *softmax1 = subst->create_softmax(input, softmax_dim); - OpX *repartition = subst->create_repartition(input, parallel_dim, num_parts); - OpX *softmax2 = subst->create_softmax(repartition->outputs[0], softmax_dim); - OpX *combine = - subst->create_combine(softmax2->outputs[0], parallel_dim, num_parts); - subst->map_output(softmax1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(softmax1); - subst->dstOps.push_back(repartition); - subst->dstOps.push_back(softmax2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_softmax_combine[" - << "softmax_dim=" << softmax_dim << ",parallel_dim=" << parallel_dim - << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_softmax_partition(FFModel *model, - int softmax_dim, - int parallel_dim, - int num_parts) { - assert(parallel_dim != softmax_dim); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *softmax1 = subst->create_softmax(input, softmax_dim); - OpX *combine = subst->create_combine(input, parallel_dim, num_parts); - OpX *softmax2 = subst->create_softmax(combine->outputs[0], softmax_dim); - OpX *repartition = - subst->create_repartition(softmax2->outputs[0], parallel_dim, num_parts); - subst->map_output(softmax1->outputs[0], repartition->outputs[0]); - subst->srcOps.push_back(softmax1); - subst->dstOps.push_back(combine); - subst->dstOps.push_back(softmax2); - subst->dstOps.push_back(repartition); - - std::ostringstream oss; - oss << "combine_softmax_partition[" - << "softmax_dim=" << softmax_dim << ",parallel_dim=" << parallel_dim - << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *leading_relu_branch_combine(FFModel *model, - int parallel_dim, - int num_parts, - int num_combines) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *old_partition = - subst->create_repartition(input, parallel_dim, num_parts); - std::vector old_combines; - for (int i = 0; i < num_combines; i++) { - old_combines.push_back( - subst->create_combine(input, parallel_dim, num_parts)); - } - - OpX *new_partition = - subst->create_repartition(input, parallel_dim, num_parts); - std::vector new_noops; - for (int i = 0; i < num_combines; i++) { - new_noops.push_back(subst->create_noop(input)); - } - - subst->map_output(old_partition->outputs[0], new_partition->outputs[0]); - for (int i = 0; i < num_combines; i++) { - subst->map_output(old_combines[i]->outputs[0], new_noops[i]->outputs[0]); - } - - subst->srcOps.push_back(old_partition); - subst->srcOps.insert( - subst->srcOps.end(), old_combines.begin(), old_combines.end()); - subst->dstOps.push_back(new_partition); - subst->dstOps.insert(subst->dstOps.end(), new_noops.begin(), new_noops.end()); - - std::ostringstream oss; - oss << "leading_relu_branch_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts - << ",num_combines=" << num_combines << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *leading_relu_branch_partition(FFModel *model, - int parallel_dim, - int num_parts, - int num_partitions) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *old_combine = subst->create_combine(input, parallel_dim, num_parts); - std::vector old_partitions; - for (int i = 0; i < num_partitions; i++) { - old_partitions.push_back( - subst->create_repartition(input, parallel_dim, num_parts)); - } - - OpX *new_combine = subst->create_combine(input, parallel_dim, num_parts); - std::vector new_noops; - for (int i = 0; i < num_partitions; i++) { - new_noops.push_back(subst->create_noop(input)); - } - - subst->map_output(old_combine->outputs[0], new_combine->outputs[0]); - for (int i = 0; i < num_partitions; i++) { - subst->map_output(old_partitions[i]->outputs[0], new_noops[i]->outputs[0]); - } - - subst->srcOps.push_back(old_combine); - subst->srcOps.insert( - subst->srcOps.end(), old_partitions.begin(), old_partitions.end()); - subst->dstOps.push_back(new_combine); - subst->dstOps.insert(subst->dstOps.end(), new_noops.begin(), new_noops.end()); - - std::ostringstream oss; - oss << "leading_relu_branch_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts - << ",num_partitions=" << num_partitions << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer * - create_linear_relu_merge(FFModel *model, int num_dims, bool use_bias) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *old_linear = - subst->create_linear(input, nullptr, num_dims, AC_MODE_NONE, use_bias); - OpX *old_relu = subst->create_relu(old_linear->outputs[0]); - - OpX *new_linear = - subst->create_linear(input, old_linear, num_dims, AC_MODE_RELU, use_bias); - - subst->map_output(old_relu->outputs[0], new_linear->outputs[0]); - subst->srcOps.push_back(old_linear); - subst->srcOps.push_back(old_relu); - subst->dstOps.push_back(new_linear); - - std::ostringstream oss; - oss << "linear_relu_merge[" - << "num_dims=" << num_dims << ",use_bias=" << use_bias << "]"; - subst->name = oss.str(); - - return subst; -} - -} // namespace ffc - -using PCG::Edge; -using PCG::Graph; -using PCG::Node; - -/** - * @brief Optimize the graph stored in FFModel. - * - * @param[in] budget The search budget - * @param[in] only_data_parallel True if only doing data parallel training - * @param[out] best_graph The searched best graph - * @param[out] optimal_views The corresponding machine view of the best_graph - * @param[in] perform_memory_search True if we want to consider memory during - * the search - * @param[in] new_config Memory optimization config to use if this is a memory - * search - * @param[out] search_result The performance result of this search - */ -void FFModel::graph_optimize( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views, - bool perform_memory_search, - MemoryOptimConfig new_config, - MemorySearchResult &search_result) { - if (perform_memory_search) { - this->graph_search->update_mem_optim_config(new_config); - this->graph_search->graph_optimize_with_memory( - budget, only_data_parallel, best_graph, optimal_views, search_result); - } else { - this->graph_search->graph_optimize( - budget, only_data_parallel, best_graph, optimal_views); - } -} - -bool FFModel::convert_graph_to_operators( - Graph const *graph, - std::unordered_map const &optimal_views) { - // Clear operators - operators.clear(); - std::unordered_map todos; - std::unordered_map node_to_op; - std::vector queue; - for (auto const &it : graph->inEdges) { - auto const &inList = it.second; - if (inList.size() == 0) { - queue.push_back(it.first); - } else { - todos[it.first] = (int)inList.size(); - } - } - size_t index = 0; - while (index < queue.size()) { - Node node = queue[index++]; - assert(node.ptr != NULL); - auto const &inList = graph->inEdges.find(node)->second; - ParallelTensor inputs[MAX_NUM_INPUTS]; - int num_inputs = 0; - for (auto const &e : inList) { - inputs[e.dstIdx] = node_to_op[e.srcOp]->outputs[e.srcIdx]; - assert(e.dstIdx < (int)inList.size()); - num_inputs++; - } - Op *new_op = NULL; - switch (node.ptr->op_type) { - case OP_INPUT: { - NoOp *noop = (NoOp *)node.ptr; - new_op = new NoOp( - *this, OP_INPUT, noop->input_tensor_guid, node.ptr->outputs[0]); - break; - } - case OP_CONCAT: { - Concat *concat = (Concat *)node.ptr; - new_op = new Concat( - *this, (int)inList.size(), inputs, concat->legion_axis, NULL); - break; - } - case OP_AGGREGATE: { - Aggregate *aggr = (Aggregate *)node.ptr; - new_op = new Aggregate(*this, inputs, aggr->n, aggr->lambda_bal, NULL); - break; - } - case OP_SPLIT: { - Split *split = (Split *)node.ptr; - std::vector splits; - for (int i = 0; i < split->numOutputs; i++) { - splits.push_back(split->outputs[i]->dims[split->legion_axis].size); - } - new_op = new Split(*this, inputs[0], splits, split->legion_axis, NULL); - break; - } - case OP_EMBEDDING: { - new_op = new Embedding(*this, *(Embedding *)node.ptr, inputs[0], true); - break; - } - case OP_EW_ADD: - case OP_EW_SUB: - case OP_EW_MUL: - case OP_EW_MAX: - case OP_EW_MIN: { - assert(inList.size() == 2); - ElementBinary *eb = (ElementBinary *)node.ptr; - new_op = new ElementBinary( - *this, eb->op_type, inputs[0], inputs[1], eb->inplace_a, NULL); - break; - } - case OP_POOL2D: { - new_op = new Pool2D(*this, *(Pool2D *)node.ptr, inputs[0]); - break; - } - case OP_CONV2D: { - new_op = new Conv2D(*this, *(Conv2D *)node.ptr, inputs[0], true); - break; - } - case OP_DROPOUT: { - new_op = new Dropout(*this, *(Dropout *)node.ptr, inputs[0]); - break; - } - case OP_LINEAR: { - new_op = new Linear(*this, *(Linear *)node.ptr, inputs[0], true); - break; - } - case OP_MULTIHEAD_ATTENTION: { - assert(inList.size() == 3); - MultiHeadAttention *attn = (MultiHeadAttention *)node.ptr; - new_op = new MultiHeadAttention( - *this, *attn, inputs[0], inputs[1], inputs[2], true); - break; - break; - } - case OP_SOFTMAX: { - assert(inList.size() == 1); - Softmax *softmax = (Softmax *)node.ptr; - new_op = new Softmax(*this, inputs[0], softmax->dim, NULL); - break; - } - case OP_COMBINE: { - assert(inList.size() == 1); - Combine *combine = (Combine *)node.ptr; - new_op = new Combine( - *this, inputs[0], combine->combine_dim, combine->combine_degree); - break; - } - case OP_REPARTITION: { - assert(inList.size() == 1); - Repartition *repart = (Repartition *)node.ptr; - new_op = new Repartition(*this, - inputs[0], - repart->repartition_dim, - repart->repartition_degree); - break; - } - case OP_REPLICATE: { - assert(inList.size() == 1); - Replicate *replicate = (Replicate *)node.ptr; - new_op = new Replicate(*this, - inputs[0], - replicate->replicate_dim, - replicate->replicate_degree); - break; - } - case OP_REDUCTION: { - assert(inList.size() == 1); - Reduction *reduction = (Reduction *)node.ptr; - new_op = new Reduction(*this, - inputs[0], - reduction->reduction_dim, - reduction->reduction_degree); - break; - } - case OP_FUSED_PARALLEL: { - assert(inList.size() == 1); - FusedParallelOp *fused = (FusedParallelOp *)node.ptr; - std::vector parallel_ops; - for (int i = 0; i < fused->num_parallel_ops; i++) { - parallel_ops.push_back(fused->parallel_ops[i]); - } - new_op = new FusedParallelOp(*this, inputs[0], parallel_ops); - break; - } - default: { - new_op = node.ptr->materialize(*this, inputs, num_inputs); - break; - } - } - // Set machine view for the output tensors of this operator - assert(optimal_views.find(node) != optimal_views.end()); - MachineView view = optimal_views.find(node)->second; - for (int i = 0; i < new_op->numOutputs; i++) { - new_op->outputs[i]->machine_view = view; - } - // Set machine view for the weight tensors of this operator - for (int i = 0; i < new_op->numWeights; i++) { - new_op->weights[i]->machine_view = view; - } - node_to_op[node] = new_op; - operators.push_back(new_op); - // Decrease the todos - auto const &outList = graph->outEdges.find(node)->second; - for (auto const &it : outList) { - todos[it.dstOp] -= 1; - if (todos[it.dstOp] == 0) { - queue.push_back(it.dstOp); - } - } - } - assert(queue.size() == graph->inEdges.size()); - // Remove the final parallel operators - while (operators[operators.size() - 1]->is_parallel_op()) { - Op *op = operators[operators.size() - 1]; - if (op->op_type == OP_REDUCTION) { - break; - } - if (op->op_type == OP_FUSED_PARALLEL) { - FusedParallelOp *fused_op = (FusedParallelOp *)op; - bool has_reduction = false; - for (int i = 0; i < fused_op->num_parallel_ops; i++) { - if (fused_op->parallel_ops[i].op_type == OP_REDUCTION) { - has_reduction = true; - } - } - if (has_reduction) { - break; - } - } - operators.pop_back(); - } - return true; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/old/substitution.h b/lib/compiler/src/old/substitution.h deleted file mode 100644 index 95a59e952c..0000000000 --- a/lib/compiler/src/old/substitution.h +++ /dev/null @@ -1,309 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _FLEXFLOW_SUBSTITUTION_H_ -#define _FLEXFLOW_SUBSTITUTION_H_ -#include "graph.h" -#include "substitutions/substitutions.h" -#include "tl/optional.hpp" -#include "utils/recursive_logger.h" -#include -#include - -namespace FlexFlow { -namespace ffc { - -/* struct PMConstraint { */ -/* PMConstraint(Compare comp, PMParameter para, int value); */ -/* Compare comp; */ -/* PMParameter para; */ -/* int value; */ -/* }; */ - -struct TNConstraint { - TNConstraint(Compare comp, TNParameter para, DIMParameter dim, int value); - TNConstraint(Compare comp, - TNParameter para1, - DIMParameter dim1, - TNParameter para2, - DIMParameter dim2); - bool singlePara; - Compare comp; - TNParameter para1, para2; - DIMParameter dim1, dim2; - int value; -}; - -/* class Op; */ -/* class OpX; */ -/* class GraphXfer; */ - -struct TensorX { - static const TensorX NO_TX; - TensorX(void) : op(NULL), idx(0) {} - TensorX(OpX *_op, int _idx) : op(_op), idx(_idx) {} - tl::optional - to_tensor(GraphXfer const *xfer) const; - OpX *op; - int idx; - - bool operator==(TensorX const &other) const; - bool operator!=(TensorX const &other) const; -}; - -struct TensorXCompare { - bool operator()(TensorX const &a, TensorX const &b) const { - if (a.op != b.op) { - return a.op < b.op; - } - return a.idx < b.idx; - }; -}; - -/* class OpX { */ -/* public: */ -/* OpX(OperatorType type, */ -/* int numInputs, */ -/* int numOutputs, */ -/* TensorX const &input1 = TensorX::NO_TX, */ -/* TensorX const &input2 = TensorX::NO_TX, */ -/* TensorX const &input3 = TensorX::NO_TX, */ -/* TensorX const &input4 = TensorX::NO_TX); */ -/* OpX(OperatorType type, */ -/* int num_inputs, */ -/* int num_outputs, */ -/* TensorX const *inputs); */ -/* bool add_pm_constraint(Compare, PMParameter para, int value); */ -/* bool add_input_constraint(Compare, TNParameter, DIMParameter, int); */ -/* bool add_input_constraint( */ -/* Compare, TNParameter, DIMParameter, TNParameter, DIMParameter); */ -/* bool get_pm_constraint(PMParameter para, int &value) const; */ - -/* public: */ -/* OperatorType type; */ -/* Node mapOp; */ -/* OpX const *matchOpX; */ -/* std::vector inputs, weights, outputs; */ -/* std::vector pmConstraints; */ -/* std::vector tnConstraints; */ -/* }; */ - -OpX *create_opx(substitutions::Operator const &op, - int parallel_degree, - TensorX const &input1 = TensorX::NO_TX, - TensorX const &input2 = TensorX::NO_TX, - TensorX const &input3 = TensorX::NO_TX, - TensorX const &input4 = TensorX::NO_TX); -void create_xfer(GraphXfer &xfer, - substitutions::Rule const &r, - int parallel_degree); -std::vector - create_xfers(substitutions::RuleCollection const &rules, - int parallel_degree); - -class GraphCompare { -public: - bool operator()(Graph *lhs, Graph *rhs) { - return lhs->optimal_cost() > rhs->optimal_cost(); - } -}; - -class GraphXferMatch { -public: - GraphXferMatch(GraphXfer const *); - - void add_mapping(Node const &, OpX *); - void add_mapping(OpX *, Node const &); - void add_input_mapping(int, std::pair const &); - void add_output_mapping(TensorX const &, TensorX const &); - OpX *at(Node const &) const; - Node at(OpX *) const; - void set_graph(Graph const *); - - bool containsNode(Graph const *, Node const &) const; - bool containsEdge(Graph const *, Edge const &) const; - - GraphXfer const *get_xfer() const; - std::unordered_set get_nodes() const; - -private: - std::map nodeToOpX; - std::map opXToNode; - std::map mappedOutputs; - size_t graph_hash; - GraphXfer const *xfer; -}; - -/* class GraphXfer { */ -/* public: */ -/* GraphXfer(); */ -/* TensorX new_tensor(void); */ -/* bool can_match(OpX *srcOp, Node const &op, Graph const *graph); */ -/* void match(OpX *srcOp, Node const &op, Graph const *graph); */ -/* void unmatch(OpX *srcOp, Node const &op, Graph const *graph); */ -/* // Compute Ops */ -/* template */ -/* OpX *create_opx(TensorX const &input, OpX const *matchOpX); */ - -/* OpX *create_noop(TensorX const &input); */ -/* OpX *create_concat(TensorX const *inputs, */ -/* int num_inputs, */ -/* OpX const *match_opx, */ -/* int concat_dim); */ -/* OpX *create_element_binary(TensorX const &input1, */ -/* TensorX const &input2, */ -/* OperatorType op_type); */ -/* OpX *create_element_unary(TensorX const &input, OperatorType op_type); */ -/* OpX *create_relu(TensorX const &input); */ -/* OpX *create_linear(TensorX const &input, */ -/* OpX const *match_opx, */ -/* int num_dims, */ -/* ActiMode acti_mode, */ -/* bool use_bias); */ -/* OpX *create_conv2d(TensorX const &input, OpX const *match_opx); */ -/* OpX *create_pool2d(TensorX const &input, OpX const *match_opx); */ -/* OpX *create_attention(TensorX const &query, */ -/* TensorX const &key, */ -/* TensorX const &value, */ -/* OpX const *match_opx, */ -/* int num_heads); */ -/* OpX *create_softmax(TensorX const &input, int softmax_dim); */ -/* // Parallel Ops */ -/* OpX *create_repartition(TensorX const &input, */ -/* int repartition_dim, */ -/* int num_parts); */ -/* OpX *create_replicate(TensorX const &input, int replicate_dim, int - * num_parts); */ -/* OpX *create_reduction(TensorX const &input, int reduction_dim, int - * num_parts); */ -/* OpX *create_combine(TensorX const &input, int combine_dim, int num_parts); - */ -/* bool map_output(TensorX const &src, TensorX const &dst); */ - -/* Graph *create_new_graph(Graph const *graph, */ -/* SimplificationSettings const &settings); */ -/* bool create_new_operator(OpX const *opx, Node &op); */ - -/* std::string get_name() const; */ - -/* void run(int depth, */ -/* Graph *graph, */ -/* std::priority_queue, GraphCompare> - * &, */ -/* std::unordered_set &, */ -/* float threshold, */ -/* int maxNumOps, */ -/* SimplificationSettings const &simplification_settings, */ -/* int &num_matches_found, */ -/* int &num_matches_rejected); */ - -/* void find_matches(Graph const *, std::vector &matches); */ -/* GraphXferMatch get_match_record(Graph const *) const; */ - -/* private: */ -/* void find_matches(int depth, */ -/* Graph const *graph, */ -/* std::vector &matches); */ - -/* public: */ -/* tl::optional name = tl::nullopt; */ -/* int tensorId; */ -/* std::map mappedOps; */ -/* std::multimap> mappedInputs; */ -/* std::map mappedOutputs; */ -/* std::vector srcOps; */ -/* std::vector dstOps; */ -/* }; */ - -struct SubstitutionMatch { - std::unordered_map node_assignment; - std::unordered_map edge_assignment; -}; - -std::unordered_set - find_matches(SubstitutionPattern const &pattern, - ParallelComputationGraph const &pcg); - -class GraphSearchHelper { -public: - GraphSearchHelper(); - void graph_optimize(size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views); - void graph_optimize_no_split( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views); - -private: - template - T generic_sequence_optimize( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape); - - float sequence_optimize(Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape); - - template - T execute_sequence_split( - std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - tl::optional const &output_shape, - tl::optional const &input_shape, - Node const &sink_node, - Node const &bottleneck, - ParallelTensorShape const &bottleneck_output_shape); - void generate_all_pcg_xfers(); - void load_graph_substitutions(std::vector &xfers) const; - Graph *construct_graph(); - void subgraph_optimize(Graph *subgraph); - - std::unique_ptr - base_optimize(Graph const *, - SimplificationSettings const &simplification_settings); - - std::vector - possible_split_output_tensor_shapes(Node const &) const; - - void find_rewrite_matches(Graph const *graph, - std::vector &matches) const; - tl::optional find_split_node(Graph const *graph, - int base_optimize_threshold) const; - - template - tl::optional try_get_cost_from_cache(size_t hash) const; - - template - void try_cache_result(size_t hash, T const &value); - - template - T get_optimal_cost(std::unique_ptr optimized) const; - -private: - std::unordered_map cached_optimized_graphs; - std::vector all_pcg_xfers; - std::unique_ptr logger; -}; - -} // namespace ffc -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 86fdd88d92..9d648ed99b 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -20,7 +20,7 @@ std::unordered_set Strategy graph_optimize(ComputationGraph &cg, - ICostEstimator const &cost_estimator, + CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( Operator const &, MachineSpecification const &)> const @@ -35,12 +35,8 @@ Strategy DeduplicatedPriorityQueue, StrategyRuntimeCmp> candidates; - Strategy initial_result(pcg, - optimal_cost(pcg, - allowed_machine_views, - cost_estimator, - resources, - cached_subgraph_costs)); + OptimalCostResult initial_pcg_result = optimal_cost(pcg, allowed_machine_views, cost_estimator, resources, cached_subgraph_costs); + Strategy initial_result{pcg, initial_pcg_result.machine_mapping, initial_pcg_result.runtime}; Strategy best_result = initial_result; candidates.push(initial_result); @@ -50,7 +46,7 @@ Strategy Strategy const ¤t_result = candidates.top(); candidates.pop(); - if (StrategyRuntimeCmp(current_result, best_result)) { + if (StrategyRuntimeCmp{}(current_result, best_result)) { best_result = current_result; } else if (current_result.runtime > best_result.runtime * opt_config.alpha) { @@ -64,9 +60,9 @@ Strategy cost_estimator, resources, cached_subgraph_costs); - Strategy new_result(new_pcg, c.machine_mapping, c.runtime); + Strategy new_result{new_pcg, c.machine_mapping, c.runtime}; if (new_result.runtime <= opt_config.threshold && - new_result.pcg.query_nodes({}).size() <= opt_config.max_num_ops) { + get_nodes(new_pcg.value()).size() <= opt_config.max_num_ops) { candidates.push(new_result); } } diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 1a5c2bc3f8..b482e851d8 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -26,6 +26,8 @@ struct MachineView : public use_visitable_cmp { StridedRectangle rect; }; +FF_VISITABLE_STRUCT(MachineView, start, rect); + std::size_t num_dims(MachineView const &); std::size_t num_devices(MachineView const &); DeviceType get_device_type(MachineView const &); @@ -43,7 +45,4 @@ MachineView make_1d_machine_view(device_id_t start, size_t interval_size); } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::MachineView, start, rect); -MAKE_VISIT_HASHABLE(::FlexFlow::MachineView); - #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph.h index 7e332933c7..2342cd08fa 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph.h @@ -15,6 +15,16 @@ struct ParallelComputationGraph }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(ParallelComputationGraph); +bool operator==(ParallelComputationGraph const &, ParallelComputationGraph const &); + } // namespace FlexFlow +namespace std { + +template <> +struct hash { + size_t operator()(FlexFlow::ParallelComputationGraph const &g) const; +}; +} + #endif diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index a52906c612..98471a8fbd 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -28,4 +28,12 @@ SubParallelComputationGraph } // namespace FlexFlow +namespace std{ +template <> +struct hash { + size_t operator()(FlexFlow::Substitution const &) const; +}; + +}; + #endif diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 5b2e5093bd..3a1444a0f5 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -108,6 +108,9 @@ std::unordered_set get_node_edges(UndirectedGraphView const &, std::unordered_set get_outputs(MultiDiGraphView const &); std::unordered_set get_inputs(MultiDiGraphView const &); +std::unordered_set get_open_outputs(OpenMultiDiGraphView const &); +std::unordered_set get_open_inputs(OpenMultiDiGraphView const &); + std::unordered_set get_incoming_edges(MultiDiGraphView const &, Node const &); std::unordered_set get_incoming_edges(DiGraphView const &, diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 15da6ce2cb..8bfcda9d0f 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -8,6 +8,7 @@ namespace FlexFlow { template struct INodeLabelledMultiDiGraphView : virtual public IMultiDiGraphView { + INodeLabelledMultiDiGraphView() = default; INodeLabelledMultiDiGraphView(INodeLabelledMultiDiGraphView const &) = delete; INodeLabelledMultiDiGraphView & operator=(INodeLabelledMultiDiGraphView const &) = delete; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 4d8c790400..2cbaaf44fd 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -9,8 +9,8 @@ template struct INodeLabelledOpenMultiDiGraphView : virtual INodeLabelledMultiDiGraphView, virtual IOpenMultiDiGraphView { - INodeLabelledOpenMultiDiGraphView(INodeLabelledOpenMultiDiGraphView const &) = - delete; + INodeLabelledOpenMultiDiGraphView() = default; + INodeLabelledOpenMultiDiGraphView(INodeLabelledOpenMultiDiGraphView const &) = delete; INodeLabelledOpenMultiDiGraphView & operator=(INodeLabelledOpenMultiDiGraphView const &) = delete; }; @@ -82,12 +82,12 @@ struct NodeLabelledOpenMultiDiGraph } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr()->query_nodes(); + return get_ptr()->query_nodes(q); } std::unordered_set query_edges(OpenMultiDiEdge const &q) const { - return get_ptr()->query_edges(); + return get_ptr()->query_edges(q); } Node add_node(NodeLabel const &l) { diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index 4a4c81aef9..8c8a8b1a1b 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -26,6 +26,10 @@ struct OutputLabelledOpenMultiDiSubgraphView return g.at(n); } + EdgeLabel const &at(InputMultiDiEdge const &i) const override { + return g.at(i); + } + EdgeLabel const &at(MultiDiOutput const &o) const override { return g.at(o); } @@ -39,11 +43,17 @@ struct OutputLabelledOpenMultiDiSubgraphView return SubgraphView(g, nodes).query_edges(q); } + OutputLabelledOpenMultiDiSubgraphView* clone() const override { + return new OutputLabelledOpenMultiDiSubgraphView(g, nodes); + } + private: OutputLabelledOpenMultiDiGraphView const &g; std::unordered_set const &nodes; }; +// CHECK_NOT_ABSTRACT(OutputLabelledOpenMultiDiSubgraphView); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index a9aa1e5251..9b35cdc883 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -15,6 +15,7 @@ struct IOutputLabelledMultiDiGraphView operator=(IOutputLabelledMultiDiGraphView const &) = delete; virtual OutputLabel const &at(MultiDiOutput const &) = 0; + using INodeLabelledMultiDiGraphView::at; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraphView); @@ -119,10 +120,10 @@ struct OutputLabelledMultiDiGraph } std::unordered_set query_nodes(NodeQuery const &q) const { - return this->ptr->query_nodes(q); + return get_ptr()->query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); + return get_ptr()->query_edges(q); } template diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index b3ecb5f273..28dba47bce 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN #define _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN -#include "node_labelled.h" +#include "node_labelled_open.h" #include "utils/graph/adjacency_openmultidigraph.h" namespace FlexFlow { @@ -59,6 +59,7 @@ struct OutputLabelledOpenMultiDiGraphView protected: using NodeLabelledOpenMultiDiGraphView< NodeLabel>::NodeLabelledOpenMultiDiGraphView; + OutputLabelledOpenMultiDiGraphView(cow_ptr_t ptr) : GraphView(ptr) {} private: cow_ptr_t get_ptr() const { @@ -70,7 +71,7 @@ struct OutputLabelledOpenMultiDiGraphView template EdgeLabel at(OutputLabelledOpenMultiDiGraphView const &g, OpenMultiDiEdge const &e) { - return visit([&](auto const e) { return g.at(e); }, e); + return visit([&](auto const &e) { return g.at(e); }, e); } template @@ -173,6 +174,11 @@ struct OutputLabelledOpenMultiDiGraph cow_ptr_t ol; }; +template +void add_label(OutputLabelledOpenMultiDiGraph &g, OpenMultiDiEdge const &e, EdgeLabel const &l) { + visit([&](const auto &e) { g.add_label(e, l); }, e); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index 0426a73e73..5a227d46ec 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -3,6 +3,7 @@ #include "node_labelled.h" #include "standard_labelled.h" +#include "output_labelled_open.h" namespace FlexFlow { @@ -70,7 +71,7 @@ CHECK_NOT_ABSTRACT(ViewMultiDiGraphAsOutputLabelled Impl materialize_output_labelled_multidigraph_view( - IOutputLabelledMultiDiGraphView const &g) { + OutputLabelledMultiDiGraphView const &g) { Impl result; for (Node const &n : get_nodes(g)) { result.add_node_unsafe(n); From c015efb50979f3f97ec81908868078296223968e Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 15 Nov 2023 17:12:53 -0500 Subject: [PATCH 02/32] unity dp works --- lib/compiler/CMakeLists.txt | 1 + lib/compiler/include/compiler/cost_estimate.h | 4 +- .../include/compiler/machine_mapping.h | 2 +- .../include/compiler/unity_algorithm.h | 3 +- lib/compiler/src/graph_utils.cc | 16 + lib/compiler/src/machine_mapping.cc | 26 +- lib/compiler/src/unity_algorithm.cc | 18 +- lib/compiler/test/CMakeLists.txt | 5 +- .../test/{ => src}/test_cost_estimator.h | 0 lib/compiler/test/{ => src}/test_generator.h | 2 +- .../test/src/test_labelled_open_graph.cc | 77 +++++ .../test/{ => src}/test_machine_mapping.cc | 2 +- lib/compiler/test/src/test_open_graph.cc | 80 +++++ lib/compiler/test/src/test_optimal_cost.cc | 60 ++++ .../test/{ => src}/test_unity_algorithm.cc | 0 lib/compiler/test/test_disjoint_set.cc | 19 -- lib/compiler/test/test_dominators.cc | 322 ------------------ lib/compiler/test/test_dot.cc | 23 -- lib/compiler/test/test_dp.cc | 54 --- lib/compiler/test/test_labelled_open_graph.cc | 76 ----- lib/compiler/test/test_machine_view.cc | 33 -- lib/compiler/test/test_open_graph.cc | 102 ------ lib/compiler/test/test_optimal_cost.cc | 24 -- lib/compiler/test/test_parallel_config.cc | 25 -- lib/compiler/test/test_random_utils.cc | 47 --- lib/compiler/test/test_substitution_loader.cc | 144 -------- lib/op-attrs/src/get_output_shapes.cc | 6 + lib/pcg/include/pcg/machine_specification.h | 19 +- lib/pcg/include/pcg/machine_view.h | 6 +- lib/pcg/include/pcg/operator.h | 3 +- lib/pcg/include/pcg/strided_rectangle.h | 23 +- lib/pcg/src/machine_view.cc | 4 +- lib/pcg/src/operator.cc | 2 +- lib/pcg/src/parallel_computation_graph.cc | 37 ++ lib/pcg/src/strided_rectangle.cc | 6 +- lib/utils/include/utils/graph/digraph.h | 2 +- .../utils/graph/labelled/node_labelled.h | 8 +- .../utils/graph/labelled/node_labelled_open.h | 7 +- .../include/utils/graph/labelled/open_views.h | 39 +++ .../utils/graph/labelled/output_labelled.h | 30 +- .../graph/labelled/output_labelled_open.h | 20 +- .../utils/graph/labelled/standard_labelled.h | 28 +- .../utils/graph/labelled/unordered_label.h | 3 +- .../include/utils/graph/labelled/views.h | 22 +- lib/utils/include/utils/graph/multidigraph.h | 2 +- lib/utils/include/utils/graph/open_graphs.h | 9 +- lib/utils/include/utils/graph/undirected.h | 2 +- lib/utils/src/graph/algorithms.cc | 35 +- lib/utils/src/graph/digraph.cc | 7 +- lib/utils/src/graph/multidigraph.cc | 7 +- lib/utils/src/graph/node.cc | 2 +- lib/utils/src/graph/open_graphs.cc | 27 +- lib/utils/src/graph/serialparallel.cc | 13 +- lib/utils/src/graph/undirected.cc | 6 +- lib/utils/src/graph/views.cc | 24 +- 55 files changed, 552 insertions(+), 1012 deletions(-) rename lib/compiler/test/{ => src}/test_cost_estimator.h (100%) rename lib/compiler/test/{ => src}/test_generator.h (98%) create mode 100644 lib/compiler/test/src/test_labelled_open_graph.cc rename lib/compiler/test/{ => src}/test_machine_mapping.cc (95%) create mode 100644 lib/compiler/test/src/test_open_graph.cc create mode 100644 lib/compiler/test/src/test_optimal_cost.cc rename lib/compiler/test/{ => src}/test_unity_algorithm.cc (100%) delete mode 100644 lib/compiler/test/test_disjoint_set.cc delete mode 100644 lib/compiler/test/test_dominators.cc delete mode 100644 lib/compiler/test/test_dot.cc delete mode 100644 lib/compiler/test/test_dp.cc delete mode 100644 lib/compiler/test/test_labelled_open_graph.cc delete mode 100644 lib/compiler/test/test_machine_view.cc delete mode 100644 lib/compiler/test/test_open_graph.cc delete mode 100644 lib/compiler/test/test_optimal_cost.cc delete mode 100644 lib/compiler/test/test_parallel_config.cc delete mode 100644 lib/compiler/test/test_random_utils.cc delete mode 100644 lib/compiler/test/test_substitution_loader.cc create mode 100644 lib/pcg/src/parallel_computation_graph.cc diff --git a/lib/compiler/CMakeLists.txt b/lib/compiler/CMakeLists.txt index 45c369fcdf..6610834eed 100644 --- a/lib/compiler/CMakeLists.txt +++ b/lib/compiler/CMakeLists.txt @@ -18,3 +18,4 @@ ff_add_library( ) add_subdirectory(ffi) +add_subdirectory(test) \ No newline at end of file diff --git a/lib/compiler/include/compiler/cost_estimate.h b/lib/compiler/include/compiler/cost_estimate.h index 27f963db50..3791292529 100644 --- a/lib/compiler/include/compiler/cost_estimate.h +++ b/lib/compiler/include/compiler/cost_estimate.h @@ -16,10 +16,11 @@ struct ICostEstimator { MachineView const &src, MachineView const &dst) const = 0; + ICostEstimator() = default; ICostEstimator(ICostEstimator const &) = delete; ICostEstimator &operator=(ICostEstimator const &) = delete; - virtual ~ICostEstimator(); + virtual ~ICostEstimator() = default; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); @@ -44,6 +45,7 @@ struct CostEstimator { } private: + CostEstimator(std::shared_ptr implementation_ptr) : implementation_ptr(implementation_ptr) {} std::shared_ptr implementation_ptr; }; diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index e8d7457fbf..9f9d97937d 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -15,7 +15,7 @@ struct MachineMapping { static MachineMapping combine(MachineMapping const &, MachineMapping const &); static bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); - + req> machine_views; }; FF_VISITABLE_STRUCT(MachineMapping, machine_views); diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index fc068d48c5..81e8375948 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -37,6 +37,7 @@ Strategy } // namespace FlexFlow VISITABLE_STRUCT(FlexFlow::Strategy, pcg, machine_mapping, runtime); + namespace std { template <> @@ -44,6 +45,6 @@ struct hash { size_t operator()(FlexFlow::Strategy const &) const; }; -}; +} #endif diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index d7f15e0796..04e96c66ed 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -7,6 +7,22 @@ SerialParallelDecomposition return get_serial_parallel_decomposition(pcg.value()); } +ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { + NOT_IMPLEMENTED(); +} + +SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { + auto g = pcg.value(); + auto g_ = view_output_labelled_as_output_labelled_open(g); + auto subpcg = materialize_output_labelled_open_multidigraph_view< + AdjacencyOpenMultiDiGraph, + UnorderedLabelling, + UnorderedLabelling, + UnorderedLabelling + >(g_); + return subpcg; +} + std::vector get_sorted_node_input_edges(ParallelComputationGraph const &pcg, Node const &n) { diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index fb04f57eac..671c59a94f 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -95,7 +95,7 @@ float estimate_cost(SubParallelComputationGraphView const &g, CostEstimator const &estimator, MachineMapping const &device_mapping, std::unordered_map const &frontier_machine_views) { - NOT_IMPLEMENTED(); + return 0.1; } void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { @@ -122,8 +122,8 @@ struct OptimalCost { SubParallelComputationGraphView const &g; CostEstimator const &cost_estimator; MachineSpecification const &resource; - std::unordered_map const &given_machine_views; - std::unordered_map const &frontier_machine_views; + std::unordered_map given_machine_views; + std::unordered_map frontier_machine_views; std::function( Operator const &, MachineSpecification const &)> const &allowed_machine_views; @@ -138,7 +138,6 @@ struct OptimalCost { if (cached_result) { return cached_result.value(); } - OptimalCostResult result = this->optimal_cost(t); cached_subgraph_costs.save(state, result); @@ -161,14 +160,15 @@ struct OptimalCost { Node split_point = get_only(post_graph_sources); OutputMultiDiEdge split_edge = get_only(get_open_outputs(pre_graph)); - + OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (MachineView const &mv : allowed_machine_views(g.at(split_point), resource)) { - auto new_given_machine_views = merge_maps(given_machine_views, std::unordered_map{{split_point, mv}}); - auto new_frontier_machine_views = merge_maps(frontier_machine_views, - std::unordered_map{{split_edge, mv}}); + std::unordered_map new_given_machine_views = given_machine_views; + new_given_machine_views.emplace(split_point, mv); + std::unordered_map new_frontier_machine_views = frontier_machine_views; + new_frontier_machine_views.emplace(split_edge, mv); minimize_runtime(optimal_result, OptimalCostResult::sequential_combine( visit(OptimalCost(pre_graph, @@ -269,14 +269,16 @@ OptimalCostResult CostEstimator const &cost_estimator, MachineSpecification const &resources, OptimalCostCache &cached_subgraph_costs) { - return visit(OptimalCost(pcg_to_subpcg(g), + SerialParallelDecomposition sp_decomposition = get_serial_parallel_decomposition(g); + SubParallelComputationGraph subpcg = pcg_to_subpcg(g); + return visit(OptimalCost(subpcg, cost_estimator, resources, - {}, - {}, + std::unordered_map{}, + std::unordered_map{}, allowed_machine_views, cached_subgraph_costs), - get_serial_parallel_decomposition(g)); + sp_decomposition); } } // namespace FlexFlow diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 9d648ed99b..16671b080a 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -10,7 +10,9 @@ bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { } std::unordered_set - get_all_substitutions(ParallelComputationGraph const &pcg); + get_all_substitutions(ParallelComputationGraph const &pcg) { + NOT_IMPLEMENTED(); +} std::unordered_set apply_substitution(ParallelComputationGraph const &pcg, @@ -73,3 +75,17 @@ Strategy } } // namespace FlexFlow + +namespace std { + +size_t hash::operator()(FlexFlow::Strategy const &s) const { + size_t h = 0; + + hash_combine(h, s.pcg); + // hash_combine(h, s.machine_mapping); + hash_combine(h, s.runtime); + + return h; +} + +} diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index dbbd0a63ec..cc64b15f7d 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -2,10 +2,13 @@ ff_add_test_executable( NAME compiler-test SRC_PATTERNS - src/*.cc + src/test_labelled_open_graph.cc + src/test_open_graph.cc + src/test_optimal_cost.cc PRIVATE_INCLUDE src/ DEPS + utils compiler doctest utils-test-common diff --git a/lib/compiler/test/test_cost_estimator.h b/lib/compiler/test/src/test_cost_estimator.h similarity index 100% rename from lib/compiler/test/test_cost_estimator.h rename to lib/compiler/test/src/test_cost_estimator.h diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/src/test_generator.h similarity index 98% rename from lib/compiler/test/test_generator.h rename to lib/compiler/test/src/test_generator.h index 374bb89455..23a79abbe0 100644 --- a/lib/compiler/test/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_TEST_GENERATOR_H #include "compiler/machine_mapping.h" -#include "compiler/sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph.h" #include "pcg/computation_graph.h" #include "rapidcheck.h" diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc new file mode 100644 index 0000000000..78ea1ece55 --- /dev/null +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -0,0 +1,77 @@ +#include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" +#include "rapidcheck.h" + +using namespace FlexFlow; + +TEST_CASE("get_subgraph_open_graph") { + auto g = OpenMultiDiGraph::create(); + + int t0 = 100000; + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + Node n3 = g.add_node(); + Node n4 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + NodePort p2 = g.add_node_port(); + NodePort p3 = g.add_node_port(); + NodePort p4 = g.add_node_port(); + NodePort p5 = g.add_node_port(); + NodePort p6 = g.add_node_port(); + NodePort p7 = g.add_node_port(); + NodePort p8 = g.add_node_port(); + NodePort p9 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + MultiDiEdge e1{n2, p2, n0, p0}; + MultiDiEdge e2{n3, p5, n1, p3}; + MultiDiEdge e3{n3, p6, n2, p4}; + MultiDiEdge e4{n4, p8, n3, p7}; + OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + std::unordered_set node_set0{n3, n4}; + + auto subgraph0 = get_subgraph(g, node_set0); + auto subgraph1 = get_subgraph(g, node_set0); + auto subgraph2 = get_subgraph(g, node_set0); + auto subgraph3 = get_subgraph(g, node_set0); + + CHECK(get_nodes(subgraph0) == node_set0); + CHECK(get_nodes(subgraph1) == node_set0); + CHECK(get_nodes(subgraph2) == node_set0); + CHECK(get_nodes(subgraph3) == node_set0); + + std::unordered_set input_set{split_edge(e2).second, + split_edge(e3).second}; + std::unordered_set output_set{e5}; + + CHECK(bool(get_open_inputs(subgraph0) == input_set)); + CHECK(bool(get_open_inputs(subgraph1) == input_set)); + CHECK(bool(get_open_inputs(subgraph2).empty())); + CHECK(bool(get_open_inputs(subgraph3).empty())); + + CHECK(bool(get_open_outputs(subgraph0) == output_set)); + CHECK(bool(get_open_outputs(subgraph1).empty())); + CHECK(bool(get_open_outputs(subgraph2) == output_set)); + CHECK(bool(get_open_outputs(subgraph3).empty())); + + CHECK(bool(get_edges(subgraph0) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4, e5})); + CHECK(bool(get_edges(subgraph1) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4})); + CHECK(bool(get_edges(subgraph2) == std::unordered_set{e4, e5})); + CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); +} diff --git a/lib/compiler/test/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc similarity index 95% rename from lib/compiler/test/test_machine_mapping.cc rename to lib/compiler/test/src/test_machine_mapping.cc index 4436a992d3..779f8134d9 100644 --- a/lib/compiler/test/test_machine_mapping.cc +++ b/lib/compiler/test/src/test_machine_mapping.cc @@ -1,4 +1,4 @@ -#include "doctest.h" +#include "doctest/doctest.h" #include "test_generator.h" TEST_CASE("MachineMapping::combine") { diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc new file mode 100644 index 0000000000..ea1108c291 --- /dev/null +++ b/lib/compiler/test/src/test_open_graph.cc @@ -0,0 +1,80 @@ +#include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" +#include "utils/graph/algorithms.h" + +using namespace FlexFlow; + +TEST_CASE("get_source_sink_open_graph") { + OpenMultiDiGraph g = OpenMultiDiGraph::create(); + + int s0 = 100000; + + Node n0 = g.add_node(); + NodePort p0 = g.add_node_port(); + InputMultiDiEdge e0{n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; + g.add_edge(e0); + + CHECK(bool(get_closed_sources(g) == std::unordered_set{})); + CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); + + CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); + CHECK(bool(get_open_sinks(g) == std::unordered_set{})); +} + +TEST_CASE("get_source_sink_open_graph:unconnected") { + OpenMultiDiGraph g = OpenMultiDiGraph::create(); + + int s0 = 100000; + int t0 = s0 + 1; + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + + InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; + OutputMultiDiEdge e1{n1, p1, std::make_pair(p1.value(), p1.value())}; + g.add_edge(e0); + g.add_edge(e1); + + /* + g: ->n0 + n1-> + */ + + CHECK(bool(get_closed_sources(g) == std::unordered_set{n1})); + CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); + + CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); + CHECK(bool(get_open_sinks(g) == std::unordered_set{n1})); +} + +TEST_CASE("get_cut") { + auto g = OpenMultiDiGraph::create(); + + std::vector ns; + for (int i = 0; i < 5; ++i) { + ns.push_back(g.add_node()); + } + + MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; + MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; + MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; + MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; + MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; + OutputMultiDiEdge e5{ns[4], g.add_node_port(), std::make_pair(ns[4].value(), ns[4].value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; + CHECK(bool(get_cut_set(g, gs0) == std::unordered_set{e1, e2})); + + GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; + CHECK(bool(get_cut_set(g, gs1) == std::unordered_set{e3, e4})); +} diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc new file mode 100644 index 0000000000..87f9d06342 --- /dev/null +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -0,0 +1,60 @@ +#include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" +#include "test_cost_estimator.h" + +using namespace FlexFlow; + +/* +Tests whether optimal_cost can give a valid result given random PCG, trivial +allowed machine views, trivial cost estimator and random machine specification. +*/ +// TEST_CASE("optimal_cost") { +// auto test_allowed_machine_views = [](Operator const &, +// MachineSpecification const &) { +// return std::unordered_set{make_1d_machine_view(0, 1, 1)}; +// }; +// rc::check([](ParallelComputationGraph const &g, +// MachineSpecification const &machine_spec) { +// OptimalCostCache cached_subgraph_costs; +// OptimalCostResult result = optimal_cost(g, +// test_allowed_machine_views, +// TestCostEstimator{}, +// machine_spec, +// cached_subgraph_costs); +// RC_ASSERT(result.runtime > 0); +// RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); +// }); +// } + +TEST_CASE("optimal_cost_0") { + auto pcg = OutputLabelledMultiDiGraph::template create< + AdjacencyMultiDiGraph, + UnorderedLabelling, + UnorderedLabelling + >(); + + Node n0 = pcg.add_node(Operator(InputAttrs{}, "input")); + Node n1 = pcg.add_node(Operator(LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, "linear")); + + MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; + pcg.add_edge(e); + pcg.add_output(e, + ParallelTensor(ParallelTensorDims({2, 1}), + DataType::FLOAT, + CreateGrad::YES)); + + auto test_allowed_machine_views = [](Operator const &, + MachineSpecification const &) { + return std::unordered_set{make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; + }; + + CostEstimator estimator = CostEstimator::create(); + + MachineSpecification machine_spec{1, 1, 1, 1, 1}; + + OptimalCostCache cached_results; + + OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), test_allowed_machine_views, estimator, machine_spec, cached_results); + + CHECK(bool(result.runtime > 0)); +} \ No newline at end of file diff --git a/lib/compiler/test/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc similarity index 100% rename from lib/compiler/test/test_unity_algorithm.cc rename to lib/compiler/test/src/test_unity_algorithm.cc diff --git a/lib/compiler/test/test_disjoint_set.cc b/lib/compiler/test/test_disjoint_set.cc deleted file mode 100644 index 796605f53f..0000000000 --- a/lib/compiler/test/test_disjoint_set.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "flexflow/utils/disjoint_set.h" -#include "gtest/gtest.h" - -TEST(disjoint_set, basic) { - int ctr = 0; - int a = ctr++, b = ctr++, c = ctr++, d = ctr++, e = ctr++, f = ctr++; - - disjoint_set ds; - ds.m_union(a, b); - ds.m_union(b, c); - ds.m_union(e, f); - ds.m_union(d, a); - - assert(ds.find(a) == ds.find(b)); - assert(ds.find(a) == ds.find(c)); - assert(ds.find(a) == ds.find(d)); - assert(ds.find(e) == ds.find(f)); - assert(ds.find(e) != ds.find(a)); -} diff --git a/lib/compiler/test/test_dominators.cc b/lib/compiler/test/test_dominators.cc deleted file mode 100644 index 60ac33696f..0000000000 --- a/lib/compiler/test/test_dominators.cc +++ /dev/null @@ -1,322 +0,0 @@ -#include "flexflow/basic_graph.h" -#include "flexflow/dominators.h" -#include "flexflow/utils/hash-utils.h" -#include "gtest/gtest.h" - -using namespace FlexFlow::PCG::Utils; - -namespace FlexFlow::PCG::Utils { -template <> -struct invalid_node<::BasicGraph, GraphStructure<::BasicGraph>> { - int operator()() const { - return -1; - } -}; -} // namespace FlexFlow::PCG::Utils - -TEST(pred_succ_cessors, basic) { - BasicGraph g; - g.add_node(0); - g.add_node(1); - g.add_node(2); - g.add_node(3); - g.add_node(4); - - g.add_edge(0, 2); - g.add_edge(1, 2); - g.add_edge(2, 3); - g.add_edge(2, 4); - - using AnswerMap = std::unordered_map>; - - AnswerMap expected_predecessors; - - expected_predecessors = {{0, {}}, {1, {}}, {2, {0, 1}}, {3, {2}}, {4, {2}}}; - - AnswerMap expected_successors = { - {0, {2}}, {1, {2}}, {2, {3, 4}}, {3, {}}, {4, {}}}; - - std::unordered_set answer; - for (auto const &kv : expected_predecessors) { - answer.clear(); - predecessors>(g, kv.first, &answer); - EXPECT_EQ(kv.second, answer) - << "^^^ Predecessors for node " << kv.first << std::endl; - } - for (auto const &kv : expected_successors) { - answer.clear(); - successors>(g, kv.first, &answer); - EXPECT_EQ(kv.second, answer) - << "^^^ Successors for node " << kv.first << std::endl; - } -} - -TEST(topo_sort, basic) { - BasicGraph g; - g.add_nodes({0, 1, 2, 3}); - g.add_edges({{3, 1}, {3, 0}, {1, 0}, {0, 2}}); - - std::vector topo_answer = {3, 1, 0, 2}; - - std::vector topo_result; - topo_sort(g, &topo_result); - EXPECT_EQ(topo_result, topo_answer); -} - -BasicGraph get_dominator_test_graph() { - BasicGraph g; - g.add_nodes({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); - g.add_edges({{1, 2}, - {1, 7}, - {2, 3}, - {2, 4}, - {3, 6}, - {4, 5}, - {4, 6}, - {5, 6}, - {6, 8}, - {7, 8}, - {8, 9}, - {8, 10}, - {9, 11}, - {10, 11}}); - - return g; -} - -TEST(dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map> answer = {{1, {1}}, - {2, {1, 2}}, - {3, {1, 2, 3}}, - {4, {1, 2, 4}}, - {5, {1, 2, 4, 5}}, - {6, {1, 2, 6}}, - {7, {1, 7}}, - {8, {1, 8}}, - {9, {1, 8, 9}}, - {10, {1, 8, 10}}, - {11, {1, 8, 11}}}; - - EXPECT_EQ(dominators(g), answer); -} - -TEST(post_dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map> answer = {{1, {1, 8, 11}}, - {2, {2, 6, 8, 11}}, - {3, {3, 6, 8, 11}}, - {4, {4, 6, 8, 11}}, - {5, {5, 6, 8, 11}}, - {6, {6, 8, 11}}, - {7, {7, 8, 11}}, - {8, {8, 11}}, - {9, {9, 11}}, - {10, {10, 11}}, - {11, {11}}}; - - EXPECT_EQ(post_dominators(g), answer); -} - -TEST(imm_dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map answer = {{1, 1}, // no immediate dominator - {2, 1}, - {3, 2}, - {4, 2}, - {5, 4}, - {6, 2}, - {7, 1}, - {8, 1}, - {9, 8}, - {10, 8}, - {11, 8}}; - - EXPECT_EQ(imm_dominators(g), answer); -} - -TEST(imm_post_dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map answer = { - {1, 8}, - {2, 6}, - {3, 6}, - {4, 6}, - {5, 6}, - {6, 8}, - {7, 8}, - {8, 11}, - {9, 11}, - {10, 11}, - {11, 11} // no immediate post - // dominator - }; - - EXPECT_EQ(imm_post_dominators(g), answer); -} - -TEST(imm_post_dominators, multisource) { - BasicGraph g; - - g.add_nodes({1, 2, 3, 4, 5}); - g.add_edges({{1, 3}, {2, 3}, {3, 4}, {3, 5}}); - - std::unordered_map answer = { - {-1, 3}, {1, 3}, {2, 3}, {3, 3}, {4, 4}, {5, 5}}; - - auto result = - imm_post_dominators>( - g); - EXPECT_EQ(result, answer); -} - -TEST(transitive_reduction, basic) { - BasicGraph g({1, 2, 3}, {{1, 2}, {2, 3}, {1, 3}}); - - BasicGraph answer({1, 2, 3}, {{1, 2}, {2, 3}}); - - auto result = transitive_reduction(g); - - EXPECT_EQ(result, answer); -} - -TEST(transitive_reduction, medium) { - BasicGraph g({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {1, 5}, - {2, 3}, - {2, 4}, - {2, 6}, - {3, 4}, - {4, 5}, - {4, 6}, - {5, 6}, - }); - - BasicGraph answer({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {2, 3}, - {3, 4}, - {4, 5}, - {5, 6}, - }); - - auto result = transitive_reduction(g); - - EXPECT_EQ(result, answer); -} - -TEST(inplace_transitive_reduction, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {1, 5}, - {2, 3}, - {2, 4}, - {2, 6}, - {3, 4}, - {4, 5}, - {4, 6}, - {5, 6}, - }); - - BasicGraph answer({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {2, 3}, - {3, 4}, - {4, 5}, - {5, 6}, - }); - - inplace_transitive_reduction(g); - - EXPECT_EQ(g, answer); -} - -TEST(roots, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - { - {1, 3}, - {2, 3}, - {3, 4}, - {3, 5}, - {3, 6}, - }); - - std::unordered_set answer{1, 2}; - - auto result = roots(g); - - EXPECT_EQ(result, answer); -} - -TEST(leaves, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - {{1, 3}, {2, 3}, {3, 4}, {3, 5}, {3, 6}}); - - std::unordered_set answer{4, 5, 6}; - - auto result = leaves(g); - - EXPECT_EQ(result, answer); -} - -TEST(descendants, directed) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - {{1, 2}, {2, 3}, {2, 4}, {3, 5}, {4, 5}}); - - std::unordered_set answer{2, 3, 4, 5}; - - auto result = descendants(g, 2); - - EXPECT_EQ(result, answer); -} - -TEST(descendants, undirected) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - {{1, 2}, {2, 3}, {2, 4}, {3, 5}, {4, 5}}); - - std::unordered_set answer{1, 2, 3, 4, 5}; - - auto result = - descendants>(g, 2); - - EXPECT_EQ(result, answer); -} - -TEST(weakly_connected_components, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6}, {{1, 3}, {2, 3}, {4, 5}, {5, 4}}); - - std::unordered_set component1{1, 2, 3}; - std::unordered_set component2{4, 5}; - std::unordered_set component3{6}; - auto result = weakly_connected_components(g); - - EXPECT_EQ(result.size(), 3); - bool component1_found = false; - bool component2_found = false; - bool component3_found = false; - for (std::unordered_set &component : result) { - if (component.size() == component1.size()) { - component1_found = true; - EXPECT_EQ(component, component1); - } else if (component.size() == component2.size()) { - component2_found = true; - EXPECT_EQ(component, component2); - } else if (component.size() == component3.size()) { - component3_found = true; - EXPECT_EQ(component, component3); - } - } - - EXPECT_TRUE(component1_found); - EXPECT_TRUE(component2_found); - EXPECT_TRUE(component3_found); -} diff --git a/lib/compiler/test/test_dot.cc b/lib/compiler/test/test_dot.cc deleted file mode 100644 index 3212971255..0000000000 --- a/lib/compiler/test/test_dot.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "flexflow/utils/dot/record_formatter.h" -#include "gtest/gtest.h" - -TEST(record_formatters, basic) { - RecordFormatter rf, rf2, rf3; - std::ostringstream oss; - oss << "Wo" - << "rld"; - rf << "Hello" - << "World" - << (rf2 << "Inner" - << "World" - << (rf3 << "Even" - << "More" - << "Inner World")) - << "Goodbye" << oss; - - std::ostringstream oss_final; - oss_final << rf; - EXPECT_EQ(oss_final.str(), - "{ Hello | World | { Inner | World | { Even | More | Inner World } " - "} | Goodbye | World }"); -} diff --git a/lib/compiler/test/test_dp.cc b/lib/compiler/test/test_dp.cc deleted file mode 100644 index 01e4189839..0000000000 --- a/lib/compiler/test/test_dp.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "doctest.h" - -using namespace FlexFlow; - -struct TestCostEstimator : public ICostEstimator { - float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - MachineView const &mv) const override { - return 0.1; - } - float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const override { - return 1; - } -}; - -TEST_CASE("optimal_cost") { - auto g(NodeLabelledMultiDiGraph::create< - UnorderedNodeLabelledMultiDiGraph>()); - - Node n0 = g.add_node(InputAttrs()); - Node n1 = g.add_node(RepartitionAttrs(ff_dim_t(0), 2)); - Node n2 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 0)); - Node n3 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 1)); - Node n4 = g.add_node(ConcatAttrs(ff_dim_t(1))); - Node n5 = g.add_node(CombineAttrs(ff_dim_t(0), 2)); - - MultiDiEdge e0(n0, n1, 0, 0); - MultiDiEdge e1(n1, n2, 0, 0); - MultiDiEdge e2(n1, n3, 1, 0); - MultiDiEdge e3(n2, n4, 0, 0); - MultiDiEdge e4(n3, n4, 0, 1); - MultiDiEdge e5(n4, n5, 0, 0); - - g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - - OptimizerPCG pcg = infer_tensor_shape(g); - auto allowed_machine_views = [](PCGOperatorAttrs const &, - MachineResource const &) { - // TODO - return std::unordered_set{}; - }; - MachineResource resource(1, 1, 2); - Strategy s = - optimal_cost(pcg, allowed_machine_views, TestCostEstimator{}, resource); - - // TODO: check result -} diff --git a/lib/compiler/test/test_labelled_open_graph.cc b/lib/compiler/test/test_labelled_open_graph.cc deleted file mode 100644 index 7d85514816..0000000000 --- a/lib/compiler/test/test_labelled_open_graph.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "doctest.h" - -using namespace FlexFlow; - -TEST_CASE("get_subgraph_labelled_open_graph") { - auto g = LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>(); - - int t0 = 100000; - - Node n0 = g.add_node(0); - Node n1 = g.add_node(1); - Node n2 = g.add_node(2); - Node n3 = g.add_node(3); - Node n4 = g.add_node(4); - - MultiDiEdge e0(n0, n1, 0, 0); - MultiDiEdge e1(n0, n2, 1, 0); - MultiDiEdge e2(n1, n3, 0, 0); - MultiDiEdge e3(n2, n3, 0, 1); - MultiDiEdge e4(n3, n4, 0, 0); - OutputMultiDiEdge e5({n4.value(), t0}, n4, 0); - - g.add_edge(e0, 0); - g.add_edge(e1, 1); - g.add_edge(e2, 2); - g.add_edge(e3, 3); - g.add_edge(e4, 4); - g.add_edge(e5, 5); - - auto subgraph0 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::INCLUDE_INPUTS, - OutputSettings::INCLUDE_OUTPUTS); - auto subgraph1 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::INCLUDE_INPUTS, - OutputSettings::EXCLUDE_OUTPUTS); - auto subgraph2 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::EXCLUDE_INPUTS, - OutputSettings::INCLUDE_OUTPUTS); - auto subgraph3 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::EXCLUDE_INPUTS, - OutputSettings::EXCLUDE_OUTPUTS); - - CHECK(get_nodes(subgraph0) == std::unordered_set{n3, n4}); - CHECK(get_nodes(subgraph1) == std::unordered_set{n3, n4}); - CHECK(get_nodes(subgraph2) == std::unordered_set{n3, n4}); - CHECK(get_nodes(subgraph3) == std::unordered_set{n3, n4}); - - std::unordered_set input_set{split_edge(e2).second, - split_edge(e3).second}; - std::unordered_set output_set{e5}; - - CHECK(get_inputs(subgraph0) == input_set); - CHECK(get_inputs(subgraph1) == input_set); - CHECK(get_inputs(subgraph2).empty()); - CHECK(get_inputs(subgraph3).empty()); - - CHECK(get_outputs(subgraph0) == output_set); - CHECK(get_outputs(subgraph1).empty()); - CHECK(get_outputs(subgraph2) == output_set); - CHECK(get_outputs(subgraph3).empty()); - - CHECK(get_edges(subgraph0) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4, e5}); - CHECK(get_edges(subgraph1) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4}); - CHECK(get_edges(subgraph2) == std::unordered_set{e4, e5}); - CHECK(get_edges(subgraph3) == std::unordered_set{e4}); -} diff --git a/lib/compiler/test/test_machine_view.cc b/lib/compiler/test/test_machine_view.cc deleted file mode 100644 index eea084db48..0000000000 --- a/lib/compiler/test/test_machine_view.cc +++ /dev/null @@ -1,33 +0,0 @@ -#include "flexflow/config.h" -#include "flexflow/machine_view.h" -#include "gtest/gtest.h" - -using namespace Legion; -using namespace FlexFlow; - -TEST(machine_view_get_domain, basic) { - MachineView mv; - mv.ndims = 1; - mv.start_device_id = 2; - mv.dim[0] = 2; - mv.stride[0] = 1; - - Domain d; - d.dim = 1; - d.rect_data[0] = 0; - d.rect_data[0 + d.dim] = - 1; // Domain is includes, MachineView is exclusive on hi - - EXPECT_EQ(mv.get_domain(), d); -} - -TEST(machine_view_get_device_id, basic) { - MachineView mv; - mv.ndims = 1; - mv.start_device_id = 2; - mv.dim[0] = 2; - mv.stride[0] = 1; - - EXPECT_EQ(mv.get_device_id({0}), 2); - EXPECT_EQ(mv.get_device_id({1}), 3); -} diff --git a/lib/compiler/test/test_open_graph.cc b/lib/compiler/test/test_open_graph.cc deleted file mode 100644 index d96cdec467..0000000000 --- a/lib/compiler/test/test_open_graph.cc +++ /dev/null @@ -1,102 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "doctest.h" - -using namespace FlexFlow; - -TEST_CASE("get_source_sink_open_graph:basic") { - OpenMultiDiGraph g(LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>()); - - int s0 = 100000; - - Node n0 = g.add_node(); - - g.add_edge(InputMultiDiEdge({s0, n0.value()}, n0, 0)); - - CHECK(get_closed_sources(g) == std::unordered_set{}); - CHECK(get_closed_sinks(g) == std::unordered_set{n0}); - - CHECK(get_open_sources(g) == std::unordered_set{n0}); - CHECK(get_open_sinks(g) == std::unordered_set{}); -} - -TEST_CASE("get_source_sink_open_graph:unconnected") { - OpenMultiDiGraph g(LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>()); - int s0 = 100000; - int t0 = s0 + 1; - - Node n0 = g.add_node(); - Node n1 = g.add_node(); - - g.add_edge(InputMultiDiEdge({s0, n0.value()}, n0, 0)); - g.add_edge(OutputMultiDiEdge({n1.value(), t0}, n1, 0)); - - /* - g: ->n0 - n1-> - */ - - CHECK(get_closed_sources(g) == std::unordered_set{n1}); - CHECK(get_closed_sinks(g) == std::unordered_set{n0}); - - CHECK(get_open_sources(g) == std::unordered_set{n0}); - CHECK(get_open_sinks(g) == std::unordered_set{n1}); -} - -TEST_CASE("get_source_sink_open_graph:complex") { - OpenMultiDiGraph g(LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>()); - int s0 = 100000; - int s1 = s0 + 1; - int t0 = s1 + 1; - int t1 = t0 + 1; - - std::vector ns; - for (int i = 0; i < 8; ++i) { - ns.push_back(g.add_node()); - } - - g.add_edge(InputMultiDiEdge({s0, ns[0].value()}, ns[0], 0)); - g.add_edge(MultiDiEdge(ns[0], ns[1], 0, 0)); - g.add_edge(OutputMultiDiEdge({ns[1].value(), t0}, ns[1], 0)); - g.add_edge(OutputMultiDiEdge({ns[1].value(), t1}, ns[1], 1)); - - g.add_edge(MultiDiEdge(ns[2], ns[3], 0, 0)); - g.add_edge(MultiDiEdge(ns[2], ns[4], 1, 0)); - g.add_edge(MultiDiEdge(ns[4], ns[3], 0, 1)); - g.add_edge(OutputMultiDiEdge({ns[3].value(), t1}, ns[3], 0)); - - g.add_edge(InputMultiDiEdge({s0, ns[5].value()}, ns[5], 0)); - g.add_edge(InputMultiDiEdge({s1, ns[5].value()}, ns[5], 1)); - g.add_edge(MultiDiEdge(ns[5], ns[6], 0, 0)); - g.add_edge(MultiDiEdge(ns[6], ns[7], 0, 0)); - - CHECK(get_closed_sources(g) == std::unordered_set{ns[2]}); - CHECK(get_closed_sinks(g) == std::unordered_set{ns[7]}); - - CHECK(get_open_sources(g) == std::unordered_set{ns[1], ns[5]}); - CHECK(get_open_sinks(g) == std::unordered_set{ns[1], ns[3]}); -} - -TEST_CASE("get_cut") { - auto g = LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>; - - std::vector ns = add_nodes(g, 5); - - int t0 = 100000; - - MultiDiEdge e0(ns[0], ns[1], 0, 0); - MultiDiEdge e1(ns[1], ns[2], 0, 0); - MultiDiEdge e2(ns[1], ns[3], 1, 0); - MultiDiEdge e3(ns[2], ns[4], 0, 0); - MultiDiEdge e4(ns[3], ns[4], 0, 1); - OutputMultiDiEdge e5({ns[4].value(), t0}, ns[4], 0); - - GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; - CHECK(get_cut(g, gs0) == std::unordered_set{e1, e2}); - - GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; - CHECK(get_cut(g, gs1) == std::unordered_set{e3, e4}); -} diff --git a/lib/compiler/test/test_optimal_cost.cc b/lib/compiler/test/test_optimal_cost.cc deleted file mode 100644 index 2d9414ba27..0000000000 --- a/lib/compiler/test/test_optimal_cost.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "test_cost_estimator.h" -#include "test_generator.h" - -/* -Tests whether optimal_cost can give a valid result given random PCG, trivial -allowed machine views, trivial cost estimator and random machine specification. -*/ -TEST_CASE("optimal_cost") { - auto test_allowed_machine_views = [](Operator const &, - MachineSpecification const &) { - return std::unordered_set{make_1d_machine_view(0, 1, 1)}; - }; - rc::check([](ParallelComputationGraph const &g, - MachineSpecification const &machine_spec) { - OptimalCostCache cached_subgraph_costs; - OptimalCostResult result = optimal_cost(g, - test_allowed_machine_views, - TestCostEstimator{}, - machine_spec, - cached_subgraph_costs); - RC_ASSERT(result.runtime > 0); - RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); - }); -} diff --git a/lib/compiler/test/test_parallel_config.cc b/lib/compiler/test/test_parallel_config.cc deleted file mode 100644 index 843879bb0d..0000000000 --- a/lib/compiler/test/test_parallel_config.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "flexflow/config.h" -#include "flexflow/model.h" -#include "gtest/gtest.h" - -using namespace FlexFlow; - -TEST(change_data_parallel_dimensionality, basic_reduce) { - ParallelConfig pc = get_basic_data_parallel_config(8, 4); - - ParallelConfig expected = get_basic_data_parallel_config(8, 2); - - ParallelConfig result = pc.change_data_parallel_dimensionality(2); - - EXPECT_EQ(result, expected); -} - -TEST(change_data_parallel_dimensionality, basic_expand) { - ParallelConfig pc = get_basic_data_parallel_config(8, 2); - - ParallelConfig expected = get_basic_data_parallel_config(8, 4); - - ParallelConfig result = pc.change_data_parallel_dimensionality(4); - - EXPECT_EQ(result, expected); -} diff --git a/lib/compiler/test/test_random_utils.cc b/lib/compiler/test/test_random_utils.cc deleted file mode 100644 index c7b4f9e5c2..0000000000 --- a/lib/compiler/test/test_random_utils.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "flexflow/utils/random_utils.h" -#include "gtest/gtest.h" - -TEST(select_random, basic) { - std::vector values{1, 2, 3, 4}; - std::vector weights{0.1, 0.2, 0.3, 0.4}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.05), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.25), 2); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 3); - EXPECT_EQ(select_random_determistic(values, weights, 0.9), 4); -} - -TEST(select_random, bounds) { - std::vector values{1, 2, 3}; - std::vector weights{0.2, 0.3, 0.5}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.0), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.2), 2); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 3); - EXPECT_EQ(select_random_determistic(values, weights, 1.0), 3); -} - -TEST(select_random, singleton) { - std::vector values{1}; - std::vector weights{1.0}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.0), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 1); - EXPECT_EQ(select_random_determistic(values, weights, 1.0), 1); -} - -TEST(select_random, empty) { - std::vector values{}; - std::vector weights{}; - EXPECT_THROW(select_random_determistic(values, weights, 0.5), - std::invalid_argument); -} - -TEST(select_random, unnormalized_weights) { - std::vector values{1, 2, 3}; - std::vector weights{1.0, 2.0, 2.0}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.1), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 2); - EXPECT_EQ(select_random_determistic(values, weights, 0.9), 3); -} diff --git a/lib/compiler/test/test_substitution_loader.cc b/lib/compiler/test/test_substitution_loader.cc deleted file mode 100644 index b0531b598a..0000000000 --- a/lib/compiler/test/test_substitution_loader.cc +++ /dev/null @@ -1,144 +0,0 @@ -#include "flexflow/substitution.h" -#include "flexflow/substitution_loader.h" -#include "gtest/gtest.h" - -namespace sl = FlexFlow::substitution_loader; -// using namespace FlexFlow::substitution_loader; -using json = nlohmann::json; -using FlexFlow::PCG::create_xfer; -using FlexFlow::PCG::create_xfers; -using FlexFlow::PCG::GraphXfer; - -TEST(substitution_loader, basic) { - // Yes, I know this substitution is not correct. It's just for testing. - - sl::Rule example_rule; - - example_rule.name = "test_rule"; - - sl::Tensor input_tensor1; - input_tensor1.opId = -1; - input_tensor1.tsId = 0; - - sl::Tensor input_tensor2; - input_tensor2.opId = -2; - input_tensor2.tsId = 0; - - sl::Operator srcOp1; - srcOp1.op_type = OP_EW_ADD; - srcOp1.input = {input_tensor1, input_tensor2}; - srcOp1.para = {}; - - sl::Tensor srcOp1Output; - srcOp1Output.opId = 0; - srcOp1Output.tsId = 0; - - sl::Parameter activation_constraint; - activation_constraint.key = PM_ACTI; - activation_constraint.value = AC_MODE_NONE; - - sl::Operator srcOp2; - srcOp2.op_type = OP_LINEAR; - srcOp2.input = {srcOp1Output}; - srcOp2.para = {activation_constraint}; - - sl::Operator dstOp1; - dstOp1.op_type = OP_LINEAR; - dstOp1.input = {input_tensor1}; - dstOp1.para = {activation_constraint}; - - sl::Tensor dstOp1Output; - dstOp1Output.opId = 0; - dstOp1Output.tsId = 0; - - sl::Operator dstOp2; - dstOp2.op_type = OP_LINEAR; - dstOp2.input = {input_tensor2}; - dstOp2.para = {activation_constraint}; - - sl::Tensor dstOp2Output; - dstOp2Output.opId = 1; - dstOp2Output.tsId = 0; - - sl::Operator dstOp3; - dstOp3.op_type = OP_EW_ADD; - dstOp3.input = {dstOp1Output, dstOp2Output}; - dstOp3.para = {}; - - sl::MapOutput map_output; - map_output.srcOpId = 1; - map_output.srcTsId = 0; - map_output.dstOpId = 2; - map_output.dstTsId = 0; - - example_rule.srcOp = {srcOp1, srcOp2}; - example_rule.dstOp = {dstOp1, dstOp2, dstOp3}; - example_rule.mappedOutput = {map_output}; - - GraphXfer *xfer = new GraphXfer(nullptr); - create_xfer(*xfer, example_rule, 2); - - EXPECT_EQ(xfer->name, "test_rule"); - - EXPECT_EQ(xfer->srcOps.size(), 2); - EXPECT_EQ(xfer->srcOps[0]->type, OP_EW_ADD); - EXPECT_EQ(xfer->srcOps[1]->type, OP_LINEAR); - EXPECT_EQ(xfer->srcOps[0]->inputs.size(), 2); - EXPECT_NE(xfer->srcOps[0]->inputs[0], xfer->srcOps[0]->inputs[1]); - EXPECT_EQ(xfer->srcOps[0]->outputs.size(), 1); - EXPECT_EQ(xfer->srcOps[1]->inputs.size(), 1); - EXPECT_EQ(xfer->srcOps[0]->outputs[0], xfer->srcOps[1]->inputs[0]); - EXPECT_EQ(xfer->srcOps[1]->outputs.size(), 1); - - EXPECT_EQ(xfer->dstOps.size(), 3); - EXPECT_EQ(xfer->dstOps[0]->type, OP_LINEAR); - EXPECT_EQ(xfer->dstOps[1]->type, OP_LINEAR); - EXPECT_EQ(xfer->dstOps[2]->type, OP_EW_ADD); - EXPECT_EQ(xfer->dstOps[0]->inputs.size(), 1); - EXPECT_EQ(xfer->dstOps[0]->outputs.size(), 1); - EXPECT_EQ(xfer->dstOps[0]->inputs[0], xfer->srcOps[0]->inputs[0]); - EXPECT_EQ(xfer->dstOps[1]->inputs.size(), 1); - EXPECT_EQ(xfer->dstOps[1]->outputs.size(), 1); - EXPECT_EQ(xfer->dstOps[1]->inputs[0], xfer->srcOps[0]->inputs[1]); - EXPECT_EQ(xfer->dstOps[2]->inputs.size(), 2); - EXPECT_EQ(xfer->dstOps[2]->inputs[0], xfer->dstOps[0]->outputs[0]); - EXPECT_EQ(xfer->dstOps[2]->inputs[1], xfer->dstOps[1]->outputs[0]); - EXPECT_NE(xfer->dstOps[2]->inputs[0], xfer->dstOps[2]->inputs[1]); - EXPECT_EQ(xfer->dstOps[2]->outputs.size(), 1); - - EXPECT_EQ(xfer->mappedOutputs.size(), 1); - EXPECT_NE(xfer->srcOps[1]->outputs[0], xfer->dstOps[2]->outputs[0]); - EXPECT_EQ(xfer->mappedOutputs.at(xfer->srcOps[1]->outputs[0]), - xfer->dstOps[2]->outputs[0]); -} - -TEST(substitution_loader, operator_deserialization) { - json j = { - {"_t", "Operator"}, - {"input", - std::vector{{{"_t", "Tensor"}, {"opId", -2}, {"tsId", 0}}, - {{"_t", "Tensor"}, {"opId", -3}, {"tsId", 0}}}}, - {"para", std::vector{}}, - {"type", "OP_EW_ADD"}, - }; - - sl::Operator o; - from_json(j, o); - - EXPECT_EQ(o.op_type, OP_EW_ADD); - EXPECT_EQ(o.input.size(), 2); - EXPECT_EQ(o.input[0].opId, -2); - EXPECT_EQ(o.input[0].tsId, 0); - EXPECT_EQ(o.input[1].opId, -3); - EXPECT_EQ(o.input[1].tsId, 0); - EXPECT_EQ(o.para.size(), 0); -} - -// TEST(substitution_loader, load_full_file) { -// sl::RuleCollection collection = -// sl::load_rule_collection_from_path("tests/unit/graph_subst_3_v2.json"); -// EXPECT_EQ(collection.rules.size(), 640); - -// std::vector xfers = create_xfers(nullptr, collection, 2); -// EXPECT_EQ(xfers.size(), 640); -// } diff --git a/lib/op-attrs/src/get_output_shapes.cc b/lib/op-attrs/src/get_output_shapes.cc index d649856152..9d007e2f45 100644 --- a/lib/op-attrs/src/get_output_shapes.cc +++ b/lib/op-attrs/src/get_output_shapes.cc @@ -5,6 +5,12 @@ namespace FlexFlow { ParallelTensorShape as_parallel(TensorShape const &); std::vector as_parallel(std::vector const &); +std::vector get_output_shapes( + PCGOperatorAttrs const &op_params, + std::vector const &input_tensor_shapes) { + NOT_IMPLEMENTED(); +} + // TensorShape get_output_shape(AggregateAttrs const &attrs, // TensorShape const &gate_preds, // TensorShape const &gate_assign, diff --git a/lib/pcg/include/pcg/machine_specification.h b/lib/pcg/include/pcg/machine_specification.h index 55f80e3cc0..1b2a02b070 100644 --- a/lib/pcg/include/pcg/machine_specification.h +++ b/lib/pcg/include/pcg/machine_specification.h @@ -11,22 +11,21 @@ struct BandwidthNetworkModelConfig int bandwidth; }; -struct MachineSpecification : public use_visitable_cmp { +struct MachineSpecification { int num_nodes; int num_cpus_per_node; int num_gpus_per_node; float inter_node_bandwidth; - float intra_node_bandwidth; + req intra_node_bandwidth; }; -} // namespace FlexFlow +FF_VISITABLE_STRUCT(MachineSpecification, + num_nodes, + num_cpus_per_node, + num_gpus_per_node, + inter_node_bandwidth, + intra_node_bandwidth); -VISITABLE_STRUCT(::FlexFlow::MachineSpecification, - num_nodes, - num_cpus_per_node, - num_gpus_per_node, - inter_node_bandwidth, - intra_node_bandwidth); -MAKE_VISIT_HASHABLE(::FlexFlow::MachineSpecification); +} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index b482e851d8..afd4206eb1 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -12,9 +12,9 @@ namespace FlexFlow { -struct MachineView : public use_visitable_cmp { - MachineView() = delete; - MachineView(device_id_t const &, StridedRectangle const &); +struct MachineView { + // MachineView() = delete; + // MachineView(device_id_t const &, StridedRectangle const &); std::vector device_ids() const; diff --git a/lib/pcg/include/pcg/operator.h b/lib/pcg/include/pcg/operator.h index c7a49bb57e..d09e25dcf3 100644 --- a/lib/pcg/include/pcg/operator.h +++ b/lib/pcg/include/pcg/operator.h @@ -17,11 +17,12 @@ struct Operator : public use_visitable_cmp { public: PCGOperatorAttrs attrs; + optional name; }; } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::Operator, attrs); +VISITABLE_STRUCT(::FlexFlow::Operator, attrs, name); MAKE_VISIT_HASHABLE(::FlexFlow::Operator); namespace FlexFlow { diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index 28331f441c..25f85ffc48 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -17,7 +17,7 @@ struct side_size_t : public strong_typedef { using strong_typedef::strong_typedef; }; -struct StridedRectangleSide : public use_visitable_cmp { +struct StridedRectangleSide { public: StridedRectangleSide() = delete; StridedRectangleSide(num_points_t const &, int stride); @@ -32,13 +32,15 @@ struct StridedRectangleSide : public use_visitable_cmp { public: num_points_t num_points; - int stride; + req stride; }; -struct StridedRectangle : public use_visitable_cmp { +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, num_points, stride); + +struct StridedRectangle { public: - StridedRectangle() = delete; - StridedRectangle(std::vector const &); + // StridedRectangle() = delete; + // StridedRectangle(std::vector const &); size_t at(FFOrdered const &) const; StridedRectangleSide at(ff_dim_t const &) const; @@ -47,6 +49,9 @@ struct StridedRectangle : public use_visitable_cmp { public: FFOrdered sides; }; + +FF_VISITABLE_STRUCT(StridedRectangle, sides); + } // namespace FlexFlow MAKE_TYPEDEF_HASHABLE(::FlexFlow::num_points_t); @@ -55,10 +60,10 @@ MAKE_TYPEDEF_PRINTABLE(::FlexFlow::num_points_t, "num_points"); MAKE_TYPEDEF_HASHABLE(::FlexFlow::side_size_t); MAKE_TYPEDEF_PRINTABLE(::FlexFlow::side_size_t, "side_size"); -VISITABLE_STRUCT(::FlexFlow::StridedRectangleSide, num_points, stride); -MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangleSide); +// VISITABLE_STRUCT(::FlexFlow::StridedRectangleSide, num_points, stride); +// MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangleSide); -VISITABLE_STRUCT(::FlexFlow::StridedRectangle, sides); -MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangle); +// VISITABLE_STRUCT(::FlexFlow::StridedRectangle, sides); +// MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangle); #endif diff --git a/lib/pcg/src/machine_view.cc b/lib/pcg/src/machine_view.cc index 9edfb09a8e..688ba1628f 100644 --- a/lib/pcg/src/machine_view.cc +++ b/lib/pcg/src/machine_view.cc @@ -3,8 +3,8 @@ namespace FlexFlow { -MachineView::MachineView(device_id_t const &start, StridedRectangle const &rect) - : start(start), rect(rect) {} +// MachineView::MachineView(device_id_t const &start, StridedRectangle const &rect) +// : start(start), rect(rect) {} static StridedRectangle make_1d_rect(int start, int stop, int stride) { assert(stop > start); diff --git a/lib/pcg/src/operator.cc b/lib/pcg/src/operator.cc index 8c79c67464..5cba8584c9 100644 --- a/lib/pcg/src/operator.cc +++ b/lib/pcg/src/operator.cc @@ -4,7 +4,7 @@ namespace FlexFlow { Operator::Operator(PCGOperatorAttrs const &attrs, optional const &name) - : attrs(attrs) {} + : attrs(attrs), name(name) {} Operator::operator PCGOperatorAttrs() const { return attrs; diff --git a/lib/pcg/src/parallel_computation_graph.cc b/lib/pcg/src/parallel_computation_graph.cc new file mode 100644 index 0000000000..609b10edd2 --- /dev/null +++ b/lib/pcg/src/parallel_computation_graph.cc @@ -0,0 +1,37 @@ +#include "pcg/parallel_computation_graph.h" +#include "utils/graph/algorithms.h" + +namespace FlexFlow { + +bool operator==(ParallelComputationGraph const &lhs, ParallelComputationGraph const &rhs) { + return std::hash{}(lhs) == std::hash{}(rhs); +} + +} + +namespace std { + +size_t hash::operator()(FlexFlow::ParallelComputationGraph const &g) const { + using namespace FlexFlow; + + size_t h = 0; + + std::vector ordered_nodes = get_topological_ordering(g.value()); + hash_combine(h, ordered_nodes.size()); + + std::unordered_map node_index; + for (int i = 0; i < ordered_nodes.size(); ++i) { + node_index[ordered_nodes[i]] = i; + hash_combine(h, g.value().at(ordered_nodes[i])); + } + + for (MultiDiEdge const &edge : get_edges(g.value())) { + hash_combine(h, node_index.at(edge.src)); + hash_combine(h, node_index.at(edge.dst)); + hash_combine(h, g.value().at(edge)); + } + + return h; +} + +} \ No newline at end of file diff --git a/lib/pcg/src/strided_rectangle.cc b/lib/pcg/src/strided_rectangle.cc index 29dcae6151..2792db65fe 100644 --- a/lib/pcg/src/strided_rectangle.cc +++ b/lib/pcg/src/strided_rectangle.cc @@ -30,8 +30,8 @@ side_size_t StridedRectangleSide::get_size() const { NOT_IMPLEMENTED(); } -StridedRectangle::StridedRectangle( - std::vector const &sides) - : sides(sides) {} +// StridedRectangle::StridedRectangle( +// std::vector const &sides) +// : sides(sides) {} } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph.h b/lib/utils/include/utils/graph/digraph.h index 4d0014596e..bfe6884c57 100644 --- a/lib/utils/include/utils/graph/digraph.h +++ b/lib/utils/include/utils/graph/digraph.h @@ -33,7 +33,7 @@ struct DiGraphView : virtual public GraphView { using GraphView::GraphView; private: - IDiGraphView &get_ptr() const; + IDiGraphView const &get_ptr() const; friend struct GraphInternal; }; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index bf037105b5..822973e149 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -54,9 +54,9 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { using MultiDiGraphView::MultiDiGraphView; private: - Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + Interface const &get_ptr() const { + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraphView); @@ -118,7 +118,7 @@ struct NodeLabelledMultiDiGraph : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index b292a4ef0d..9d83cebac6 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -53,9 +53,8 @@ struct NodeLabelledOpenMultiDiGraphView using NodeLabelledMultiDiGraphView::NodeLabelledMultiDiGraphView; private: - Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + Interface const &get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -119,7 +118,7 @@ struct NodeLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl) {} Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index 8c8a8b1a1b..501aa9caa4 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -54,6 +54,45 @@ struct OutputLabelledOpenMultiDiSubgraphView // CHECK_NOT_ABSTRACT(OutputLabelledOpenMultiDiSubgraphView); +template +struct ViewOutputLabelledAsOutputLabelledOpen : virtual IOutputLabelledOpenMultiDiGraphView { + ViewOutputLabelledAsOutputLabelledOpen(OutputLabelledMultiDiGraphView const &g) : g(g) {} + + NodeLabel const &at(Node const &n) const override { + return g.at(n); + } + + EdgeLabel const &at(InputMultiDiEdge const &i) const override { + assert(false); + } + + EdgeLabel const &at(MultiDiOutput const &o) const override { + return g.at(o); + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return g.query_nodes(q); + } + + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const override { + return transform(g.query_edges(q.standard_edge_query), + [](MultiDiEdge const &e) { return OpenMultiDiEdge(e); }); + } + + ViewOutputLabelledAsOutputLabelledOpen* clone() const override { + return new ViewOutputLabelledAsOutputLabelledOpen(g); + } + +private: + OutputLabelledMultiDiGraphView const &g; +}; + +template +OutputLabelledOpenMultiDiGraphView view_output_labelled_as_output_labelled_open(OutputLabelledMultiDiGraphView const &g) { + return OutputLabelledOpenMultiDiGraphView::template create>(g); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index d0f94414b7..9b3d982e75 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -14,7 +14,7 @@ struct IOutputLabelledMultiDiGraphView IOutputLabelledMultiDiGraphView & operator=(IOutputLabelledMultiDiGraphView const &) = delete; - virtual OutputLabel const &at(MultiDiOutput const &) = 0; + virtual OutputLabel const &at(MultiDiOutput const &) const = 0; using INodeLabelledMultiDiGraphView::at; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraphView); @@ -31,11 +31,11 @@ struct OutputLabelledMultiDiGraphView OutputLabelledMultiDiGraphView & operator=(OutputLabelledMultiDiGraphView const &) = default; - NodeLabel const &at(Node const &n) const { + virtual NodeLabel const &at(Node const &n) const { return get_ptr().at(n); } - OutputLabel const &at(MultiDiOutput const &o) const { + virtual OutputLabel const &at(MultiDiOutput const &o) const { return get_ptr().at(o); } @@ -56,13 +56,11 @@ struct OutputLabelledMultiDiGraphView } protected: - OutputLabelledMultiDiGraphView(cow_ptr_t ptr) - : NodeLabelledMultiDiGraphView(ptr) {} + using NodeLabelledMultiDiGraphView::NodeLabelledMultiDiGraphView; private: - Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + Interface const &get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -81,7 +79,7 @@ struct OutputLabelledMultiDiGraph Node add_node(NodeLabel const &l) { Node n = get_ptr().add_node(); - nl->add_label(n, l); + nl.get_mutable()->add_label(n, l); return n; } @@ -93,12 +91,12 @@ struct OutputLabelledMultiDiGraph return nl.get_mutable()->get_label(n); } - NodeLabel const &at(Node const &n) const { + NodeLabel const &at(Node const &n) const override { return nl->get_label(n); } void add_output(MultiDiOutput const &o, OutputLabel const &l) { - ol->add_label(o, l); + ol.get_mutable()->add_label(o, l); }; void add_edge(MultiDiOutput const &o, MultiDiInput const &i) { @@ -110,16 +108,17 @@ struct OutputLabelledMultiDiGraph } OutputLabel &at(MultiDiOutput const &o) { - return ol->get_label(o); + return ol.get_mutable()->get_label(o); } - OutputLabel const &at(MultiDiOutput const &o) const { + OutputLabel const &at(MultiDiOutput const &o) const override { return ol->get_label(o); } std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { return get_ptr().query_edges(q); } @@ -139,12 +138,11 @@ struct OutputLabelledMultiDiGraph OutputLabelledMultiDiGraph(cow_ptr_t ptr, cow_ptr_t nl, cow_ptr_t ol) - : OutputLabelledMultiDiGraphView(ptr), nl(nl), - ol(ol) {} + : GraphView(ptr), nl(nl), ol(ol) {} private: Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 3d2ac9d601..986d337a57 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -17,7 +17,8 @@ struct IOutputLabelledOpenMultiDiGraphView template struct OutputLabelledOpenMultiDiGraphView - : virtual NodeLabelledOpenMultiDiGraphView { + : virtual NodeLabelledOpenMultiDiGraphView, + virtual OutputLabelledMultiDiGraphView { private: using Interface = IOutputLabelledOpenMultiDiGraphView; @@ -59,12 +60,10 @@ struct OutputLabelledOpenMultiDiGraphView protected: using NodeLabelledOpenMultiDiGraphView< NodeLabel>::NodeLabelledOpenMultiDiGraphView; - OutputLabelledOpenMultiDiGraphView(cow_ptr_t ptr) : GraphView(ptr) {} private: - Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + Interface const &get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -96,6 +95,11 @@ struct OutputLabelledOpenMultiDiGraph return n; } + void add_node_unsafe(Node const &n, NodeLabel const &l) { + get_ptr().add_node_unsafe(n); + nl.get_mutable()->add_label(n, l); + } + NodePort add_node_port() { return get_ptr().add_node_port(); } @@ -121,14 +125,14 @@ struct OutputLabelledOpenMultiDiGraph } EdgeLabel &at(MultiDiOutput const &o) { - return ol->get_label(o); + return ol.get_mutable()->get_label(o); } EdgeLabel const &at(MultiDiOutput const &o) const override { return ol->get_label(o); } EdgeLabel &at(InputMultiDiEdge const &e) { - return il->get_label(e); + return il.get_mutable()->get_label(e); } EdgeLabel const &at(InputMultiDiEdge const &e) const override { @@ -165,7 +169,7 @@ struct OutputLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl), il(il), ol(ol) {} Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index 70e5e87f93..ae9b02c911 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -32,19 +32,19 @@ struct LabelledMultiDiGraphView operator=(LabelledMultiDiGraphView const &) = default; NodeLabel const &at(Node const &n) const { - return get_ptr()->at(n); + return get_ptr().at(n); } EdgeLabel const &at(MultiDiEdge const &e) const { - return get_ptr()->at(e); + return get_ptr().at(e); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr()->query_nodes(q); + return get_ptr().query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return get_ptr()->query_edges(q); + return get_ptr().query_edges(q); } template @@ -58,8 +58,8 @@ struct LabelledMultiDiGraphView protected: LabelledMultiDiGraphView(cow_ptr_t ptr) : NodeLabelledMultiDiGraphView(ptr) {} - cow_ptr_t get_ptr() const { - return cow_ptr_t(static_cast(*GraphView::ptr)); + Interface const &get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraphView); @@ -84,11 +84,11 @@ struct LabelledMultiDiGraph } NodePort add_node_port() { - return this->get_ptr()->add_node_port(); + return this->get_ptr().add_node_port(); } NodeLabel &at(Node const &n) { - return nl->get_label(n); + return nl.get_mutable()->get_label(n); } NodeLabel const &at(Node const &n) const { @@ -96,20 +96,20 @@ struct LabelledMultiDiGraph } void add_edge(MultiDiEdge const &e, EdgeLabel const &l) { - return this->get_ptr()->add_edge(e, l); + return this->get_ptr().add_edge(e, l); } EdgeLabel &at(MultiDiEdge const &e) { - return el->get_label(e); + return el.get_mutable()->get_label(e); } EdgeLabel const &at(MultiDiEdge const &e) const { return el->get_label(e); } std::unordered_set query_nodes(NodeQuery const &q) const { - return this->get_ptr()->query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return this->get_ptr()->query_edges(q); + return this->get_ptr().query_edges(q); } template @@ -129,8 +129,8 @@ struct LabelledMultiDiGraph cow_ptr_t el) : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} - cow_ptr_t get_ptr() const { - return cow_ptr_t(static_cast(*GraphView::ptr)); + Interface& get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/unordered_label.h b/lib/utils/include/utils/graph/labelled/unordered_label.h index 94c4bffe11..230e286ef8 100644 --- a/lib/utils/include/utils/graph/labelled/unordered_label.h +++ b/lib/utils/include/utils/graph/labelled/unordered_label.h @@ -19,7 +19,8 @@ struct UnorderedLabelling : virtual public ILabelling { } void add_label(Elem const &e, Label const &l) { - label_map.insert({e, l}); + auto p = std::make_pair(e, l); + label_map.insert(p); } UnorderedLabelling *clone() const { diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index 5a227d46ec..82a45a2ad0 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -52,7 +52,7 @@ struct ViewMultiDiGraphAsOutputLabelled return node_label(n); } - virtual OutputLabel &at(MultiDiOutput const &o) override { + virtual OutputLabel const &at(MultiDiOutput const &o) const override { return output_label(o); } @@ -86,6 +86,26 @@ Impl materialize_output_labelled_multidigraph_view( return result; } +template +OutputLabelledOpenMultiDiGraph materialize_output_labelled_open_multidigraph_view(OutputLabelledOpenMultiDiGraphView const &g) { + OutputLabelledOpenMultiDiGraph result = OutputLabelledOpenMultiDiGraph::template create(); + for (Node const &n : get_nodes(g)) { + result.add_node_unsafe(n, g.at(n)); + } + for (OpenMultiDiEdge const &e : get_edges(g)) { + result.add_edge(e); + if (is_input_edge(e)) { + InputMultiDiEdge input_edge = get(e); + result.add_label(input_edge, g.at(input_edge)); + } else { + MultiDiOutput output = is_standard_edge(e) ? static_cast(get(e)) : static_cast(get(e)); + auto tensor = g.at(output); + result.add_label(output, tensor); + } + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph.h index cfb2c7db21..d5d72bbbd7 100644 --- a/lib/utils/include/utils/graph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph.h @@ -30,7 +30,7 @@ struct MultiDiGraphView : virtual DiGraphView { using DiGraphView::DiGraphView; private: - IMultiDiGraphView &get_ptr() const; + IMultiDiGraphView const &get_ptr() const; friend struct GraphInternal; }; diff --git a/lib/utils/include/utils/graph/open_graphs.h b/lib/utils/include/utils/graph/open_graphs.h index 1f8a3692fa..703ad6778f 100644 --- a/lib/utils/include/utils/graph/open_graphs.h +++ b/lib/utils/include/utils/graph/open_graphs.h @@ -34,7 +34,7 @@ struct OpenMultiDiGraphView : virtual MultiDiGraphView { using MultiDiGraphView::MultiDiGraphView; private: - IOpenMultiDiGraphView &get_ptr() const; + IOpenMultiDiGraphView const &get_ptr() const; friend struct GraphInternal; }; @@ -50,6 +50,7 @@ struct OpenMultiDiGraph : virtual OpenMultiDiGraphView { Node add_node(); void add_node_unsafe(Node const &); void remove_node_unsafe(Node const &); + NodePort add_node_port(); void add_edge(Edge const &); void remove_edge(Edge const &); @@ -60,7 +61,7 @@ struct OpenMultiDiGraph : virtual OpenMultiDiGraphView { static typename std::enable_if::value, OpenMultiDiGraph>::type create() { - return make_cow_ptr(); + return OpenMultiDiGraph(make_cow_ptr()); } private: @@ -96,7 +97,7 @@ struct UpwardOpenMultiDiGraphView : virtual MultiDiGraphView { private: using MultiDiGraphView::MultiDiGraphView; - IUpwardOpenMultiDiGraphView &get_ptr() const; + IUpwardOpenMultiDiGraphView const &get_ptr() const; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UpwardOpenMultiDiGraphView); @@ -158,7 +159,7 @@ struct DownwardOpenMultiDiGraphView : virtual MultiDiGraphView { private: using MultiDiGraphView::MultiDiGraphView; - Interface &get_ptr() const; + Interface const &get_ptr() const; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DownwardOpenMultiDiGraphView); diff --git a/lib/utils/include/utils/graph/undirected.h b/lib/utils/include/utils/graph/undirected.h index 24cd07caa9..b32b6e3572 100644 --- a/lib/utils/include/utils/graph/undirected.h +++ b/lib/utils/include/utils/graph/undirected.h @@ -55,7 +55,7 @@ struct UndirectedGraphView : virtual GraphView { friend struct GraphInternal; private: - IUndirectedGraphView &get_ptr() const; + IUndirectedGraphView const &get_ptr() const; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UndirectedGraphView); diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index d62989d65b..d5407adbae 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -164,7 +164,9 @@ DiGraphView apply_contraction(DiGraphView const &g, for (auto const &kv : nodes) { Node from = kv.first; Node into = kv.second; - contractedView = contract_node(contractedView, from, into); + if (from != into) { + contractedView = contract_node(contractedView, from, into); + } } return contractedView; } @@ -347,6 +349,13 @@ std::unordered_set }); } +std::unordered_set get_open_outputs(OpenMultiDiGraphView const &g) { + return transform(g.query_edges(OutputMultiDiEdgeQuery::all()), [](OpenMultiDiEdge const &e) { return get(e); }); +} +std::unordered_set get_open_inputs(OpenMultiDiGraphView const &g) { + return transform(g.query_edges(InputMultiDiEdgeQuery::all()), [](OpenMultiDiEdge const &e) { return get(e); }); +} + std::unordered_map> get_predecessors(DiGraphView const &g, std::unordered_set const &nodes) { @@ -757,4 +766,28 @@ std::unordered_set> return components; } +std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return get_incoming_edges(g, n).size() == 0; + }); +} + +std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return get_outgoing_edges(g, n).size() == 0; + }); +} + +std::unordered_set get_open_sources(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return !g.query_edges(InputMultiDiEdgeQuery::all().with_dst_nodes({n})).empty(); + }); +} + +std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return !g.query_edges(OutputMultiDiEdgeQuery::all().with_src_nodes({n})).empty(); + }); +} + } // namespace FlexFlow diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index ff65df1cf6..1e2f562c19 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -13,9 +13,8 @@ std::unordered_set return get_ptr().query_edges(query); } -IDiGraphView &DiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); +IDiGraphView const &DiGraphView::get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } Node DiGraph::add_node() { @@ -48,6 +47,6 @@ std::unordered_set } IDiGraph &DiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 07d5837b1e..7bbe4cae67 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -23,9 +23,8 @@ std::unordered_set return this->get_ptr().query_edges(q); } -IMultiDiGraphView &MultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); +IMultiDiGraphView const &MultiDiGraphView::get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } Node MultiDiGraph::add_node() { @@ -66,7 +65,7 @@ std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { } IMultiDiGraph &MultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/node.cc b/lib/utils/src/graph/node.cc index a9635aa553..836b5513e9 100644 --- a/lib/utils/src/graph/node.cc +++ b/lib/utils/src/graph/node.cc @@ -53,7 +53,7 @@ std::unordered_set Graph::query_nodes(NodeQuery const &q) const { } IGraph &Graph::get_ptr() const { - return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index 0acda5e6f6..9bbb1bfa3d 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -21,9 +21,8 @@ std::unordered_set return this->get_ptr().query_edges(q); } -IOpenMultiDiGraphView &OpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); +IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } Node OpenMultiDiGraph::add_node() { @@ -51,8 +50,12 @@ std::unordered_set return this->get_ptr().query_edges(q); } +NodePort OpenMultiDiGraph::add_node_port() { + return get_ptr().add_node_port(); +} + IOpenMultiDiGraph &OpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -67,9 +70,9 @@ std::unordered_set return get_ptr().query_edges(q); } -IUpwardOpenMultiDiGraphView &UpwardOpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); +IUpwardOpenMultiDiGraphView const &UpwardOpenMultiDiGraphView::get_ptr() const { + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } Node UpwardOpenMultiDiGraph::add_node() { @@ -98,7 +101,7 @@ std::unordered_set UpwardOpenMultiDiGraph::query_edges( } IUpwardOpenMultiDiGraph &UpwardOpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -113,9 +116,9 @@ std::unordered_set return this->get_ptr().query_edges(q); } -IDownwardOpenMultiDiGraphView &DownwardOpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); +IDownwardOpenMultiDiGraphView const &DownwardOpenMultiDiGraphView::get_ptr() const { + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } Node DownwardOpenMultiDiGraph::add_node() { @@ -150,7 +153,7 @@ std::unordered_set } IDownwardOpenMultiDiGraph &DownwardOpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 3b3a1b0aed..41ecf3c436 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -72,7 +72,7 @@ std::unordered_set if (include_src == SourceSettings::INCLUDE_SOURCE_NODES) { result = set_union(result, srcs); } - if (include_sink == SinkSettings::EXCLUDE_SINK_NODES) { + if (include_sink == SinkSettings::INCLUDE_SINK_NODES) { result = set_union(result, sinks); } return result; @@ -103,12 +103,12 @@ SplitAST sp_decomposition(DiGraphView const &g) { sources, {bottleneck.value()}, SourceSettings::INCLUDE_SOURCE_NODES, - SinkSettings::INCLUDE_SINK_NODES)), + SinkSettings::EXCLUDE_SINK_NODES)), sp_decomposition(source_to_sink_subgraph( g, {bottleneck.value()}, sinks, - SourceSettings::EXCLUDE_SOURCE_NODES, + SourceSettings::INCLUDE_SOURCE_NODES, SinkSettings::INCLUDE_SINK_NODES))); } else { return parallel_decomposition(g); @@ -195,6 +195,13 @@ struct ToFinalAST { variant to_final_ast(SplitAST const &ast) { return visit(ToFinalAST{}, ast); } + +SerialParallelDecomposition + get_serial_parallel_decomposition(DiGraphView const &g) { + SplitAST ast = sp_decomposition(g); + return to_final_ast(ast); +} + struct GetNodes { template std::unordered_set operator()(T const &t) { diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index 414b350a89..166a9efa36 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -26,7 +26,7 @@ void UndirectedGraph::remove_edge(UndirectedEdge const &e) { } IUndirectedGraph &UndirectedGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -50,8 +50,8 @@ std::unordered_set return this->get_ptr().query_nodes(q); } -IUndirectedGraphView &UndirectedGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( +IUndirectedGraphView const &UndirectedGraphView::get_ptr() const { + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index 062dca6858..5a8c6e9f93 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -469,7 +469,7 @@ std::unordered_set UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), inputs(inputs) {} + : g(g), nodes(nodes), inputs(transform(get_cut_set(g, nodes), to_inputmultidiedge)) {} UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { return new UpwardOpenMultiDiSubgraphView(g, nodes); @@ -477,11 +477,11 @@ UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { std::unordered_set UpwardOpenMultiDiSubgraphView::query_edges( OpenMultiDiEdgeQuery const &q) const { - std::unordered_set result = - g.query_edges(OpenMultiDiEdgeQuery( - q.input_edge_query.with_dst_nodes(nodes), - q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), - OutputMultiDiEdgeQuery::none())); + OpenMultiDiEdgeQuery subgraph_query( + q.input_edge_query.with_dst_nodes(nodes), + q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), + OutputMultiDiEdgeQuery::none()); + std::unordered_set result = g.query_edges(subgraph_query); extend(result, query_edge(inputs, q.input_edge_query.with_dst_nodes(nodes))); return result; } @@ -493,16 +493,16 @@ std::unordered_set DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes) {} + : g(g), nodes(nodes), outputs(transform(get_cut_set(g, nodes), to_outputmultidiedge)) {} std::unordered_set DownwardOpenMultiDiSubgraphView::query_edges( OpenMultiDiEdgeQuery const &q) const { - std::unordered_set result = - g.query_edges(OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), - q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), - q.output_edge_query.with_src_nodes(nodes))); + OpenMultiDiEdgeQuery subgraph_query( + InputMultiDiEdgeQuery::none(), + q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), + q.output_edge_query.with_src_nodes(nodes)); + std::unordered_set result = g.query_edges(subgraph_query); extend(result, query_edge(outputs, q.output_edge_query.with_src_nodes(nodes))); return result; From 6211b84e84a8ba471783d7b7c4f6854b4d59c884 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 15 Nov 2023 17:13:45 -0500 Subject: [PATCH 03/32] format --- lib/compiler/include/compiler/cost_estimate.h | 3 +- .../include/compiler/machine_mapping.h | 18 ++--- .../include/compiler/unity_algorithm.h | 2 +- lib/compiler/src/graph_utils.cc | 9 ++- lib/compiler/src/machine_mapping.cc | 67 +++++++++++-------- lib/compiler/src/unity_algorithm.cc | 11 ++- lib/compiler/test/src/test_generator.h | 2 +- .../test/src/test_labelled_open_graph.cc | 11 +-- lib/compiler/test/src/test_open_graph.cc | 6 +- lib/compiler/test/src/test_optimal_cost.cc | 31 +++++---- lib/op-attrs/src/get_output_shapes.cc | 2 +- .../include/pcg/parallel_computation_graph.h | 5 +- lib/pcg/include/pcg/strided_rectangle.h | 4 +- lib/pcg/src/machine_view.cc | 3 +- lib/pcg/src/parallel_computation_graph.cc | 13 ++-- .../include/substitutions/substitution.h | 4 +- lib/utils/include/utils/graph/algorithms.h | 6 +- .../utils/graph/labelled/node_labelled.h | 6 +- .../utils/graph/labelled/node_labelled_open.h | 6 +- .../include/utils/graph/labelled/open_views.h | 21 ++++-- .../utils/graph/labelled/output_labelled.h | 5 +- .../graph/labelled/output_labelled_open.h | 9 +-- .../utils/graph/labelled/standard_labelled.h | 2 +- .../include/utils/graph/labelled/views.h | 25 +++++-- lib/utils/src/graph/algorithms.cc | 20 ++++-- lib/utils/src/graph/multidigraph.cc | 3 +- lib/utils/src/graph/open_graphs.cc | 6 +- lib/utils/src/graph/views.cc | 6 +- 28 files changed, 189 insertions(+), 117 deletions(-) diff --git a/lib/compiler/include/compiler/cost_estimate.h b/lib/compiler/include/compiler/cost_estimate.h index 3791292529..557f51a7ca 100644 --- a/lib/compiler/include/compiler/cost_estimate.h +++ b/lib/compiler/include/compiler/cost_estimate.h @@ -45,7 +45,8 @@ struct CostEstimator { } private: - CostEstimator(std::shared_ptr implementation_ptr) : implementation_ptr(implementation_ptr) {} + CostEstimator(std::shared_ptr implementation_ptr) + : implementation_ptr(implementation_ptr) {} std::shared_ptr implementation_ptr; }; diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 9f9d97937d..aa9152dcd6 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -9,13 +9,14 @@ namespace FlexFlow { -using SubParallelComputationGraphView = OutputLabelledOpenMultiDiGraphView; +using SubParallelComputationGraphView = + OutputLabelledOpenMultiDiGraphView; struct MachineMapping { static MachineMapping combine(MachineMapping const &, MachineMapping const &); static bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); - + req> machine_views; }; FF_VISITABLE_STRUCT(MachineMapping, machine_views); @@ -24,11 +25,10 @@ struct OptimalCostState { SerialParallelDecomposition subgraph; req resource; // req> given_machine_views; - // req> frontier_machine_views; + // req> + // frontier_machine_views; }; -FF_VISITABLE_STRUCT(OptimalCostState, - subgraph, - resource); +FF_VISITABLE_STRUCT(OptimalCostState, subgraph, resource); struct OptimalCostResult { static OptimalCostResult sequential_combine(OptimalCostResult const &s1, @@ -72,9 +72,11 @@ namespace std { template <> struct hash> { - size_t operator()(std::unordered_map const &g) const; + size_t operator()( + std::unordered_map const &g) + const; }; -}; +}; // namespace std #endif diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index 81e8375948..a87bddcc3a 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -45,6 +45,6 @@ struct hash { size_t operator()(FlexFlow::Strategy const &) const; }; -} +} // namespace std #endif diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 04e96c66ed..e0134d6dd8 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -15,11 +15,10 @@ SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { auto g = pcg.value(); auto g_ = view_output_labelled_as_output_labelled_open(g); auto subpcg = materialize_output_labelled_open_multidigraph_view< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling, - UnorderedLabelling - >(g_); + AdjacencyOpenMultiDiGraph, + UnorderedLabelling, + UnorderedLabelling, + UnorderedLabelling>(g_); return subpcg; } diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 671c59a94f..2bdd7de1e2 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -94,7 +94,8 @@ GraphSplit float estimate_cost(SubParallelComputationGraphView const &g, CostEstimator const &estimator, MachineMapping const &device_mapping, - std::unordered_map const &frontier_machine_views) { + std::unordered_map const + &frontier_machine_views) { return 0.1; } @@ -103,19 +104,20 @@ void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { } struct OptimalCost { - OptimalCost( - SubParallelComputationGraphView const &g, - CostEstimator const &cost_estimator, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const &frontier_machine_views, - std::function( - Operator const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimalCostCache &cached_subgraph_costs) + OptimalCost(SubParallelComputationGraphView const &g, + CostEstimator const &cost_estimator, + MachineSpecification const &resource, + std::unordered_map const &given_machine_views, + std::unordered_map const + &frontier_machine_views, + std::function( + Operator const &, MachineSpecification const &)> const + &allowed_machine_views, + OptimalCostCache &cached_subgraph_costs) : g(g), cost_estimator(cost_estimator), resource(resource), given_machine_views(restrict_keys(given_machine_views, get_nodes(g))), - frontier_machine_views(restrict_keys(frontier_machine_views, get_edges(g))), + frontier_machine_views( + restrict_keys(frontier_machine_views, get_edges(g))), allowed_machine_views(allowed_machine_views), cached_subgraph_costs(cached_subgraph_costs) {} @@ -131,7 +133,8 @@ struct OptimalCost { template OptimalCostResult operator()(T const &t) const { - OptimalCostState state{t, resource/*, given_machine_views, frontier_machine_views*/}; + OptimalCostState state{ + t, resource /*, given_machine_views, frontier_machine_views*/}; optional cached_result = cached_subgraph_costs.load(state); @@ -150,8 +153,10 @@ struct OptimalCost { SerialParallelDecomposition post_decompn = decomposed.second; GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); - SubParallelComputationGraphView pre_graph = get_subgraph(g, graph_split.first); - SubParallelComputationGraphView post_graph = get_subgraph(g, graph_split.second); + SubParallelComputationGraphView pre_graph = + get_subgraph(g, graph_split.first); + SubParallelComputationGraphView post_graph = + get_subgraph(g, graph_split.second); std::unordered_set post_graph_sources = get_closed_sources(post_graph); @@ -165,9 +170,11 @@ struct OptimalCost { for (MachineView const &mv : allowed_machine_views(g.at(split_point), resource)) { - std::unordered_map new_given_machine_views = given_machine_views; + std::unordered_map new_given_machine_views = + given_machine_views; new_given_machine_views.emplace(split_point, mv); - std::unordered_map new_frontier_machine_views = frontier_machine_views; + std::unordered_map + new_frontier_machine_views = frontier_machine_views; new_frontier_machine_views.emplace(split_edge, mv); minimize_runtime(optimal_result, OptimalCostResult::sequential_combine( @@ -198,8 +205,10 @@ struct OptimalCost { SerialParallelDecomposition decompn2 = decomposed.second; GraphSplit graph_split = get_graph_split(decompn1, decompn2); - SubParallelComputationGraphView g1 = get_subgraph(g, graph_split.first), - g2 = get_subgraph(g, graph_split.second); + SubParallelComputationGraphView g1 = get_subgraph( + g, graph_split.first), + g2 = get_subgraph( + g, graph_split.second); OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( visit(OptimalCost(g1, @@ -225,16 +234,16 @@ struct OptimalCost { visit(OptimalCost(g1, cost_estimator, resource_split.first, - given_machine_views, - frontier_machine_views, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn1), visit(OptimalCost(g2, cost_estimator, resource_split.second, - given_machine_views, - frontier_machine_views, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn2))); @@ -248,13 +257,16 @@ struct OptimalCost { assert(contains(allowed_machine_views(g.at(node), resource), source_machine_view.value())); MachineMapping mv_map{given_machine_views}; - return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), mv_map}; + return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), + mv_map}; } else { OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (auto mv : allowed_machine_views(g.at(node), resource)) { MachineMapping mv_map{{{node, mv}}}; - minimize_runtime(optimal_result, - {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), mv_map}); + minimize_runtime( + optimal_result, + {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), + mv_map}); } return optimal_result; } @@ -269,7 +281,8 @@ OptimalCostResult CostEstimator const &cost_estimator, MachineSpecification const &resources, OptimalCostCache &cached_subgraph_costs) { - SerialParallelDecomposition sp_decomposition = get_serial_parallel_decomposition(g); + SerialParallelDecomposition sp_decomposition = + get_serial_parallel_decomposition(g); SubParallelComputationGraph subpcg = pcg_to_subpcg(g); return visit(OptimalCost(subpcg, cost_estimator, diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 16671b080a..c89bf04b25 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -37,8 +37,13 @@ Strategy DeduplicatedPriorityQueue, StrategyRuntimeCmp> candidates; - OptimalCostResult initial_pcg_result = optimal_cost(pcg, allowed_machine_views, cost_estimator, resources, cached_subgraph_costs); - Strategy initial_result{pcg, initial_pcg_result.machine_mapping, initial_pcg_result.runtime}; + OptimalCostResult initial_pcg_result = optimal_cost(pcg, + allowed_machine_views, + cost_estimator, + resources, + cached_subgraph_costs); + Strategy initial_result{ + pcg, initial_pcg_result.machine_mapping, initial_pcg_result.runtime}; Strategy best_result = initial_result; candidates.push(initial_result); @@ -88,4 +93,4 @@ size_t hash::operator()(FlexFlow::Strategy const &s) const { return h; } -} +} // namespace std diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h index 23a79abbe0..6566c8c2de 100644 --- a/lib/compiler/test/src/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_TEST_GENERATOR_H #include "compiler/machine_mapping.h" -#include "substitutions/sub_parallel_computation_graph.h" #include "pcg/computation_graph.h" #include "rapidcheck.h" +#include "substitutions/sub_parallel_computation_graph.h" using namespace FlexFlow; diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index 78ea1ece55..82c247e0d2 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -67,11 +67,12 @@ TEST_CASE("get_subgraph_open_graph") { CHECK(bool(get_open_outputs(subgraph3).empty())); CHECK(bool(get_edges(subgraph0) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4, e5})); + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4, e5})); CHECK(bool(get_edges(subgraph1) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4})); - CHECK(bool(get_edges(subgraph2) == std::unordered_set{e4, e5})); + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4})); + CHECK(bool(get_edges(subgraph2) == + std::unordered_set{e4, e5})); CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); } diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc index ea1108c291..00cb4ca890 100644 --- a/lib/compiler/test/src/test_open_graph.cc +++ b/lib/compiler/test/src/test_open_graph.cc @@ -11,7 +11,8 @@ TEST_CASE("get_source_sink_open_graph") { Node n0 = g.add_node(); NodePort p0 = g.add_node_port(); - InputMultiDiEdge e0{n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; + InputMultiDiEdge e0{ + n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; g.add_edge(e0); CHECK(bool(get_closed_sources(g) == std::unordered_set{})); @@ -63,7 +64,8 @@ TEST_CASE("get_cut") { MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; - OutputMultiDiEdge e5{ns[4], g.add_node_port(), std::make_pair(ns[4].value(), ns[4].value())}; + OutputMultiDiEdge e5{ + ns[4], g.add_node_port(), std::make_pair(ns[4].value(), ns[4].value())}; g.add_edge(e0); g.add_edge(e1); diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 87f9d06342..7eeb118c57 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -27,25 +27,28 @@ allowed machine views, trivial cost estimator and random machine specification. // } TEST_CASE("optimal_cost_0") { - auto pcg = OutputLabelledMultiDiGraph::template create< - AdjacencyMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling - >(); + auto pcg = + OutputLabelledMultiDiGraph::template create< + AdjacencyMultiDiGraph, + UnorderedLabelling, + UnorderedLabelling>(); Node n0 = pcg.add_node(Operator(InputAttrs{}, "input")); - Node n1 = pcg.add_node(Operator(LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, "linear")); + Node n1 = pcg.add_node(Operator( + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, + "linear")); MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; pcg.add_edge(e); pcg.add_output(e, - ParallelTensor(ParallelTensorDims({2, 1}), - DataType::FLOAT, - CreateGrad::YES)); + ParallelTensor(ParallelTensorDims({2, 1}), + DataType::FLOAT, + CreateGrad::YES)); auto test_allowed_machine_views = [](Operator const &, MachineSpecification const &) { - return std::unordered_set{make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; + return std::unordered_set{ + make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; }; CostEstimator estimator = CostEstimator::create(); @@ -54,7 +57,11 @@ TEST_CASE("optimal_cost_0") { OptimalCostCache cached_results; - OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), test_allowed_machine_views, estimator, machine_spec, cached_results); + OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), + test_allowed_machine_views, + estimator, + machine_spec, + cached_results); CHECK(bool(result.runtime > 0)); -} \ No newline at end of file +} diff --git a/lib/op-attrs/src/get_output_shapes.cc b/lib/op-attrs/src/get_output_shapes.cc index 9d007e2f45..c20d4be34c 100644 --- a/lib/op-attrs/src/get_output_shapes.cc +++ b/lib/op-attrs/src/get_output_shapes.cc @@ -8,7 +8,7 @@ std::vector as_parallel(std::vector const &); std::vector get_output_shapes( PCGOperatorAttrs const &op_params, std::vector const &input_tensor_shapes) { - NOT_IMPLEMENTED(); + NOT_IMPLEMENTED(); } // TensorShape get_output_shape(AggregateAttrs const &attrs, diff --git a/lib/pcg/include/pcg/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph.h index 2342cd08fa..39a69a80ab 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph.h @@ -15,7 +15,8 @@ struct ParallelComputationGraph }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(ParallelComputationGraph); -bool operator==(ParallelComputationGraph const &, ParallelComputationGraph const &); +bool operator==(ParallelComputationGraph const &, + ParallelComputationGraph const &); } // namespace FlexFlow @@ -25,6 +26,6 @@ template <> struct hash { size_t operator()(FlexFlow::ParallelComputationGraph const &g) const; }; -} +} // namespace std #endif diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index 25f85ffc48..179fff080f 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -35,7 +35,9 @@ struct StridedRectangleSide { req stride; }; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, num_points, stride); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, + num_points, + stride); struct StridedRectangle { public: diff --git a/lib/pcg/src/machine_view.cc b/lib/pcg/src/machine_view.cc index 688ba1628f..f146482141 100644 --- a/lib/pcg/src/machine_view.cc +++ b/lib/pcg/src/machine_view.cc @@ -3,7 +3,8 @@ namespace FlexFlow { -// MachineView::MachineView(device_id_t const &start, StridedRectangle const &rect) +// MachineView::MachineView(device_id_t const &start, StridedRectangle const +// &rect) // : start(start), rect(rect) {} static StridedRectangle make_1d_rect(int start, int stop, int stride) { diff --git a/lib/pcg/src/parallel_computation_graph.cc b/lib/pcg/src/parallel_computation_graph.cc index 609b10edd2..011c40eb4c 100644 --- a/lib/pcg/src/parallel_computation_graph.cc +++ b/lib/pcg/src/parallel_computation_graph.cc @@ -3,15 +3,18 @@ namespace FlexFlow { -bool operator==(ParallelComputationGraph const &lhs, ParallelComputationGraph const &rhs) { - return std::hash{}(lhs) == std::hash{}(rhs); +bool operator==(ParallelComputationGraph const &lhs, + ParallelComputationGraph const &rhs) { + return std::hash{}(lhs) == + std::hash{}(rhs); } -} +} // namespace FlexFlow namespace std { -size_t hash::operator()(FlexFlow::ParallelComputationGraph const &g) const { +size_t hash::operator()( + FlexFlow::ParallelComputationGraph const &g) const { using namespace FlexFlow; size_t h = 0; @@ -34,4 +37,4 @@ size_t hash::operator()(FlexFlow::ParallelCo return h; } -} \ No newline at end of file +} // namespace std diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index 98471a8fbd..8dbe4e66cf 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -28,12 +28,12 @@ SubParallelComputationGraph } // namespace FlexFlow -namespace std{ +namespace std { template <> struct hash { size_t operator()(FlexFlow::Substitution const &) const; }; -}; +}; // namespace std #endif diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index fc3d219e57..cee5445190 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -106,8 +106,10 @@ std::unordered_set get_node_edges(UndirectedGraphView const &, std::unordered_set get_outputs(MultiDiGraphView const &); std::unordered_set get_inputs(MultiDiGraphView const &); -std::unordered_set get_open_outputs(OpenMultiDiGraphView const &); -std::unordered_set get_open_inputs(OpenMultiDiGraphView const &); +std::unordered_set + get_open_outputs(OpenMultiDiGraphView const &); +std::unordered_set + get_open_inputs(OpenMultiDiGraphView const &); std::unordered_set get_incoming_edges(MultiDiGraphView const &, Node const &); diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 822973e149..64de380f9c 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -55,8 +55,7 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraphView); @@ -118,8 +117,7 @@ struct NodeLabelledMultiDiGraph : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} Interface &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 9d83cebac6..ff069060ca 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -10,7 +10,8 @@ struct INodeLabelledOpenMultiDiGraphView : virtual INodeLabelledMultiDiGraphView, virtual IOpenMultiDiGraphView { INodeLabelledOpenMultiDiGraphView() = default; - INodeLabelledOpenMultiDiGraphView(INodeLabelledOpenMultiDiGraphView const &) = delete; + INodeLabelledOpenMultiDiGraphView(INodeLabelledOpenMultiDiGraphView const &) = + delete; INodeLabelledOpenMultiDiGraphView & operator=(INodeLabelledOpenMultiDiGraphView const &) = delete; }; @@ -118,8 +119,7 @@ struct NodeLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl) {} Interface &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index 501aa9caa4..8323fcf9dc 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -43,7 +43,7 @@ struct OutputLabelledOpenMultiDiSubgraphView return SubgraphView(g, nodes).query_edges(q); } - OutputLabelledOpenMultiDiSubgraphView* clone() const override { + OutputLabelledOpenMultiDiSubgraphView *clone() const override { return new OutputLabelledOpenMultiDiSubgraphView(g, nodes); } @@ -55,8 +55,11 @@ struct OutputLabelledOpenMultiDiSubgraphView // CHECK_NOT_ABSTRACT(OutputLabelledOpenMultiDiSubgraphView); template -struct ViewOutputLabelledAsOutputLabelledOpen : virtual IOutputLabelledOpenMultiDiGraphView { - ViewOutputLabelledAsOutputLabelledOpen(OutputLabelledMultiDiGraphView const &g) : g(g) {} +struct ViewOutputLabelledAsOutputLabelledOpen + : virtual IOutputLabelledOpenMultiDiGraphView { + ViewOutputLabelledAsOutputLabelledOpen( + OutputLabelledMultiDiGraphView const &g) + : g(g) {} NodeLabel const &at(Node const &n) const override { return g.at(n); @@ -77,10 +80,10 @@ struct ViewOutputLabelledAsOutputLabelledOpen : virtual IOutputLabelledOpenMulti std::unordered_set query_edges(OpenMultiDiEdgeQuery const &q) const override { return transform(g.query_edges(q.standard_edge_query), - [](MultiDiEdge const &e) { return OpenMultiDiEdge(e); }); + [](MultiDiEdge const &e) { return OpenMultiDiEdge(e); }); } - ViewOutputLabelledAsOutputLabelledOpen* clone() const override { + ViewOutputLabelledAsOutputLabelledOpen *clone() const override { return new ViewOutputLabelledAsOutputLabelledOpen(g); } @@ -89,8 +92,12 @@ struct ViewOutputLabelledAsOutputLabelledOpen : virtual IOutputLabelledOpenMulti }; template -OutputLabelledOpenMultiDiGraphView view_output_labelled_as_output_labelled_open(OutputLabelledMultiDiGraphView const &g) { - return OutputLabelledOpenMultiDiGraphView::template create>(g); +OutputLabelledOpenMultiDiGraphView + view_output_labelled_as_output_labelled_open( + OutputLabelledMultiDiGraphView const &g) { + return OutputLabelledOpenMultiDiGraphView:: + template create< + ViewOutputLabelledAsOutputLabelledOpen>(g); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index 9b3d982e75..03f44fc5ee 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -118,7 +118,7 @@ struct OutputLabelledMultiDiGraph std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } - + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { return get_ptr().query_edges(q); } @@ -142,8 +142,7 @@ struct OutputLabelledMultiDiGraph private: Interface &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } cow_ptr_t nl; cow_ptr_t ol; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 986d337a57..17cf7fa7af 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -169,8 +169,7 @@ struct OutputLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl), il(il), ol(ol) {} Interface &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } cow_ptr_t nl; @@ -179,8 +178,10 @@ struct OutputLabelledOpenMultiDiGraph }; template -void add_label(OutputLabelledOpenMultiDiGraph &g, OpenMultiDiEdge const &e, EdgeLabel const &l) { - visit([&](const auto &e) { g.add_label(e, l); }, e); +void add_label(OutputLabelledOpenMultiDiGraph &g, + OpenMultiDiEdge const &e, + EdgeLabel const &l) { + visit([&](auto const &e) { g.add_label(e, l); }, e); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index ae9b02c911..d47d7fdbc0 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -129,7 +129,7 @@ struct LabelledMultiDiGraph cow_ptr_t el) : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} - Interface& get_ptr() const { + Interface &get_ptr() const { return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index 82a45a2ad0..84bb20d327 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_VIEWS_H #include "node_labelled.h" -#include "standard_labelled.h" #include "output_labelled_open.h" +#include "standard_labelled.h" namespace FlexFlow { @@ -86,9 +86,21 @@ Impl materialize_output_labelled_multidigraph_view( return result; } -template -OutputLabelledOpenMultiDiGraph materialize_output_labelled_open_multidigraph_view(OutputLabelledOpenMultiDiGraphView const &g) { - OutputLabelledOpenMultiDiGraph result = OutputLabelledOpenMultiDiGraph::template create(); +template +OutputLabelledOpenMultiDiGraph + materialize_output_labelled_open_multidigraph_view( + OutputLabelledOpenMultiDiGraphView const &g) { + OutputLabelledOpenMultiDiGraph result = + OutputLabelledOpenMultiDiGraph::template create< + Impl, + NodeLabelImpl, + InputLabelImpl, + OutputLabelImpl>(); for (Node const &n : get_nodes(g)) { result.add_node_unsafe(n, g.at(n)); } @@ -98,7 +110,10 @@ OutputLabelledOpenMultiDiGraph materialize_output_labell InputMultiDiEdge input_edge = get(e); result.add_label(input_edge, g.at(input_edge)); } else { - MultiDiOutput output = is_standard_edge(e) ? static_cast(get(e)) : static_cast(get(e)); + MultiDiOutput output = + is_standard_edge(e) + ? static_cast(get(e)) + : static_cast(get(e)); auto tensor = g.at(output); result.add_label(output, tensor); } diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index d5407adbae..72242709e2 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -349,11 +349,17 @@ std::unordered_set }); } -std::unordered_set get_open_outputs(OpenMultiDiGraphView const &g) { - return transform(g.query_edges(OutputMultiDiEdgeQuery::all()), [](OpenMultiDiEdge const &e) { return get(e); }); +std::unordered_set + get_open_outputs(OpenMultiDiGraphView const &g) { + return transform( + g.query_edges(OutputMultiDiEdgeQuery::all()), + [](OpenMultiDiEdge const &e) { return get(e); }); } -std::unordered_set get_open_inputs(OpenMultiDiGraphView const &g) { - return transform(g.query_edges(InputMultiDiEdgeQuery::all()), [](OpenMultiDiEdge const &e) { return get(e); }); +std::unordered_set + get_open_inputs(OpenMultiDiGraphView const &g) { + return transform( + g.query_edges(InputMultiDiEdgeQuery::all()), + [](OpenMultiDiEdge const &e) { return get(e); }); } std::unordered_map> @@ -780,13 +786,15 @@ std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g) { std::unordered_set get_open_sources(OpenMultiDiGraphView const &g) { return filter(get_nodes(g), [&](Node const &n) { - return !g.query_edges(InputMultiDiEdgeQuery::all().with_dst_nodes({n})).empty(); + return !g.query_edges(InputMultiDiEdgeQuery::all().with_dst_nodes({n})) + .empty(); }); } std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g) { return filter(get_nodes(g), [&](Node const &n) { - return !g.query_edges(OutputMultiDiEdgeQuery::all().with_src_nodes({n})).empty(); + return !g.query_edges(OutputMultiDiEdgeQuery::all().with_src_nodes({n})) + .empty(); }); } diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 7bbe4cae67..d0ed98b29b 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -24,7 +24,8 @@ std::unordered_set } IMultiDiGraphView const &MultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } Node MultiDiGraph::add_node() { diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index 9bbb1bfa3d..d545b45a31 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -22,7 +22,8 @@ std::unordered_set } IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } Node OpenMultiDiGraph::add_node() { @@ -116,7 +117,8 @@ std::unordered_set return this->get_ptr().query_edges(q); } -IDownwardOpenMultiDiGraphView const &DownwardOpenMultiDiGraphView::get_ptr() const { +IDownwardOpenMultiDiGraphView const & + DownwardOpenMultiDiGraphView::get_ptr() const { return *std::dynamic_pointer_cast( GraphView::ptr.get()); } diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index 5a8c6e9f93..bf4f7351c0 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -469,7 +469,8 @@ std::unordered_set UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), inputs(transform(get_cut_set(g, nodes), to_inputmultidiedge)) {} + : g(g), nodes(nodes), + inputs(transform(get_cut_set(g, nodes), to_inputmultidiedge)) {} UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { return new UpwardOpenMultiDiSubgraphView(g, nodes); @@ -493,7 +494,8 @@ std::unordered_set DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), outputs(transform(get_cut_set(g, nodes), to_outputmultidiedge)) {} + : g(g), nodes(nodes), + outputs(transform(get_cut_set(g, nodes), to_outputmultidiedge)) {} std::unordered_set DownwardOpenMultiDiSubgraphView::query_edges( From fb58a99913681970c01233a139ea3bfb34c00fea Mon Sep 17 00:00:00 2001 From: wmdi Date: Tue, 23 Jan 2024 22:46:15 -0500 Subject: [PATCH 04/32] fmt --- lib/utils/include/utils/graph/labelled/node_labelled.h | 9 +++------ .../include/utils/graph/labelled/node_labelled_open.h | 9 +++------ lib/utils/include/utils/graph/labelled/output_labelled.h | 9 +++------ .../include/utils/graph/labelled/output_labelled_open.h | 9 +++------ .../include/utils/graph/labelled/standard_labelled.h | 9 +++------ lib/utils/src/graph/multidigraph.cc | 3 +-- lib/utils/src/graph/undirected.cc | 3 ++- 7 files changed, 18 insertions(+), 33 deletions(-) diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 109855965d..ded049f224 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -54,8 +54,7 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraphView); @@ -117,13 +116,11 @@ struct NodeLabelledMultiDiGraph : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index c77d75c37a..fab6695070 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -55,8 +55,7 @@ struct NodeLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -120,13 +119,11 @@ struct NodeLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index a675632a55..58b4ef23fd 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -60,8 +60,7 @@ struct OutputLabelledMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -143,13 +142,11 @@ struct OutputLabelledMultiDiGraph private: Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 18ffbf569d..231d70db74 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -63,8 +63,7 @@ struct OutputLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -174,13 +173,11 @@ struct OutputLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl), il(il), ol(ol) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index 1a98701811..8af69e18fc 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -60,8 +60,7 @@ struct LabelledMultiDiGraphView : NodeLabelledMultiDiGraphView(ptr) {} Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraphView); @@ -132,13 +131,11 @@ struct LabelledMultiDiGraph : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 41ae3e1aa3..771e01e573 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -66,8 +66,7 @@ std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { } IMultiDiGraph const &MultiDiGraph::get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } IMultiDiGraph &MultiDiGraph::get_ptr() { diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index ab13cb5ef7..b1e8be7f14 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -56,7 +56,8 @@ std::unordered_set } IUndirectedGraphView const &UndirectedGraphView::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } } // namespace FlexFlow From 02937e1e584d110d2e1e89301889c393ea6526d2 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 24 Jan 2024 16:32:36 -0500 Subject: [PATCH 05/32] fix --- lib/compiler/src/graph_utils.cc | 8 +++----- .../include/utils/graph/labelled/output_labelled.h | 10 ++++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index e0134d6dd8..12d5c99a34 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -12,14 +12,12 @@ ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { } SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { - auto g = pcg.value(); - auto g_ = view_output_labelled_as_output_labelled_open(g); - auto subpcg = materialize_output_labelled_open_multidigraph_view< + return materialize_output_labelled_open_multidigraph_view< AdjacencyOpenMultiDiGraph, UnorderedLabelling, UnorderedLabelling, - UnorderedLabelling>(g_); - return subpcg; + UnorderedLabelling>( + view_output_labelled_as_output_labelled_open(pcg.value())); } std::vector diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index 58b4ef23fd..f3cf14022b 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -39,11 +39,12 @@ struct OutputLabelledMultiDiGraphView return get_ptr().at(o); } - std::unordered_set query_nodes(NodeQuery const &q) const { + virtual std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } - std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { + virtual std::unordered_set + query_edges(MultiDiEdgeQuery const &q) const { return get_ptr().query_edges(q); } @@ -115,11 +116,12 @@ struct OutputLabelledMultiDiGraph return ol->get_label(o); } - std::unordered_set query_nodes(NodeQuery const &q) const { + std::unordered_set query_nodes(NodeQuery const &q) const override { return get_ptr().query_nodes(q); } - std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { + std::unordered_set + query_edges(MultiDiEdgeQuery const &q) const override { return get_ptr().query_edges(q); } From 6402ed0a538a232a16d8d634f39c131e1ae9a495 Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 25 Jan 2024 15:10:25 -0500 Subject: [PATCH 06/32] add substitutions, compiler, and their unit tests to CI --- .github/workflows/per-lib-check.yml | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index f21621b265..4cbcdf6afb 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -83,3 +83,28 @@ jobs: run: | cd build make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) kernels + + - name: Build substitutions + run: | + cd build + make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) substitutions + + - name: Build compiler + run: | + cd build + make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) compiler + + - name: Build substitutions-test + run: | + cd build + make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) substitutions-test + + - name: Build compiler-test + run: | + cd build + make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) compiler-test + + - name: Unit tests + run: | + cd build + ctest From 0c45f61f114414215c5eff7d39006f07735e2fe7 Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 25 Jan 2024 16:04:47 -0500 Subject: [PATCH 07/32] disable runtime unit test --- lib/runtime/CMakeLists.txt | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/lib/runtime/CMakeLists.txt b/lib/runtime/CMakeLists.txt index 49b052ec2b..fd5b4991ef 100644 --- a/lib/runtime/CMakeLists.txt +++ b/lib/runtime/CMakeLists.txt @@ -17,18 +17,18 @@ ff_add_library( pcg ) -ff_add_test_executable( - NAME - runtime-test - SRC_PATTERNS - test/src/*.cc - PUBLIC_INCLUDE - include/ - PRIVATE_INCLUDE - test/src/ src/ - DEPS - runtime - doctest -) +# ff_add_test_executable( +# NAME +# runtime-test +# SRC_PATTERNS +# test/src/*.cc +# PUBLIC_INCLUDE +# include/ +# PRIVATE_INCLUDE +# test/src/ src/ +# DEPS +# runtime +# doctest +# ) add_subdirectory(ffi) From 95fa427b8680e68acf33cccf96bbc66bb37cd1fa Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 15 Feb 2024 17:06:11 -0500 Subject: [PATCH 08/32] minor fix --- lib/compiler/src/machine_mapping.cc | 34 +++++- lib/compiler/src/unity_algorithm.cc | 8 +- lib/compiler/test/CMakeLists.txt | 4 +- .../test/src/test_labelled_open_graph.cc | 2 +- lib/compiler/test/src/test_open_graph.cc | 10 +- lib/compiler/test/src/test_optimal_cost.cc | 6 +- lib/pcg/include/pcg/machine_view.h | 3 - lib/pcg/include/pcg/operator.h | 19 ++-- lib/pcg/include/pcg/parallel_tensor.h | 2 + lib/pcg/include/pcg/strided_rectangle.h | 9 -- lib/pcg/src/machine_view.cc | 4 - lib/pcg/src/operator.cc | 6 +- lib/pcg/src/parallel_tensor.cc | 4 + lib/pcg/src/strided_rectangle.cc | 4 - lib/substitutions/src/substitution.cc | 106 +++++++++--------- lib/utils/include/utils/graph/algorithms.h | 1 + .../include/utils/graph/labelled/open_views.h | 2 - .../graph/labelled/output_labelled_open.h | 45 +++++--- .../utils/graph/labelled/unordered_label.h | 3 +- lib/utils/include/utils/variant.h | 8 ++ lib/utils/src/graph/algorithms.cc | 10 +- lib/utils/src/graph/open_graphs.cc | 2 +- lib/utils/src/graph/views.cc | 10 +- 23 files changed, 166 insertions(+), 136 deletions(-) diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 2bdd7de1e2..3cabd972bf 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -96,7 +96,39 @@ float estimate_cost(SubParallelComputationGraphView const &g, MachineMapping const &device_mapping, std::unordered_map const &frontier_machine_views) { - return 0.1; + float cost = 0; + for (Node const &node : get_nodes(g)) { + std::unordered_set incoming_edges = + get_incoming_edges(g, node); + std::vector inputs = + transform(as_vector(incoming_edges), + [&](UpwardOpenMultiDiEdge const &input_edge) { + return g.at(input_edge).get_shape(); + }); + cost += estimator.estimate_cost( + g.at(node).attrs, inputs, device_mapping.machine_views.at(node)); + } + + for (OpenMultiDiEdge const &edge : get_edges(g)) { + if (holds_alternative(edge)) { + cost += estimator.estimate_cost( + g.at(edge).get_shape(), + frontier_machine_views.at(edge), + device_mapping.machine_views.at(get(edge).dst)); + } else if (holds_alternative(edge)) { + cost += estimator.estimate_cost( + g.at(edge).get_shape(), + device_mapping.machine_views.at(get(edge).src), + frontier_machine_views.at(edge)); + } else { + assert(holds_alternative(edge)); + cost += estimator.estimate_cost( + g.at(edge).get_shape(), + device_mapping.machine_views.at(get(edge).src), + device_mapping.machine_views.at(get(edge).dst)); + } + } + return cost; } void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index c89bf04b25..3363aecc2f 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -9,11 +9,17 @@ bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { return lhs.runtime < rhs.runtime; } +/* + * Gets all substitutions applicable to a PCG + */ std::unordered_set get_all_substitutions(ParallelComputationGraph const &pcg) { NOT_IMPLEMENTED(); } +/* + * Applies a substitution to all possible positions in PCG + */ std::unordered_set apply_substitution(ParallelComputationGraph const &pcg, Substitution const &) { @@ -53,7 +59,7 @@ Strategy Strategy const ¤t_result = candidates.top(); candidates.pop(); - if (StrategyRuntimeCmp{}(current_result, best_result)) { + if (current_result.runtime < best_result.runtime) { best_result = current_result; } else if (current_result.runtime > best_result.runtime * opt_config.alpha) { diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index cc64b15f7d..cbd7e233c0 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -2,9 +2,7 @@ ff_add_test_executable( NAME compiler-test SRC_PATTERNS - src/test_labelled_open_graph.cc - src/test_open_graph.cc - src/test_optimal_cost.cc + src/*.cc PRIVATE_INCLUDE src/ DEPS diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index 82c247e0d2..1cae9a0cd1 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -4,7 +4,7 @@ using namespace FlexFlow; -TEST_CASE("get_subgraph_open_graph") { +TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { auto g = OpenMultiDiGraph::create(); int t0 = 100000; diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc index 00cb4ca890..7436f213d7 100644 --- a/lib/compiler/test/src/test_open_graph.cc +++ b/lib/compiler/test/src/test_open_graph.cc @@ -7,8 +7,6 @@ using namespace FlexFlow; TEST_CASE("get_source_sink_open_graph") { OpenMultiDiGraph g = OpenMultiDiGraph::create(); - int s0 = 100000; - Node n0 = g.add_node(); NodePort p0 = g.add_node_port(); InputMultiDiEdge e0{ @@ -25,9 +23,6 @@ TEST_CASE("get_source_sink_open_graph") { TEST_CASE("get_source_sink_open_graph:unconnected") { OpenMultiDiGraph g = OpenMultiDiGraph::create(); - int s0 = 100000; - int t0 = s0 + 1; - Node n0 = g.add_node(); Node n1 = g.add_node(); @@ -54,10 +49,7 @@ TEST_CASE("get_source_sink_open_graph:unconnected") { TEST_CASE("get_cut") { auto g = OpenMultiDiGraph::create(); - std::vector ns; - for (int i = 0; i < 5; ++i) { - ns.push_back(g.add_node()); - } + std::vector ns = add_nodes(g, 5); MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 7eeb118c57..a6cd88a006 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -33,10 +33,10 @@ TEST_CASE("optimal_cost_0") { UnorderedLabelling, UnorderedLabelling>(); - Node n0 = pcg.add_node(Operator(InputAttrs{}, "input")); - Node n1 = pcg.add_node(Operator( + Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); + Node n1 = pcg.add_node(Operator{ LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, - "linear")); + "linear"}); MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; pcg.add_edge(e); diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index afd4206eb1..7521cd209a 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -13,9 +13,6 @@ namespace FlexFlow { struct MachineView { - // MachineView() = delete; - // MachineView(device_id_t const &, StridedRectangle const &); - std::vector device_ids() const; device_id_t at(FFOrdered const &coord) const; diff --git a/lib/pcg/include/pcg/operator.h b/lib/pcg/include/pcg/operator.h index d09e25dcf3..3eb7fb2a43 100644 --- a/lib/pcg/include/pcg/operator.h +++ b/lib/pcg/include/pcg/operator.h @@ -2,31 +2,26 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_H #include "op-attrs/operator_attrs.h" -#include "utils/optional.h" #include "utils/stack_string.h" #include "utils/visitable.h" +#include + namespace FlexFlow { -struct Operator : public use_visitable_cmp { +struct Operator { public: - Operator() = delete; - Operator(PCGOperatorAttrs const &attrs, optional const &name); - operator PCGOperatorAttrs() const; public: PCGOperatorAttrs attrs; - optional name; + req> name; }; -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::Operator, attrs, name); -MAKE_VISIT_HASHABLE(::FlexFlow::Operator); +FF_VISITABLE_STRUCT(Operator, attrs, name); -namespace FlexFlow { static_assert(is_well_behaved_value_type::value, ""); -} + +} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/parallel_tensor.h b/lib/pcg/include/pcg/parallel_tensor.h index eadc83d9fd..4594e849cf 100644 --- a/lib/pcg/include/pcg/parallel_tensor.h +++ b/lib/pcg/include/pcg/parallel_tensor.h @@ -47,6 +47,8 @@ struct ParallelTensor : public use_visitable_cmp { optional sync_type = nullopt, optional initializer = nullopt); + ParallelTensorShape get_shape() const; + public: ParallelTensorDims dims; DataType data_type; diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index 179fff080f..d123d7c6ac 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -41,9 +41,6 @@ FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, struct StridedRectangle { public: - // StridedRectangle() = delete; - // StridedRectangle(std::vector const &); - size_t at(FFOrdered const &) const; StridedRectangleSide at(ff_dim_t const &) const; size_t num_dims() const; @@ -62,10 +59,4 @@ MAKE_TYPEDEF_PRINTABLE(::FlexFlow::num_points_t, "num_points"); MAKE_TYPEDEF_HASHABLE(::FlexFlow::side_size_t); MAKE_TYPEDEF_PRINTABLE(::FlexFlow::side_size_t, "side_size"); -// VISITABLE_STRUCT(::FlexFlow::StridedRectangleSide, num_points, stride); -// MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangleSide); - -// VISITABLE_STRUCT(::FlexFlow::StridedRectangle, sides); -// MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangle); - #endif diff --git a/lib/pcg/src/machine_view.cc b/lib/pcg/src/machine_view.cc index f146482141..46f87833f0 100644 --- a/lib/pcg/src/machine_view.cc +++ b/lib/pcg/src/machine_view.cc @@ -3,10 +3,6 @@ namespace FlexFlow { -// MachineView::MachineView(device_id_t const &start, StridedRectangle const -// &rect) -// : start(start), rect(rect) {} - static StridedRectangle make_1d_rect(int start, int stop, int stride) { assert(stop > start); assert(stride > 0); diff --git a/lib/pcg/src/operator.cc b/lib/pcg/src/operator.cc index 5cba8584c9..81e7326a76 100644 --- a/lib/pcg/src/operator.cc +++ b/lib/pcg/src/operator.cc @@ -2,9 +2,9 @@ namespace FlexFlow { -Operator::Operator(PCGOperatorAttrs const &attrs, - optional const &name) - : attrs(attrs), name(name) {} +// Operator::Operator(PCGOperatorAttrs const &attrs, +// std::optional const &name) +// : attrs(attrs), name(name) {} Operator::operator PCGOperatorAttrs() const { return attrs; diff --git a/lib/pcg/src/parallel_tensor.cc b/lib/pcg/src/parallel_tensor.cc index a8d7b15ea9..8cc79d7293 100644 --- a/lib/pcg/src/parallel_tensor.cc +++ b/lib/pcg/src/parallel_tensor.cc @@ -10,4 +10,8 @@ ParallelTensor::ParallelTensor(ParallelTensorDims const &dims, : dims(dims), data_type(data_type), sync_type(sync_type), initializer(initializer), create_gradients(create_gradients) {} +ParallelTensorShape ParallelTensor::get_shape() const { + return ParallelTensorShape(dims, data_type); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/strided_rectangle.cc b/lib/pcg/src/strided_rectangle.cc index 2792db65fe..7f612b743b 100644 --- a/lib/pcg/src/strided_rectangle.cc +++ b/lib/pcg/src/strided_rectangle.cc @@ -30,8 +30,4 @@ side_size_t StridedRectangleSide::get_size() const { NOT_IMPLEMENTED(); } -// StridedRectangle::StridedRectangle( -// std::vector const &sides) -// : sides(sides) {} - } // namespace FlexFlow diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 8e99624acb..635083b780 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -119,27 +119,27 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, get(assignments.at(OperatorAttributeKey::OP_TYPE)); switch (op_type) { case Op::BATCHMATMUL: - return Operator( + return Operator{ BatchMatmulAttrs{ get(assignments.at(OperatorAttributeKey::A_SEQ_LENGTH_DIM)), get(assignments.at(OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, - nullopt); + std::nullopt}; case Op::BATCHNORM: - return Operator( + return Operator{ BatchNormAttrs{get(assignments.at(OperatorAttributeKey::RELU))}, - nullopt); + std::nullopt}; case Op::CAST: - return Operator(CastAttrs{get( + return Operator{CastAttrs{get( assignments.at(OperatorAttributeKey::DATA_TYPE))}, - nullopt); + std::nullopt}; case Op::CONCAT: - return Operator( + return Operator{ ConcatAttrs{ get(assignments.at(OperatorAttributeKey::AXIS)), get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, - nullopt); + std::nullopt}; case Op::CONV2D: - return Operator( + return Operator{ Conv2DAttrs{ get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), get(assignments.at(OperatorAttributeKey::KERNEL_H)), @@ -151,13 +151,13 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, get(assignments.at(OperatorAttributeKey::GROUPS)), get(assignments.at(OperatorAttributeKey::ACTIVATION)), get(assignments.at(OperatorAttributeKey::USE_BIAS))}, - nullopt); + std::nullopt}; case Op::DROPOUT: - return Operator( + return Operator{ DropoutAttrs{get(assignments.at(OperatorAttributeKey::RATE)), get( assignments.at(OperatorAttributeKey::SEED))}, - nullopt); + std::nullopt}; case Op::EW_ADD: case Op::EW_DIV: case Op::EW_EQUAL: @@ -167,7 +167,7 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::EW_MIN: case Op::EW_MUL: case Op::EW_SUB: - return Operator( + return Operator{ ElementBinaryAttrs{ op_type, get(assignments.at(OperatorAttributeKey::DATA_TYPE)), @@ -175,44 +175,44 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_LHS)), get( assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, - nullopt); + std::nullopt}; case Op::SCALAR_ADD: case Op::SCALAR_FLOOR_DIV: case Op::SCALAR_MULTIPLY: case Op::SCALAR_SUB: case Op::SCALAR_TRUE_DIV: - return Operator( + return Operator{ ElementScalarUnaryAttrs{ op_type, get(assignments.at(OperatorAttributeKey::SCALAR))}, - nullopt); + std::nullopt}; case Op::EMBEDDING: - return Operator( + return Operator{ EmbeddingAttrs{ get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), get(assignments.at(OperatorAttributeKey::AGGR)), get(assignments.at(OperatorAttributeKey::OP_TYPE))}, - nullopt); + std::nullopt}; case Op::FLAT: - return Operator(FlatAttrs{}, nullopt); + return Operator{FlatAttrs{}, std::nullopt}; case Op::GATHER: - return Operator( + return Operator{ GatherAttrs{get(assignments.at(OperatorAttributeKey::DIM))}, - nullopt); + std::nullopt}; case Op::INPUT: - return Operator(InputAttrs{}, nullopt); + return Operator{InputAttrs{}, std::nullopt}; case Op::LAYERNORM: - return Operator( + return Operator{ LayerNormAttrs{ get>( assignments.at(OperatorAttributeKey::AXES)), get( assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), get(assignments.at(OperatorAttributeKey::EPSILON))}, - nullopt); + std::nullopt}; case Op::LINEAR: - return Operator( + return Operator{ LinearAttrs{ get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), get(assignments.at(OperatorAttributeKey::USE_BIAS)), @@ -220,9 +220,9 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, get(assignments.at(OperatorAttributeKey::ACTIVATION)), get>( assignments.at(OperatorAttributeKey::REGULARIZER))}, - nullopt); + std::nullopt}; case Op::MULTIHEAD_ATTENTION: - return Operator( + return Operator{ MultiHeadAttentionAttrs{ get(assignments.at(OperatorAttributeKey::EMBED_DIM)), get(assignments.at(OperatorAttributeKey::NUM_HEADS)), @@ -232,11 +232,11 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, get(assignments.at(OperatorAttributeKey::BIAS)), get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), get(assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, - nullopt); + std::nullopt}; case Op::NOOP: - return Operator(NoopAttrs{}, nullopt); + return Operator{NoopAttrs{}, std::nullopt}; case Op::POOL2D: - return Operator( + return Operator{ Pool2DAttrs{ get(assignments.at(OperatorAttributeKey::KERNEL_H)), get(assignments.at(OperatorAttributeKey::KERNEL_W)), @@ -247,7 +247,7 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, get(assignments.at(OperatorAttributeKey::POOL_TYPE)), get( assignments.at(OperatorAttributeKey::ACTIVATION))}, - nullopt); + std::nullopt}; case Op::REDUCE_ARGMAX: case Op::REDUCE_ARGMIN: case Op::REDUCE_MAX: @@ -255,65 +255,65 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::REDUCE_MIN: case Op::REDUCE_PROD: case Op::REDUCE_SUM: - return Operator( + return Operator{ ReduceAttrs{ get>( assignments.at(OperatorAttributeKey::AXES)), op_type, get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, - nullopt); + std::nullopt}; case Op::REVERSE: - return Operator(ReverseAttrs{get( + return Operator{ReverseAttrs{get( assignments.at(OperatorAttributeKey::AXIS))}, - nullopt); + std::nullopt}; case Op::RESHAPE: - return Operator(ReshapeAttrs{get( + return Operator{ReshapeAttrs{get( assignments.at(OperatorAttributeKey::SHAPE))}, - nullopt); + std::nullopt}; case Op::SPLIT: - return Operator( + return Operator{ SplitAttrs{get>( assignments.at(OperatorAttributeKey::SPLITS)), get(assignments.at(OperatorAttributeKey::AXIS))}, - nullopt); + std::nullopt}; case Op::SOFTMAX: - return Operator(SoftmaxAttrs{get( + return Operator{SoftmaxAttrs{get( assignments.at(OperatorAttributeKey::DIM))}, - nullopt); + std::nullopt}; case Op::TOPK: - return Operator( + return Operator{ TopKAttrs{get(assignments.at(OperatorAttributeKey::K)), get(assignments.at(OperatorAttributeKey::SORTED))}, - nullopt); + std::nullopt}; case Op::TRANSPOSE: - return Operator( + return Operator{ TransposeAttrs{get>( assignments.at(OperatorAttributeKey::PERMUTATION))}, - nullopt); + std::nullopt}; case Op::COMBINE: - return Operator( + return Operator{ CombineAttrs{ get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + std::nullopt}; case Op::REDUCTION: - return Operator( + return Operator{ ReductionAttrs{ get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + std::nullopt}; case Op::REPARTITION: - return Operator( + return Operator{ RepartitionAttrs{ get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + std::nullopt}; case Op::REPLICATE: - return Operator( + return Operator{ ReplicateAttrs{ get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + std::nullopt}; default: mk_runtime_error("Unknown Operator"); } diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index cee5445190..12aa2dccb0 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -23,6 +23,7 @@ std::vector add_nodes(Graph &, int); std::vector add_nodes(UndirectedGraph &, int); std::vector add_nodes(DiGraph &, int); std::vector add_nodes(MultiDiGraph &, int); +std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes); std::vector add_node_ports(MultiDiGraph &, int); diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index 8323fcf9dc..a24c2b940b 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -52,8 +52,6 @@ struct OutputLabelledOpenMultiDiSubgraphView std::unordered_set const &nodes; }; -// CHECK_NOT_ABSTRACT(OutputLabelledOpenMultiDiSubgraphView); - template struct ViewOutputLabelledAsOutputLabelledOpen : virtual IOutputLabelledOpenMultiDiGraphView { diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 231d70db74..1c1b28c6d6 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -40,6 +40,16 @@ struct OutputLabelledOpenMultiDiGraphView return get_ptr().at(o); } + template + EdgeLabel const &at(variant const &e) const { + return visit([&](auto const &e) -> auto const & { return this->at(e); }, e); + } + + template + EdgeLabel &at(variant const &e) { + return visit([&](auto const &e) -> auto & { return this->at(e); }, e); + } + std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } @@ -85,52 +95,52 @@ struct OutputLabelledOpenMultiDiGraph Node add_node(NodeLabel const &l) { Node n = get_ptr().add_node(); - nl.get_mutable()->add_label(n, l); + this->node_labelling.get_mutable()->add_label(n, l); return n; } void add_node_unsafe(Node const &n, NodeLabel const &l) { - get_ptr().add_node_unsafe(n); - nl.get_mutable()->add_label(n, l); + this->get_ptr().add_node_unsafe(n); + this->node_labelling.get_mutable()->add_label(n, l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } NodeLabel &at(Node const &n) { - return nl.get_mutable()->get_label(n); + return this->node_labelling.get_mutable()->get_label(n); } NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->node_labelling->get_label(n); } void add_label(MultiDiOutput const &o, EdgeLabel const &l) { - ol.get_mutable()->add_label(o, l); + this->output_labelling.get_mutable()->add_label(o, l); }; void add_label(InputMultiDiEdge const &e, EdgeLabel const &l) { - il.get_mutable()->add_label(e, l); + this->input_labelling.get_mutable()->add_label(e, l); } void add_edge(OpenMultiDiEdge const &e) { - return get_ptr().add_edge(e); + return this->get_ptr().add_edge(e); } EdgeLabel &at(MultiDiOutput const &o) { - return ol.get_mutable()->get_label(o); + return this->output_labelling.get_mutable()->get_label(o); } EdgeLabel const &at(MultiDiOutput const &o) const { - return ol->get_label(o); + return this->output_labelling->get_label(o); } EdgeLabel &at(InputMultiDiEdge const &e) { - return il.get_mutable()->get_label(e); + return this->input_labelling.get_mutable()->get_label(e); } EdgeLabel const &at(InputMultiDiEdge const &e) const { - return il->get_label(e); + return this->input_labelling->get_label(e); } template @@ -170,7 +180,8 @@ struct OutputLabelledOpenMultiDiGraph cow_ptr_t nl, cow_ptr_t il, cow_ptr_t ol) - : GraphView(ptr), nl(nl), il(il), ol(ol) {} + : GraphView(ptr), node_labelling(nl), input_labelling(il), + output_labelling(ol) {} Interface &get_ptr() { return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); @@ -180,9 +191,9 @@ struct OutputLabelledOpenMultiDiGraph return *std::dynamic_pointer_cast(GraphView::ptr.get()); } - cow_ptr_t nl; - cow_ptr_t il; - cow_ptr_t ol; + cow_ptr_t node_labelling; + cow_ptr_t input_labelling; + cow_ptr_t output_labelling; }; template diff --git a/lib/utils/include/utils/graph/labelled/unordered_label.h b/lib/utils/include/utils/graph/labelled/unordered_label.h index 230e286ef8..94c4bffe11 100644 --- a/lib/utils/include/utils/graph/labelled/unordered_label.h +++ b/lib/utils/include/utils/graph/labelled/unordered_label.h @@ -19,8 +19,7 @@ struct UnorderedLabelling : virtual public ILabelling { } void add_label(Elem const &e, Label const &l) { - auto p = std::make_pair(e, l); - label_map.insert(p); + label_map.insert({e, l}); } UnorderedLabelling *clone() const { diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index b1a1dc1081..bb78719c9e 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -198,6 +198,14 @@ auto narrow(Container const &c) -> decltype(transform( return transform(c, [](VariantIn const &i) { return narrow(i); }); } +template ::value>> +auto narrow(Container const &c) { + return transform(c, [](VariantIn const &e) { return get(e); }); +} + template add_nodes(MultiDiGraph &g, int num_nodes) { return add_nodes_impl(g, num_nodes); } +std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes) { + return add_nodes_impl(g, num_nodes); +} + std::vector add_node_ports(MultiDiGraph &g, int num_node_ports) { std::vector node_ports; for (int i = 0; i < num_node_ports; i++) { @@ -786,15 +790,13 @@ std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g) { std::unordered_set get_open_sources(OpenMultiDiGraphView const &g) { return filter(get_nodes(g), [&](Node const &n) { - return !g.query_edges(InputMultiDiEdgeQuery::all().with_dst_nodes({n})) - .empty(); + return !get_incoming_edges(g, n).empty(); }); } std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g) { return filter(get_nodes(g), [&](Node const &n) { - return !g.query_edges(OutputMultiDiEdgeQuery::all().with_src_nodes({n})) - .empty(); + return !get_outgoing_edges(g, n).empty(); }); } diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index 5ab5858fd2..8355713506 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -52,7 +52,7 @@ std::unordered_set } NodePort OpenMultiDiGraph::add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } IOpenMultiDiGraph &OpenMultiDiGraph::get_ptr() { diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index bf4f7351c0..dc823f7da4 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -469,8 +469,9 @@ std::unordered_set UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), - inputs(transform(get_cut_set(g, nodes), to_inputmultidiedge)) {} + : g(g), nodes(nodes) { + inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); +} UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { return new UpwardOpenMultiDiSubgraphView(g, nodes); @@ -494,8 +495,9 @@ std::unordered_set DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), - outputs(transform(get_cut_set(g, nodes), to_outputmultidiedge)) {} + : g(g), nodes(nodes) { + outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); +} std::unordered_set DownwardOpenMultiDiSubgraphView::query_edges( From 1f7e2b6aebe0e22cbb8c64db07667d359f980553 Mon Sep 17 00:00:00 2001 From: wmdi Date: Sun, 18 Feb 2024 17:27:51 -0500 Subject: [PATCH 09/32] (not compilable) visitable issue for OptimalCostState --- lib/compiler/include/compiler/machine_mapping.h | 10 +++++----- lib/compiler/src/machine_mapping.cc | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index aa9152dcd6..7299404d90 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -23,12 +23,12 @@ FF_VISITABLE_STRUCT(MachineMapping, machine_views); struct OptimalCostState { SerialParallelDecomposition subgraph; - req resource; - // req> given_machine_views; - // req> - // frontier_machine_views; + MachineSpecification resource; + std::unordered_map given_machine_views; + req> + frontier_machine_views; }; -FF_VISITABLE_STRUCT(OptimalCostState, subgraph, resource); +FF_VISITABLE_STRUCT(OptimalCostState, subgraph, resource, given_machine_views, frontier_machine_views); struct OptimalCostResult { static OptimalCostResult sequential_combine(OptimalCostResult const &s1, diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 3cabd972bf..fc89ff1306 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -166,7 +166,7 @@ struct OptimalCost { template OptimalCostResult operator()(T const &t) const { OptimalCostState state{ - t, resource /*, given_machine_views, frontier_machine_views*/}; + t, resource , given_machine_views, frontier_machine_views}; optional cached_result = cached_subgraph_costs.load(state); From ffa7f79ae98eff4e0fd4134da256c3d45553f571 Mon Sep 17 00:00:00 2001 From: Bob Chen <70640928+Bob-Chen222@users.noreply.github.com> Date: Fri, 23 Feb 2024 00:55:52 -0500 Subject: [PATCH 10/32] first try on docs --- docs/doxygen/Doxyfile | 1 + .../substitutions/graph_pattern_match.h | 47 +++++++++++++++++++ .../include/substitutions/operator_pattern.h | 21 +++++++++ .../substitutions/parallel_tensor_pattern.h | 12 +++++ 4 files changed, 81 insertions(+) diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index b38bfc12b5..fe2f69c500 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -881,6 +881,7 @@ INPUT += $(FF_HOME)/include INPUT += $(FF_HOME)/nmt INPUT += $(FF_HOME)/python INPUT += $(FF_HOME)/src +INPUT += $(FF_HOME)/lib/substitutions/include # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/lib/substitutions/include/substitutions/graph_pattern_match.h b/lib/substitutions/include/substitutions/graph_pattern_match.h index bf6d6b6921..d441d5390e 100644 --- a/lib/substitutions/include/substitutions/graph_pattern_match.h +++ b/lib/substitutions/include/substitutions/graph_pattern_match.h @@ -6,32 +6,79 @@ namespace FlexFlow { +/** + * @struct MultiDiGraphPatternMatch + * @brief MultiDiGraphPatternMatch is a struct that describes a mapping from how an open graph is matched with + * a PCG graph. + * To apply a substitution to a PCG, we should first match the pattern graph to a subgraph of the PCG. MultiDiGraphPatternMatch describes the match, + * which consists of a node_assignment that describes how the GraphPattern node mapped to PCG node and an edge_assignment that describes + * how the GraphPattern edge mapped to PCG edge. + */ struct MultiDiGraphPatternMatch { using PatternNode = Node; using PCGNode = Node; + + /** + * @see OpenMultiDiEdge + */ using PatternEdge = OpenMultiDiEdge; using PCGEdge = OpenMultiDiEdge; + /** + * @brief node_assignment is a bidirectional map from PatternNode to PCGNode + */ bidict node_assignment; + + /** + * @brief edge_assignment is a bidirectional map from PatternEdge to PCGEdge + */ bidict edge_assignment; }; +/** + * @struct MatchSplit + * @brief MatchSplit is a struct that describes a split of a MultiDiGraphPatternMatch into two sub MultiDiGraphPatternMatch + * + */ struct MatchSplit { MultiDiGraphPatternMatch prefix_submatch; MultiDiGraphPatternMatch postfix_submatch; }; +/** + * @struct MatchAdditionalCriterion + * @brief The additional conditions need to be satisfied other than geometric properties of the graph. + */ struct MatchAdditionalCriterion { std::function node_criterion; std::function edge_criterion; }; +/** + * @brief pattern_matches checks if the pattern graph matches the graph with additional conditions defined by additional_criterion. + * @param pattern The pattern graph + * @param graph The graph to be matched + * @param match The mapping between the pattern graph and the graph + * @param additional_criterion The additional conditions need to be satisfied other than geometric properties of the graph. + * @return true if the pattern graph matches the graph, false otherwise. + * @details function is used to check whether the generated match from pattern to graph is valid or not. It is used in find_pattern_matches to check against all the enumerated matches + * and filter out the invalid ones. + */ bool pattern_matches(OpenMultiDiGraphView const &pattern, OpenMultiDiGraphView const &graph, MultiDiGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion); +/** + * @brief generate all valid matches from pattern to a subgraph of graph + * @param pattern + * @param graph + * @param additional_criterion + * @return std::vector + * + * @details Given a pattern and a graph, find all the valid matches between the pattern and the graph with additional conditions defined by additional_criterion. + */ std::vector find_pattern_matches(OpenMultiDiGraphView const &pattern, OpenMultiDiGraphView const &graph, diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index 9392a7876e..31805438f9 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -11,6 +11,11 @@ namespace FlexFlow { +/** + * @brief OperatorAttributeKey is an enum class that represents the keys of the attributes of an Operator. + * Each operator has a set of attributes that describe its behavior. OperatorAttributeKey is used to retrieve the value of an attribute or the expression of an attribute stored + * in an attribute map. + */ enum class OperatorAttributeKey { OP_TYPE, // AnyOp USE_BIAS, @@ -70,6 +75,11 @@ enum class OperatorAttributeKey { NUM_INPUTS }; +/** + * @brief OperatorAttributeValue is a variant that represents the concrete value of an attribute of an Operator. + * The OperatorAttributeValue is evalutated from AttributeExpr + * The datatype of the value corresponds to the datatype of the attributekey listed in OperatorAttributeKey. + */ using OperatorAttributeValue = variant, index); FF_VISITABLE_STRUCT(ListSize, attribute_key); +/** + * @todo: need to better understand what is constraints and pattern + * + */ using OperatorAttributeConstraint = AttributeConstraint; using OperatorPattern = AttributePattern; + +/** + * @brief Given a specific attribute of an Operator, evaluate the expression of the attribute and return the value of the attribute. + * @param attrs + * @param expr + * @return optional + */ optional evaluate_attribute_expr(Operator const &attrs, AttributeExpr const &expr); diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h index d07a1da23b..2c22ff878f 100644 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -6,6 +6,11 @@ namespace FlexFlow { +/** + * @brief TensorAttributeKey is an enum class that represents the keys of the attributes of a Tensor(matrix). + * DIM_SIZES describes the length along each dimension of the tensor + * DIM_DEGREES describes the number of partitions along each dimension of the tensor for data parallelism computation + */ enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; using TensorAttributeValue = variant>; @@ -16,6 +21,13 @@ using TensorAttributeConstraint = using ParallelTensorPattern = AttributePattern; +/** + * @brief evaluate_attribute_expr evaluates the attribute expression for a given ParallelTensor + * + * @param tensor_shape, which describes the attributes of a ParallelTensor + * @param expr, which describes the specific attribute expression to be evaluated + * @return optional + */ optional evaluate_attribute_expr(ParallelTensor const &tensor_shape, AttributeExpr const &expr); From a9a64020d9d7295065d66914ce3ffe301ed75aeb Mon Sep 17 00:00:00 2001 From: wmdi Date: Tue, 27 Feb 2024 15:43:22 -0500 Subject: [PATCH 11/32] fix machine mapping hash & refactor dp algorithm --- .../include/compiler/machine_mapping.h | 9 +- lib/compiler/src/graph_utils.cc | 10 +- lib/compiler/src/graph_utils.h | 3 +- lib/compiler/src/machine_mapping.cc | 206 ++++++++++-------- lib/compiler/src/unity_algorithm.cc | 6 +- lib/utils/include/utils/containers.decl.h | 2 +- lib/utils/include/utils/containers.h | 2 +- lib/utils/include/utils/fmt.h | 8 + lib/utils/include/utils/hash-utils.h | 4 +- lib/utils/test/src/test_hash.cc | 18 ++ 10 files changed, 156 insertions(+), 112 deletions(-) create mode 100644 lib/utils/test/src/test_hash.cc diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 7299404d90..185f2706ef 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -25,10 +25,13 @@ struct OptimalCostState { SerialParallelDecomposition subgraph; MachineSpecification resource; std::unordered_map given_machine_views; - req> - frontier_machine_views; + req> frontier_machine_views; }; -FF_VISITABLE_STRUCT(OptimalCostState, subgraph, resource, given_machine_views, frontier_machine_views); +FF_VISITABLE_STRUCT(OptimalCostState, + subgraph, + resource, + given_machine_views, + frontier_machine_views); struct OptimalCostResult { static OptimalCostResult sequential_combine(OptimalCostResult const &s1, diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 12d5c99a34..069ae4a41f 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -11,13 +11,9 @@ ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { NOT_IMPLEMENTED(); } -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { - return materialize_output_labelled_open_multidigraph_view< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling, - UnorderedLabelling>( - view_output_labelled_as_output_labelled_open(pcg.value())); +SubParallelComputationGraphView + pcg_to_subpcg(ParallelComputationGraph const &pcg) { + return view_output_labelled_as_output_labelled_open(pcg.value()); } std::vector diff --git a/lib/compiler/src/graph_utils.h b/lib/compiler/src/graph_utils.h index 88515ef950..711a253b61 100644 --- a/lib/compiler/src/graph_utils.h +++ b/lib/compiler/src/graph_utils.h @@ -9,7 +9,8 @@ SerialParallelDecomposition get_serial_parallel_decomposition(ParallelComputationGraph const &pcg); ParallelComputationGraph cg_to_pcg(ComputationGraph const &g); -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); +SubParallelComputationGraphView + pcg_to_subpcg(ParallelComputationGraph const &g); // NOTE(@wmdi): I think we should have the following interfaces in the graph // library eventually. diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index fc89ff1306..5ce988b951 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -45,12 +45,10 @@ bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, optional OptimalCostCache::load(OptimalCostState const &state) const { - auto it = cache.find(state); - // if (contains_key(cache, state)) { - // // auto result = cache.at(state); - // OptimalCostResult result = OptimalCostResult::infinity(); - // return make_optional(result); - // } + if (contains_key(cache, state)) { + OptimalCostResult result = cache.at(state); + return make_optional(result); + } return nullopt; } @@ -135,51 +133,74 @@ void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { minimize(m1, m2, OptimalCostRuntimeCmp{}); } -struct OptimalCost { - OptimalCost(SubParallelComputationGraphView const &g, - CostEstimator const &cost_estimator, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views, - std::function( - Operator const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimalCostCache &cached_subgraph_costs) - : g(g), cost_estimator(cost_estimator), resource(resource), - given_machine_views(restrict_keys(given_machine_views, get_nodes(g))), - frontier_machine_views( - restrict_keys(frontier_machine_views, get_edges(g))), +struct MachineMappingSearcher { + MachineMappingSearcher( + CostEstimator cost_estimator, + std::function( + Operator const &, MachineSpecification const &)> const + &allowed_machine_views, + OptimalCostCache &cached_subgraph_costs) + : cost_estimator(cost_estimator), allowed_machine_views(allowed_machine_views), cached_subgraph_costs(cached_subgraph_costs) {} - SubParallelComputationGraphView const &g; - CostEstimator const &cost_estimator; - MachineSpecification const &resource; - std::unordered_map given_machine_views; - std::unordered_map frontier_machine_views; - std::function( - Operator const &, MachineSpecification const &)> const - &allowed_machine_views; + CostEstimator cost_estimator; + std::function(Operator const &, + MachineSpecification const &)> + allowed_machine_views; OptimalCostCache &cached_subgraph_costs; - template - OptimalCostResult operator()(T const &t) const { - OptimalCostState state{ - t, resource , given_machine_views, frontier_machine_views}; - optional cached_result = - cached_subgraph_costs.load(state); + struct OptimalCostFunctor { + OptimalCostFunctor( + MachineMappingSearcher *searcher, + SubParallelComputationGraphView const &g, + MachineSpecification resource, + std::unordered_map given_machine_views, + std::unordered_map frontier_machine_views) + : searcher(searcher), g(g), resource(resource), + given_machine_views(given_machine_views), + frontier_machine_views(frontier_machine_views) {} + + MachineMappingSearcher *searcher; + SubParallelComputationGraphView const &g; + MachineSpecification resource; + std::unordered_map given_machine_views; + std::unordered_map frontier_machine_views; + + template + OptimalCostResult operator()(T const &t) { + OptimalCostState state{ + t, resource, given_machine_views, frontier_machine_views}; + optional cached_result = + searcher->cached_subgraph_costs.load(state); + + if (cached_result) { + return cached_result.value(); + } + OptimalCostResult result = searcher->optimal_cost( + t, g, resource, given_machine_views, frontier_machine_views); - if (cached_result) { - return cached_result.value(); + searcher->cached_subgraph_costs.save(state, result); + return result; } - OptimalCostResult result = this->optimal_cost(t); - - cached_subgraph_costs.save(state, result); - return result; + }; + + OptimalCostResult + optimal_cost(SubParallelComputationGraphView const &g, + MachineSpecification resource, + SerialParallelDecomposition const &sp_decomposition) { + return visit(OptimalCostFunctor(this, g, resource, {}, {}), + sp_decomposition); } - OptimalCostResult optimal_cost(Serial const &serial) const { + OptimalCostResult optimal_cost( + Serial const &serial, + SubParallelComputationGraphView const &g, + MachineSpecification const &resource, + std::unordered_map const &given_machine_views, + std::unordered_map const + &frontier_machine_views) { + auto decomposed = decompose(serial); SerialParallelDecomposition pre_decompn = decomposed.first; SerialParallelDecomposition post_decompn = decomposed.second; @@ -210,28 +231,30 @@ struct OptimalCost { new_frontier_machine_views.emplace(split_edge, mv); minimize_runtime(optimal_result, OptimalCostResult::sequential_combine( - visit(OptimalCost(pre_graph, - cost_estimator, - resource, - given_machine_views, - new_frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + pre_graph, + resource, + given_machine_views, + new_frontier_machine_views), pre_decompn), - visit(OptimalCost(post_graph, - cost_estimator, - resource, - new_given_machine_views, - frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + post_graph, + resource, + new_given_machine_views, + frontier_machine_views), post_decompn))); } return optimal_result; } - OptimalCostResult optimal_cost(Parallel const ¶llel) const { + OptimalCostResult optimal_cost( + Parallel const ¶llel, + SubParallelComputationGraphView const &g, + MachineSpecification const &resource, + std::unordered_map const &given_machine_views, + std::unordered_map const + &frontier_machine_views) { auto decomposed = decompose(parallel); SerialParallelDecomposition decompn1 = decomposed.first; SerialParallelDecomposition decompn2 = decomposed.second; @@ -243,48 +266,46 @@ struct OptimalCost { g, graph_split.second); OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( - visit(OptimalCost(g1, - cost_estimator, - resource, - given_machine_views, - frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g1, + resource, + given_machine_views, + frontier_machine_views), decompn1), - visit(OptimalCost(g2, - cost_estimator, - resource, - given_machine_views, - frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g2, + resource, + given_machine_views, + frontier_machine_views), decompn2)); for (auto const &resource_split : get_resource_split(resource)) { minimize_runtime(optimal_result, OptimalCostResult::parallel_combine( - visit(OptimalCost(g1, - cost_estimator, - resource_split.first, - given_machine_views, - frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g1, + resource_split.first, + given_machine_views, + frontier_machine_views), decompn1), - visit(OptimalCost(g2, - cost_estimator, - resource_split.second, - given_machine_views, - frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g2, + resource_split.second, + given_machine_views, + frontier_machine_views), decompn2))); } return optimal_result; } - OptimalCostResult optimal_cost(Node const &node) const { + OptimalCostResult optimal_cost( + Node const &node, + SubParallelComputationGraphView const &g, + MachineSpecification const &resource, + std::unordered_map const &given_machine_views, + std::unordered_map const + &frontier_machine_views) { if (contains_key(given_machine_views, node)) { assert(contains(allowed_machine_views(g.at(node), resource), source_machine_view.value())); @@ -315,15 +336,10 @@ OptimalCostResult OptimalCostCache &cached_subgraph_costs) { SerialParallelDecomposition sp_decomposition = get_serial_parallel_decomposition(g); - SubParallelComputationGraph subpcg = pcg_to_subpcg(g); - return visit(OptimalCost(subpcg, - cost_estimator, - resources, - std::unordered_map{}, - std::unordered_map{}, - allowed_machine_views, - cached_subgraph_costs), - sp_decomposition); + SubParallelComputationGraphView subpcg = pcg_to_subpcg(g); + MachineMappingSearcher searcher( + cost_estimator, allowed_machine_views, cached_subgraph_costs); + return searcher.optimal_cost(subpcg, resources, sp_decomposition); } } // namespace FlexFlow diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 3363aecc2f..9fcde4dcca 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -13,7 +13,7 @@ bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { * Gets all substitutions applicable to a PCG */ std::unordered_set - get_all_substitutions(ParallelComputationGraph const &pcg) { + get_all_applicable_substitutions(ParallelComputationGraph const &pcg) { NOT_IMPLEMENTED(); } @@ -37,7 +37,7 @@ Strategy ParallelComputationGraph pcg = cg_to_pcg(cg); - std::unordered_set subs = get_all_substitutions(pcg); + std::unordered_set subs = get_all_applicable_substitutions(pcg); OptimalCostCache cached_subgraph_costs; DeduplicatedPriorityQueue, StrategyRuntimeCmp> @@ -93,7 +93,7 @@ size_t hash::operator()(FlexFlow::Strategy const &s) const { size_t h = 0; hash_combine(h, s.pcg); - // hash_combine(h, s.machine_mapping); + hash_combine(h, s.machine_mapping); hash_combine(h, s.runtime); return h; diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 8ad65a4488..430da61ff9 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -109,7 +109,7 @@ template std::vector values(C const &c); template -std::unordered_set> +std::unordered_set> items(C const &c); template diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 679586ba69..1d0151c38a 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -228,7 +228,7 @@ std::vector values(C const &c) { } template -std::unordered_set> +std::unordered_set> items(C const &c) { return {c.begin(), c.end()}; } diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index ddf5b00355..c44cb88b61 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -6,6 +6,8 @@ #include "utils/test_types.h" #include "utils/type_traits_core.h" +#include + namespace FlexFlow { template @@ -26,6 +28,12 @@ struct already_has_ostream_operator : std::true_type {}; template <> struct already_has_ostream_operator : std::true_type {}; +template <> +struct already_has_ostream_operator> : std::true_type {}; + +template <> +struct already_has_ostream_operator : std::true_type {}; + // This will create an error /* template diff --git a/lib/utils/include/utils/hash-utils.h b/lib/utils/include/utils/hash-utils.h index 923c8df840..d56ff34644 100644 --- a/lib/utils/include/utils/hash-utils.h +++ b/lib/utils/include/utils/hash-utils.h @@ -4,6 +4,8 @@ #include "containers.h" #include "hash-utils-core.h" +using namespace FlexFlow; + namespace std { template struct hash> { @@ -18,7 +20,7 @@ struct hash> { template struct hash> { size_t operator()(std::unordered_map const &m) const { - return get_std_hash(items(m)); + return get_std_hash(::FlexFlow::items(m)); } }; diff --git a/lib/utils/test/src/test_hash.cc b/lib/utils/test/src/test_hash.cc new file mode 100644 index 0000000000..f0d907b741 --- /dev/null +++ b/lib/utils/test/src/test_hash.cc @@ -0,0 +1,18 @@ +#include "test/utils/doctest.h" +#include "utils/hash-utils.h" + +using namespace FlexFlow; + +TEST_CASE("hash:unordered_map") { + std::unordered_map map1{{1, 2}}; + std::unordered_map map2{{1, 2}, {3, 4}}; + + size_t hash1 = get_std_hash(map1); + size_t hash2 = get_std_hash(map2); + + CHECK(hash1 != hash2); + + map1.insert({1, 2}); + hash1 = get_std_hash(map1); + CHECK(hash1 == hash2); +} From d8bbcb883103c9ca046ff700be9e1655f80e4892 Mon Sep 17 00:00:00 2001 From: wmdi Date: Tue, 27 Feb 2024 16:34:09 -0500 Subject: [PATCH 12/32] minor fix --- lib/compiler/include/compiler/unity_algorithm.h | 13 ++----------- lib/compiler/src/unity_algorithm.cc | 14 -------------- lib/compiler/test/src/test_labelled_open_graph.cc | 2 -- lib/compiler/test/src/test_optimal_cost.cc | 1 + lib/pcg/include/pcg/operator.h | 2 +- lib/pcg/src/operator.cc | 4 ---- .../include/utils/graph/labelled/node_labelled.h | 9 ++++++--- .../utils/graph/labelled/node_labelled_open.h | 9 ++++++--- .../include/utils/graph/labelled/output_labelled.h | 9 ++++++--- .../utils/graph/labelled/output_labelled_open.h | 9 ++++++--- .../utils/graph/labelled/standard_labelled.h | 9 ++++++--- lib/utils/src/graph/digraph.cc | 7 ++++--- lib/utils/src/graph/multidigraph.cc | 7 ++++--- lib/utils/src/graph/node.cc | 4 ++-- lib/utils/src/graph/open_graphs.cc | 14 +++++++------- lib/utils/src/graph/undirected.cc | 6 +++--- 16 files changed, 54 insertions(+), 65 deletions(-) diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index a87bddcc3a..7d7a7a74dc 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -14,6 +14,8 @@ struct Strategy { req runtime; }; +FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime); + struct StrategyRuntimeCmp { bool operator()(Strategy const &, Strategy const &); }; @@ -36,15 +38,4 @@ Strategy } // namespace FlexFlow -VISITABLE_STRUCT(FlexFlow::Strategy, pcg, machine_mapping, runtime); - -namespace std { - -template <> -struct hash { - size_t operator()(FlexFlow::Strategy const &) const; -}; - -} // namespace std - #endif diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 9fcde4dcca..c9666851db 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -86,17 +86,3 @@ Strategy } } // namespace FlexFlow - -namespace std { - -size_t hash::operator()(FlexFlow::Strategy const &s) const { - size_t h = 0; - - hash_combine(h, s.pcg); - hash_combine(h, s.machine_mapping); - hash_combine(h, s.runtime); - - return h; -} - -} // namespace std diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index 1cae9a0cd1..a360d86ee7 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -7,8 +7,6 @@ using namespace FlexFlow; TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { auto g = OpenMultiDiGraph::create(); - int t0 = 100000; - Node n0 = g.add_node(); Node n1 = g.add_node(); Node n2 = g.add_node(); diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index a6cd88a006..c5f74ff392 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -4,6 +4,7 @@ using namespace FlexFlow; +// Rapidcheck infrastructures for graphs does not work for now /* Tests whether optimal_cost can give a valid result given random PCG, trivial allowed machine views, trivial cost estimator and random machine specification. diff --git a/lib/pcg/include/pcg/operator.h b/lib/pcg/include/pcg/operator.h index 3eb7fb2a43..bb9a4cf5e4 100644 --- a/lib/pcg/include/pcg/operator.h +++ b/lib/pcg/include/pcg/operator.h @@ -20,7 +20,7 @@ struct Operator { FF_VISITABLE_STRUCT(Operator, attrs, name); -static_assert(is_well_behaved_value_type::value, ""); +static_assert(is_well_behaved_value_type::value); } // namespace FlexFlow diff --git a/lib/pcg/src/operator.cc b/lib/pcg/src/operator.cc index 81e7326a76..9d36ae1b25 100644 --- a/lib/pcg/src/operator.cc +++ b/lib/pcg/src/operator.cc @@ -2,10 +2,6 @@ namespace FlexFlow { -// Operator::Operator(PCGOperatorAttrs const &attrs, -// std::optional const &name) -// : attrs(attrs), name(name) {} - Operator::operator PCGOperatorAttrs() const { return attrs; } diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index ded049f224..1ecd87226c 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -54,7 +54,8 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraphView); @@ -116,11 +117,13 @@ struct NodeLabelledMultiDiGraph : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index fab6695070..2162ee0384 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -55,7 +55,8 @@ struct NodeLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } }; @@ -119,11 +120,13 @@ struct NodeLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index f3cf14022b..882fca8df0 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -61,7 +61,8 @@ struct OutputLabelledMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } }; @@ -144,11 +145,13 @@ struct OutputLabelledMultiDiGraph private: Interface &get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 1c1b28c6d6..23dd9c190c 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -73,7 +73,8 @@ struct OutputLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } }; @@ -184,11 +185,13 @@ struct OutputLabelledOpenMultiDiGraph output_labelling(ol) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } cow_ptr_t node_labelling; diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index 8af69e18fc..3c69d62ae9 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -60,7 +60,8 @@ struct LabelledMultiDiGraphView : NodeLabelledMultiDiGraphView(ptr) {} Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraphView); @@ -131,11 +132,13 @@ struct LabelledMultiDiGraph : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index bdfe5ff599..dda9eef5e0 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -14,7 +14,8 @@ std::unordered_set } IDiGraphView const &DiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } Node DiGraph::add_node() { @@ -47,11 +48,11 @@ std::unordered_set } IDiGraph &DiGraph::get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); } IDiGraph const &DiGraph::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 771e01e573..99a7ea86fa 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -24,7 +24,7 @@ std::unordered_set } IMultiDiGraphView const &MultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } @@ -66,11 +66,12 @@ std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { } IMultiDiGraph const &MultiDiGraph::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } IMultiDiGraph &MultiDiGraph::get_ptr() { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/node.cc b/lib/utils/src/graph/node.cc index 72caa3136e..9854afffbf 100644 --- a/lib/utils/src/graph/node.cc +++ b/lib/utils/src/graph/node.cc @@ -53,11 +53,11 @@ std::unordered_set Graph::query_nodes(NodeQuery const &q) const { } IGraph const &Graph::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast(GraphView::ptr.get()); } IGraph &Graph::get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index 8355713506..c32ff6ded5 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -22,7 +22,7 @@ std::unordered_set } IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } @@ -56,7 +56,7 @@ NodePort OpenMultiDiGraph::add_node_port() { } IOpenMultiDiGraph &OpenMultiDiGraph::get_ptr() { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } @@ -77,7 +77,7 @@ std::unordered_set } IUpwardOpenMultiDiGraphView const &UpwardOpenMultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } @@ -107,12 +107,12 @@ std::unordered_set UpwardOpenMultiDiGraph::query_edges( } IUpwardOpenMultiDiGraph const &UpwardOpenMultiDiGraph::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } IUpwardOpenMultiDiGraph &UpwardOpenMultiDiGraph::get_ptr() { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } @@ -129,7 +129,7 @@ std::unordered_set IDownwardOpenMultiDiGraphView const & DownwardOpenMultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } @@ -165,7 +165,7 @@ std::unordered_set } IDownwardOpenMultiDiGraph &DownwardOpenMultiDiGraph::get_ptr() { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index b1e8be7f14..ce42cfe22c 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -26,12 +26,12 @@ void UndirectedGraph::remove_edge(UndirectedEdge const &e) { } IUndirectedGraph const &UndirectedGraph::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } IUndirectedGraph &UndirectedGraph::get_ptr() { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } @@ -56,7 +56,7 @@ std::unordered_set } IUndirectedGraphView const &UndirectedGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } From 09d3152ef80177118cd1ae51111c723f4f7482c7 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 28 Feb 2024 15:10:01 -0500 Subject: [PATCH 13/32] fix variant issue --- lib/utils/include/utils/containers.decl.h | 2 ++ lib/utils/include/utils/containers.h | 10 +++++++++ lib/utils/include/utils/variant.h | 10 ++++----- lib/utils/src/graph/algorithms.cc | 25 ++++++++--------------- lib/utils/src/graph/serialparallel.cc | 4 ++-- lib/utils/src/graph/views.cc | 11 +++++----- 6 files changed, 33 insertions(+), 29 deletions(-) diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 430da61ff9..fd35afe3fc 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -293,6 +293,8 @@ T reversed(T const &t); template std::vector value_all(std::vector> const &v); +template +std::unordered_set value_all(std::unordered_set> const &v); template std::vector subvec(std::vector const &v, diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 1d0151c38a..99c29564fb 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -675,6 +675,16 @@ std::vector value_all(std::vector> const &v) { }); } +template +std::unordered_set value_all(std::unordered_set> const &v) { + return transform(v, [](optional const &element) { + return unwrap(element, [] { + throw mk_runtime_error( + "Encountered element without value in call to value_all"); + }); + }); +} + template std::vector subvec(std::vector const &v, optional const &maybe_start, diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index bb78719c9e..132b7e66f4 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -58,7 +58,7 @@ struct elements_satisfy> : elements_satisfy_impl {}; template -struct is_in_variant; +struct is_in_variant : std::false_type {}; template struct is_in_variant> : std::true_type {}; template @@ -182,7 +182,7 @@ auto widen(Container const &c) -> decltype(transform( template < typename VariantOut, typename VariantIn, - typename = std::enable_if::value>> + typename = std::enable_if_t::value>> optional narrow(VariantIn const &v) { return visit(VariantNarrowFunctor{}, v); } @@ -191,7 +191,7 @@ template < typename VariantOut, typename Container, typename VariantIn = typename Container::value_type, - typename = std::enable_if::value>> + typename = std::enable_if_t::value>> auto narrow(Container const &c) -> decltype(transform( c, std::declval(VariantIn const &)>>())) { @@ -201,7 +201,7 @@ auto narrow(Container const &c) -> decltype(transform( template ::value>> + typename = std::enable_if_t::value>> auto narrow(Container const &c) { return transform(c, [](VariantIn const &e) { return get(e); }); } @@ -210,7 +210,7 @@ template , VariantIn>::value>> optional> narrow(VariantIn const &v) { return visit(VariantNarrowFunctor>{}, v); diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index 3b9877f71b..1667ddfce8 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -9,6 +9,7 @@ #include "utils/graph/traversal.h" #include "utils/graph/undirected.h" #include "utils/graph/views.h" +#include "utils/variant.h" #include #include #include @@ -256,7 +257,7 @@ std::unordered_set get_node_edges(UndirectedGraphView const &g, std::unordered_set get_outputs(MultiDiGraphView const &g) { return transform(get_edges(g), [&](MultiDiEdge const &e) -> MultiDiOutput { - return MultiDiOutput(e); + return static_cast(e); }); } @@ -333,37 +334,27 @@ std::unordered_map> std::unordered_set get_outgoing_edges(OpenMultiDiGraphView const &g, Node const &n) { - return transform(g.query_edges(OpenMultiDiEdgeQuery( + return value_all(narrow(g.query_edges(OpenMultiDiEdgeQuery( InputMultiDiEdgeQuery::none(), MultiDiEdgeQuery::all().with_src_nodes({n}), - OutputMultiDiEdgeQuery::all().with_src_nodes({n}))), - [](OpenMultiDiEdge const &e) { - return narrow(e).value(); - }); + OutputMultiDiEdgeQuery::all().with_src_nodes({n}))))); } std::unordered_set get_incoming_edges(OpenMultiDiGraphView const &g, Node const &n) { - return transform(g.query_edges(OpenMultiDiEdgeQuery( + return value_all(narrow(g.query_edges(OpenMultiDiEdgeQuery( InputMultiDiEdgeQuery::all().with_dst_nodes({n}), MultiDiEdgeQuery::all().with_dst_nodes({n}), - OutputMultiDiEdgeQuery::none())), - [](OpenMultiDiEdge const &e) { - return narrow(e).value(); - }); + OutputMultiDiEdgeQuery::none())))); } std::unordered_set get_open_outputs(OpenMultiDiGraphView const &g) { - return transform( - g.query_edges(OutputMultiDiEdgeQuery::all()), - [](OpenMultiDiEdge const &e) { return get(e); }); + return narrow(g.query_edges(OutputMultiDiEdgeQuery::all())); } std::unordered_set get_open_inputs(OpenMultiDiGraphView const &g) { - return transform( - g.query_edges(InputMultiDiEdgeQuery::all()), - [](OpenMultiDiEdge const &e) { return get(e); }); + return narrow(g.query_edges(InputMultiDiEdgeQuery::all())); } std::unordered_map> diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 41ecf3c436..8b179d31de 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -178,11 +178,11 @@ struct ToFinalAST { variant operator()(SplitASTNode const &node) { if (node.type == SplitType::SERIAL) { return Serial{transform(node.children, [](SplitAST const &s) { - return narrow(to_final_ast(s)).value(); + return narrow>(to_final_ast(s)).value(); })}; } else { return Parallel{transform(node.children, [](SplitAST const &s) { - return narrow(to_final_ast(s)).value(); + return narrow>(to_final_ast(s)).value(); })}; } } diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index dc823f7da4..a1308cffbb 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -445,9 +445,10 @@ std::unordered_set OpenMultiDiSubgraphView::OpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), - inputs(transform(get_cut_set(g, nodes), to_inputmultidiedge)), - outputs(transform(get_cut_set(g, nodes), to_outputmultidiedge)) {} + : g(g), nodes(nodes) { + this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); + this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); + } std::unordered_set OpenMultiDiSubgraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { @@ -470,7 +471,7 @@ std::unordered_set UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) : g(g), nodes(nodes) { - inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); + this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); } UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { @@ -496,7 +497,7 @@ std::unordered_set DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) : g(g), nodes(nodes) { - outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); + this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); } std::unordered_set From a150d3a90536ba6583276d52787b35fc15ba7f1d Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 28 Feb 2024 15:15:00 -0500 Subject: [PATCH 14/32] fmt --- lib/utils/src/graph/algorithms.cc | 20 +++++++++++--------- lib/utils/src/graph/views.cc | 6 +++--- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index 1667ddfce8..777d3d55d2 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -334,23 +334,25 @@ std::unordered_map> std::unordered_set get_outgoing_edges(OpenMultiDiGraphView const &g, Node const &n) { - return value_all(narrow(g.query_edges(OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), - MultiDiEdgeQuery::all().with_src_nodes({n}), - OutputMultiDiEdgeQuery::all().with_src_nodes({n}))))); + return value_all( + narrow(g.query_edges(OpenMultiDiEdgeQuery( + InputMultiDiEdgeQuery::none(), + MultiDiEdgeQuery::all().with_src_nodes({n}), + OutputMultiDiEdgeQuery::all().with_src_nodes({n}))))); } std::unordered_set get_incoming_edges(OpenMultiDiGraphView const &g, Node const &n) { - return value_all(narrow(g.query_edges(OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::all().with_dst_nodes({n}), - MultiDiEdgeQuery::all().with_dst_nodes({n}), - OutputMultiDiEdgeQuery::none())))); + return value_all(narrow(g.query_edges( + OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery::all().with_dst_nodes({n}), + MultiDiEdgeQuery::all().with_dst_nodes({n}), + OutputMultiDiEdgeQuery::none())))); } std::unordered_set get_open_outputs(OpenMultiDiGraphView const &g) { - return narrow(g.query_edges(OutputMultiDiEdgeQuery::all())); + return narrow( + g.query_edges(OutputMultiDiEdgeQuery::all())); } std::unordered_set get_open_inputs(OpenMultiDiGraphView const &g) { diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index a1308cffbb..af15b0d6aa 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -446,9 +446,9 @@ std::unordered_set OpenMultiDiSubgraphView::OpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) : g(g), nodes(nodes) { - this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); - this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); - } + this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); + this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); +} std::unordered_set OpenMultiDiSubgraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { From 2eb3fdfd2ab9efc7ff217c715303e69cd6da5f9a Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 11 Mar 2024 18:50:35 -0400 Subject: [PATCH 15/32] fix --- lib/compiler/test/src/test_generator.h | 313 +++++++++--------- lib/compiler/test/src/test_machine_mapping.cc | 34 +- lib/compiler/test/src/test_unity_algorithm.cc | 39 +-- 3 files changed, 194 insertions(+), 192 deletions(-) diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h index 6566c8c2de..c14743347a 100644 --- a/lib/compiler/test/src/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -8,161 +8,162 @@ using namespace FlexFlow; -/* - Generates computation graphs with trivial layers and tensors, which are used - for tests focusing on graph structures. -*/ -ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { - return materialize_output_labelled_multidigraph_view( - ViewMultiDiGraphAsOutputLabelled( - g, - [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, - [](Tensor(MultiDiOutput const &)) { - return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; - })); -} - -/* - Generates parallel computation graphs with trivial layers and tensors, which - are used for tests focusing on graph structures. -*/ -ParallelComputationGraph - test_parallel_computation_graph(MultiDiGraphView const &g) { - return materialize_output_labelled_multidigraph_view( - ViewMultiDiGraphAsOutputLabelled( - g, - [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, - [](Operator(MultiDiOutput const &)) { - return ParallelTensor(ParallelTensorDims(TensorDims({})), - DataType::FLOAT); - })); -} - -rc::Gen small_integer_generator() { - return rc::gen::inRange(1, 4); -} - -namespace rc { - -Gen serialParallelMultiDiGraph() { - return gen::map(gen::arbitrary(), - multidigraph_from_sp_decomposition); -} - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::map(gen::cast(serialParallelMultiDiGraph()), - test_computataion_graph); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::map(gen::cast(serialParallelMultiDiGraph()), - test_parallel_computation_graph); - } -}; - -template <> -struct Arbitrary> { - static Gen> arbitrary() { - return gen::mapcat(gen::arbitrary(), [](bool is_node) { - return is_node - ? gen::cast>(gen::arbitrary()) - : gen::cast>(gen::arbitrary()); - }); - } -}; - -template <> -struct Arbitrary> { - static Gen> arbitrary() { - return gen::mapcat(gen::arbitrary(), [](bool is_node) { - return is_node - ? gen::cast>(gen::arbitrary()) - : gen::cast>( - gen::arbitrary()); - }); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&Serial::children, - gen::container>>( - gen::arbitrary>()))); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&Parallel::children, - gen::container>>( - gen::arbitrary>()))); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::mapcat(gen::arbitrary(), [](bool is_serial) { - return is_serial ? gen::construct( - gen::arbitrary()) - : gen::construct( - gen::arbitrary()); - }); - } -}; - -template -struct Arbitrary { - static Gen< - std::enable_if, Tag>::value>::type> - arbitrary() { - return gen::construct(gen::arbitrary()); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::apply(make_1d_machine_view, - gen::arbitrary, - gen::arbitrary, - small_integer_generator()); - } -} - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&MachineMapping::machine_views, - gen::container>( - gen::arbitrary(), gen::arbitrary()))); - } -} - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), - gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), - gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), - gen::set(&MachineSpecification::inter_node_bandwidth, - gen::nonZero()), - gen::set(&MachineSpecification::intra_node_bandwidth, - gen::nonZero())); - } -} - -} // namespace rc +// Rapidcheck does not work for now +// /* +// Generates computation graphs with trivial layers and tensors, which are used +// for tests focusing on graph structures. +// */ +// ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { +// return materialize_output_labelled_multidigraph_view( +// ViewMultiDiGraphAsOutputLabelled( +// g, +// [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, +// [](Tensor(MultiDiOutput const &)) { +// return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; +// })); +// } + +// /* +// Generates parallel computation graphs with trivial layers and tensors, which +// are used for tests focusing on graph structures. +// */ +// ParallelComputationGraph +// test_parallel_computation_graph(MultiDiGraphView const &g) { +// return materialize_output_labelled_multidigraph_view( +// ViewMultiDiGraphAsOutputLabelled( +// g, +// [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, +// [](Operator(MultiDiOutput const &)) { +// return ParallelTensor(ParallelTensorDims(TensorDims({})), +// DataType::FLOAT); +// })); +// } + +// rc::Gen small_integer_generator() { +// return rc::gen::inRange(1, 4); +// } + +// namespace rc { + +// Gen serialParallelMultiDiGraph() { +// return gen::map(gen::arbitrary(), +// multidigraph_from_sp_decomposition); +// } + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::map(gen::cast(serialParallelMultiDiGraph()), +// test_computataion_graph); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::map(gen::cast(serialParallelMultiDiGraph()), +// test_parallel_computation_graph); +// } +// }; + +// template <> +// struct Arbitrary> { +// static Gen> arbitrary() { +// return gen::mapcat(gen::arbitrary(), [](bool is_node) { +// return is_node +// ? gen::cast>(gen::arbitrary()) +// : gen::cast>(gen::arbitrary()); +// }); +// } +// }; + +// template <> +// struct Arbitrary> { +// static Gen> arbitrary() { +// return gen::mapcat(gen::arbitrary(), [](bool is_node) { +// return is_node +// ? gen::cast>(gen::arbitrary()) +// : gen::cast>( +// gen::arbitrary()); +// }); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&Serial::children, +// gen::container>>( +// gen::arbitrary>()))); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&Parallel::children, +// gen::container>>( +// gen::arbitrary>()))); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::mapcat(gen::arbitrary(), [](bool is_serial) { +// return is_serial ? gen::construct( +// gen::arbitrary()) +// : gen::construct( +// gen::arbitrary()); +// }); +// } +// }; + +// template +// struct Arbitrary { +// static Gen< +// std::enable_if, Tag>::value>::type> +// arbitrary() { +// return gen::construct(gen::arbitrary()); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::apply(make_1d_machine_view, +// gen::arbitrary, +// gen::arbitrary, +// small_integer_generator()); +// } +// } + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&MachineMapping::machine_views, +// gen::container>( +// gen::arbitrary(), gen::arbitrary()))); +// } +// } + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), +// gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), +// gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), +// gen::set(&MachineSpecification::inter_node_bandwidth, +// gen::nonZero()), +// gen::set(&MachineSpecification::intra_node_bandwidth, +// gen::nonZero())); +// } +// } + +// } // namespace rc #endif diff --git a/lib/compiler/test/src/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc index 779f8134d9..b2abc6929d 100644 --- a/lib/compiler/test/src/test_machine_mapping.cc +++ b/lib/compiler/test/src/test_machine_mapping.cc @@ -1,21 +1,21 @@ -#include "doctest/doctest.h" -#include "test_generator.h" +// #include "doctest/doctest.h" +// #include "test_generator.h" -TEST_CASE("MachineMapping::combine") { - rc::check([](MachineMapping const &m0, MachineMapping const &m1) { - RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); +// TEST_CASE("MachineMapping::combine") { +// rc::check([](MachineMapping const &m0, MachineMapping const &m1) { +// RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); - MachineMapping comb = MachineMapping::combine(m0, m1); +// MachineMapping comb = MachineMapping::combine(m0, m1); - RC_ASSERT(comb.machine_views.size() == - m0.machine_views.size() + m1.machine_views.size()); - RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); - RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); - }); -} +// RC_ASSERT(comb.machine_views.size() == +// m0.machine_views.size() + m1.machine_views.size()); +// RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); +// RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); +// }); +// } -TEST_CASE("OptimalCostResult::infinity") { - rc::check([](OptimalCostResult const &c) { - RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); - }); -} +// TEST_CASE("OptimalCostResult::infinity") { +// rc::check([](OptimalCostResult const &c) { +// RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); +// }); +// } diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc index 6a0131dd77..cceecae831 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/test_unity_algorithm.cc @@ -2,22 +2,23 @@ #include "test_cost_estimator.h" #include "test_generator.h" -TEST_CASE("graph_optimize") { - rc::check([](ComputationGraph const &g, - float alpha, - int budget, - float threshold, - int max_num_ops) { - Strategy s = graph_optimize( - g, - TestCostEstimator{}, - MachineSpecification{1, 1, 4, 0.1, 0.2}, - [](Operator const &, MachineSpecification const &) { - return std::unordered_set{make_1d_machine_view(0, 1, 1)}; - }, - OptimizerConfig{alpha, budget, threshold, max_num_ops}); - RC_ASSERT(get_nodes(s.pcg).size() > 0); - RC_ASSERT(s.machine_mapping.runtime > 0); - RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); - }); -} +// Rapidcheck does not work for now +// TEST_CASE("graph_optimize") { +// rc::check([](ComputationGraph const &g, +// float alpha, +// int budget, +// float threshold, +// int max_num_ops) { +// Strategy s = graph_optimize( +// g, +// TestCostEstimator{}, +// MachineSpecification{1, 1, 4, 0.1, 0.2}, +// [](Operator const &, MachineSpecification const &) { +// return std::unordered_set{make_1d_machine_view(0, 1, 1)}; +// }, +// OptimizerConfig{alpha, budget, threshold, max_num_ops}); +// RC_ASSERT(get_nodes(s.pcg).size() > 0); +// RC_ASSERT(s.machine_mapping.runtime > 0); +// RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); +// }); +// } From 7598a923848588234262a61a83a4fa8bd0377f33 Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 11 Mar 2024 18:55:34 -0400 Subject: [PATCH 16/32] fmt --- lib/compiler/test/src/test_generator.h | 29 +++++++++++-------- lib/compiler/test/src/test_unity_algorithm.cc | 3 +- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h index c14743347a..d6b8222968 100644 --- a/lib/compiler/test/src/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -10,8 +10,8 @@ using namespace FlexFlow; // Rapidcheck does not work for now // /* -// Generates computation graphs with trivial layers and tensors, which are used -// for tests focusing on graph structures. +// Generates computation graphs with trivial layers and tensors, which are +// used for tests focusing on graph structures. // */ // ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { // return materialize_output_labelled_multidigraph_view( @@ -24,8 +24,8 @@ using namespace FlexFlow; // } // /* -// Generates parallel computation graphs with trivial layers and tensors, which -// are used for tests focusing on graph structures. +// Generates parallel computation graphs with trivial layers and tensors, +// which are used for tests focusing on graph structures. // */ // ParallelComputationGraph // test_parallel_computation_graph(MultiDiGraphView const &g) { @@ -53,7 +53,8 @@ using namespace FlexFlow; // template <> // struct Arbitrary { // static Gen arbitrary() { -// return gen::map(gen::cast(serialParallelMultiDiGraph()), +// return +// gen::map(gen::cast(serialParallelMultiDiGraph()), // test_computataion_graph); // } // }; @@ -61,7 +62,8 @@ using namespace FlexFlow; // template <> // struct Arbitrary { // static Gen arbitrary() { -// return gen::map(gen::cast(serialParallelMultiDiGraph()), +// return +// gen::map(gen::cast(serialParallelMultiDiGraph()), // test_parallel_computation_graph); // } // }; @@ -72,7 +74,8 @@ using namespace FlexFlow; // return gen::mapcat(gen::arbitrary(), [](bool is_node) { // return is_node // ? gen::cast>(gen::arbitrary()) -// : gen::cast>(gen::arbitrary()); +// : gen::cast>(gen::arbitrary()); // }); // } // }; @@ -124,8 +127,8 @@ using namespace FlexFlow; // template // struct Arbitrary { // static Gen< -// std::enable_if, Tag>::value>::type> -// arbitrary() { +// std::enable_if, +// Tag>::value>::type> arbitrary() { // return gen::construct(gen::arbitrary()); // } // }; @@ -146,7 +149,8 @@ using namespace FlexFlow; // return gen::build( // gen::set(&MachineMapping::machine_views, // gen::container>( -// gen::arbitrary(), gen::arbitrary()))); +// gen::arbitrary(), +// gen::arbitrary()))); // } // } @@ -155,8 +159,9 @@ using namespace FlexFlow; // static Gen arbitrary() { // return gen::build( // gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), -// gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), -// gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), +// gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, +// 64)), gen::set(&MachineSpecification::num_gpus_per_node, +// gen::inRange(1, 16)), // gen::set(&MachineSpecification::inter_node_bandwidth, // gen::nonZero()), // gen::set(&MachineSpecification::intra_node_bandwidth, diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc index cceecae831..c39b3ef14f 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/test_unity_algorithm.cc @@ -14,7 +14,8 @@ // TestCostEstimator{}, // MachineSpecification{1, 1, 4, 0.1, 0.2}, // [](Operator const &, MachineSpecification const &) { -// return std::unordered_set{make_1d_machine_view(0, 1, 1)}; +// return std::unordered_set{make_1d_machine_view(0, 1, +// 1)}; // }, // OptimizerConfig{alpha, budget, threshold, max_num_ops}); // RC_ASSERT(get_nodes(s.pcg).size() > 0); From 05c8336ef7109f112c67ec838540fb8a1b06dfb3 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 13 Mar 2024 20:56:15 -0400 Subject: [PATCH 17/32] fix --- lib/compiler/src/graph_utils.cc | 2 +- lib/compiler/src/machine_mapping.cc | 23 ++----------------- lib/op-attrs/src/attention.cc | 7 ++++++ lib/op-attrs/src/embedding.cc | 8 ++++++- .../src/parallel_dim_mapping_record_solver.cc | 8 +++++++ lib/pcg/src/strided_rectangle.cc | 4 ++++ .../include/utils/graph/labelled/open_views.h | 6 ++--- .../utils/graph/labelled/output_labelled.h | 18 +++++++-------- lib/utils/include/utils/graph/views.h | 16 ++++++------- lib/utils/src/graph/open_graphs.cc | 2 +- lib/utils/src/graph/serialparallel.cc | 2 +- 11 files changed, 50 insertions(+), 46 deletions(-) diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 069ae4a41f..3c6e44216b 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -54,7 +54,7 @@ std::unordered_map } } - assert(result.size() == get_edges(pcg).size()); + assert(result.size() == get_edges(pcg.value()).size()); return result; } diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 5ce988b951..b48e200c15 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -94,6 +94,7 @@ float estimate_cost(SubParallelComputationGraphView const &g, MachineMapping const &device_mapping, std::unordered_map const &frontier_machine_views) { + // TODO: Consider parallelism float cost = 0; for (Node const &node : get_nodes(g)) { std::unordered_set incoming_edges = @@ -106,26 +107,6 @@ float estimate_cost(SubParallelComputationGraphView const &g, cost += estimator.estimate_cost( g.at(node).attrs, inputs, device_mapping.machine_views.at(node)); } - - for (OpenMultiDiEdge const &edge : get_edges(g)) { - if (holds_alternative(edge)) { - cost += estimator.estimate_cost( - g.at(edge).get_shape(), - frontier_machine_views.at(edge), - device_mapping.machine_views.at(get(edge).dst)); - } else if (holds_alternative(edge)) { - cost += estimator.estimate_cost( - g.at(edge).get_shape(), - device_mapping.machine_views.at(get(edge).src), - frontier_machine_views.at(edge)); - } else { - assert(holds_alternative(edge)); - cost += estimator.estimate_cost( - g.at(edge).get_shape(), - device_mapping.machine_views.at(get(edge).src), - device_mapping.machine_views.at(get(edge).dst)); - } - } return cost; } @@ -308,7 +289,7 @@ struct MachineMappingSearcher { &frontier_machine_views) { if (contains_key(given_machine_views, node)) { assert(contains(allowed_machine_views(g.at(node), resource), - source_machine_view.value())); + given_machine_views.at(node))); MachineMapping mv_map{given_machine_views}; return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), mv_map}; diff --git a/lib/op-attrs/src/attention.cc b/lib/op-attrs/src/attention.cc index 4b6c53897c..2c1500a477 100644 --- a/lib/op-attrs/src/attention.cc +++ b/lib/op-attrs/src/attention.cc @@ -91,7 +91,14 @@ TensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, static_cast(value_shape)); return get_tensor_shape_unsafe(parallel_shape); } +TensorShape get_output_shape(MultiHeadAttentionAttrs const &, + MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} +int get_oSize(ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} } // namespace FlexFlow // Tensor FFModel::multihead_attention(const Tensor query, diff --git a/lib/op-attrs/src/embedding.cc b/lib/op-attrs/src/embedding.cc index 02cbfaa031..56014fcc67 100644 --- a/lib/op-attrs/src/embedding.cc +++ b/lib/op-attrs/src/embedding.cc @@ -1,3 +1,9 @@ #include "op-attrs/ops/embedding.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc b/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc index 68686393f5..500119241d 100644 --- a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc +++ b/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc @@ -351,4 +351,12 @@ void construct_output_parallel_dims( /* return solution; */ /* } */ +ParallelDimMappingSolution solve_parallel_dim_mappings( + std::vector const &mappings, + std::vector const &input, + int numWeights, + int numOutputs) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/strided_rectangle.cc b/lib/pcg/src/strided_rectangle.cc index 7f612b743b..27ef9a7f5b 100644 --- a/lib/pcg/src/strided_rectangle.cc +++ b/lib/pcg/src/strided_rectangle.cc @@ -30,4 +30,8 @@ side_size_t StridedRectangleSide::get_size() const { NOT_IMPLEMENTED(); } +size_t StridedRectangle::num_dims() const { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index a24c2b940b..494d8d9f9d 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -48,8 +48,8 @@ struct OutputLabelledOpenMultiDiSubgraphView } private: - OutputLabelledOpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OutputLabelledOpenMultiDiGraphView g; + std::unordered_set nodes; }; template @@ -86,7 +86,7 @@ struct ViewOutputLabelledAsOutputLabelledOpen } private: - OutputLabelledMultiDiGraphView const &g; + OutputLabelledMultiDiGraphView g; }; template diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index 882fca8df0..9c65db4daa 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -31,20 +31,19 @@ struct OutputLabelledMultiDiGraphView OutputLabelledMultiDiGraphView & operator=(OutputLabelledMultiDiGraphView const &) = default; - virtual NodeLabel const &at(Node const &n) const { + NodeLabel const &at(Node const &n) const { return get_ptr().at(n); } - virtual OutputLabel const &at(MultiDiOutput const &o) const { + OutputLabel const &at(MultiDiOutput const &o) const { return get_ptr().at(o); } - virtual std::unordered_set query_nodes(NodeQuery const &q) const { + std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } - virtual std::unordered_set - query_edges(MultiDiEdgeQuery const &q) const { + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { return get_ptr().query_edges(q); } @@ -93,7 +92,7 @@ struct OutputLabelledMultiDiGraph return nl.get_mutable()->get_label(n); } - NodeLabel const &at(Node const &n) const override { + NodeLabel const &at(Node const &n) const { return nl->get_label(n); } @@ -113,16 +112,15 @@ struct OutputLabelledMultiDiGraph return ol.get_mutable()->get_label(o); } - OutputLabel const &at(MultiDiOutput const &o) const override { + OutputLabel const &at(MultiDiOutput const &o) const { return ol->get_label(o); } - std::unordered_set query_nodes(NodeQuery const &q) const override { + std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } - std::unordered_set - query_edges(MultiDiEdgeQuery const &q) const override { + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { return get_ptr().query_edges(q); } diff --git a/lib/utils/include/utils/graph/views.h b/lib/utils/include/utils/graph/views.h index 776a72e6d5..43d813bf8c 100644 --- a/lib/utils/include/utils/graph/views.h +++ b/lib/utils/include/utils/graph/views.h @@ -256,8 +256,8 @@ struct OpenMultiDiSubgraphView : public IOpenMultiDiGraphView { OpenMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; std::unordered_set inputs; std::unordered_set outputs; }; @@ -274,8 +274,8 @@ struct UpwardOpenMultiDiSubgraphView : public IOpenMultiDiGraphView { UpwardOpenMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; std::unordered_set inputs; }; @@ -291,8 +291,8 @@ struct DownwardOpenMultiDiSubgraphView : public IOpenMultiDiGraphView { DownwardOpenMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; std::unordered_set outputs; }; @@ -308,8 +308,8 @@ struct ClosedMultiDiSubgraphView : public IOpenMultiDiGraphView { ClosedMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; }; UndirectedEdge to_undirected_edge(DirectedEdge const &); diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index c32ff6ded5..e0bc94ca8c 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -22,7 +22,7 @@ std::unordered_set } IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 8b179d31de..3461e27ddf 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -19,7 +19,7 @@ Node find_sink_node(DiGraphView const &g) { optional find_bottleneck_node(DiGraphView const &g) { std::unordered_set sources = get_sources(g); - std::unordered_set sinks = get_sources(g); + std::unordered_set sinks = get_sinks(g); optional maybe_bottleneck = get_imm_post_dominator(g, sources); if (maybe_bottleneck.has_value()) { From 6962bc86b2fa13358c0e5a46c9172ee5017710c9 Mon Sep 17 00:00:00 2001 From: Bob Chen <70640928+Bob-Chen222@users.noreply.github.com> Date: Thu, 14 Mar 2024 18:58:02 -0400 Subject: [PATCH 18/32] additional doc --- .../include/substitutions/attribute_expr.h | 40 ++++++++++++++ .../include/substitutions/get_attribute.h | 4 ++ .../include/substitutions/graph_pattern.h | 34 ++++++++++++ .../substitutions/graph_pattern_match.h | 52 ++++++++++--------- .../include/substitutions/operator_pattern.h | 32 +++++++----- .../include/substitutions/output_graph.h | 33 +++++++++++- .../substitutions/parallel_tensor_pattern.h | 23 +++++--- .../sub_parallel_computation_graph.h | 7 +++ .../include/substitutions/substitution.h | 18 +++++++ 9 files changed, 199 insertions(+), 44 deletions(-) diff --git a/lib/substitutions/include/substitutions/attribute_expr.h b/lib/substitutions/include/substitutions/attribute_expr.h index d6902d1274..c243d80a28 100644 --- a/lib/substitutions/include/substitutions/attribute_expr.h +++ b/lib/substitutions/include/substitutions/attribute_expr.h @@ -7,20 +7,55 @@ namespace FlexFlow { enum class ConstraintType { EQUAL }; +/** + * @struct ListIndexAccess + * @brief Given the attribute key, retrieve the specific value stored at index i in the attribute + * This struct will be used in EvaluateOperatorAttributeExpr and EvaluateTensorAttributeExpr, + * where we evaluate the expression and return the concrete value of the attribute stored at index i + */ template struct ListIndexAccess { T attribute_key; req index; }; +/** + * @struct ListSize + * @brief Given the type of an attribute, retrieve the size of the attribute + * Specifically, for the OperatorAttributeValue, the size of the attribute is always MAX_TENSOR_DIM + * For the TensorAttributeValue, the size of the attribute is the size of the vector that represents + * the specific attribute of tensor in PCG + */ template struct ListSize { req attribute_key; }; +/** + * @struct AttributeExpr + * @brief AttributeExpr is a representation of ways to access the attribute. + * It can be a direct value, or a list index access, or a list size. + * For example, padding of a Conv2D operator will be represented as a int, + * and the dimension of a tensor will be represented as a vector to which + * we can access the vector size with ListSize and access the specific value + * with ListIndexAccess + */ template using AttributeExpr = variant, ListSize>; + +/** + * @struct AttributeConstraint + * @brief AttributeConstraint is additional constraint imposed when doing pattern matching other than + * just matching graph topology. Specifically, given a pattern and a graph, matching solely the attribute + * type is not enough as there are other factors to consider. For example, if we want to fuse two dense + * layer, we need to match the input shape; given a dense layer, we need to make sure the input shape matches + * the output shape of the previous layer. + * + * Given an attribute expression, attribute_expr should have a relationship with attribute_value defined by + * constraint_type. Currently only EQUAL is supported, meaning that the attribute_expr should be equal to + * attribute_value after evaluation. + */ template struct AttributeConstraint { ConstraintType constraint_type; @@ -28,6 +63,11 @@ struct AttributeConstraint { V attribute_value; }; + +/** + * @struct AttributePattern + * @brief AttributePattern is a collection of attribute constraints for pattern matching to satisfy. + */ template struct AttributePattern { std::vector> attribute_constraints; diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/get_attribute.h index f35145133e..fd390e540e 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/get_attribute.h @@ -5,6 +5,10 @@ #include "operator_pattern.h" #include "utils/optional.h" + +/** + * @brief overloading get_attribute functions for different operator attributes. + */ namespace FlexFlow { optional get_attribute(PCGOperatorAttrs const &, diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/graph_pattern.h index 4f4021203b..6e0f839e28 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/graph_pattern.h @@ -8,6 +8,16 @@ namespace FlexFlow { +/** + * @struct GraphPattern + * @brief A GraphPattern is defined as an open graph with node label OperatorPattern + * and output label ParallelTensorPattern, which is refered to as the pattern graph. + * The graph structure of a GraphPattern instance defines the geometrical property + * of the input graph, while the node labels and output labels define the attribute + * property of that. To be detailed, the OperatorPattern and ParallelTensorPattern + * contains a set of constraints and the corresponding graph needs to satisfy these + * constraints in order to be considered as match. + */ struct GraphPattern : public strong_typedef< GraphPattern, @@ -16,15 +26,39 @@ struct GraphPattern using strong_typedef::strong_typedef; }; +/** + * @brief Given a pattern, split_pattern is used to split the pattern + * and recursively match the sub-patterns. + */ GraphSplit split_pattern(OpenMultiDiGraphView const &pattern); +/** + * @brief singleton_pattern is defined as a pattern that has only one node. + * A singleton pattern serves as the base case for recursive pattern matching. + */ bool is_singleton_pattern(OpenMultiDiGraphView const &); +/** + * @brief operator_satisfies checks if the operator satisfies the set of constraints. + * shown in the pattern. + */ bool operator_satisfies(Operator const ¶ms, OperatorPattern const &pattern); + +/** + * @brief parallel_tensor_satisfies checks if the parallel tensor satisfies the set of + * constraints shown in the pattern. + */ bool parallel_tensor_satisfies(ParallelTensor const ¶ms, ParallelTensorPattern const &pattern); +/** + * @brief assignment_satifies checks if the provided MultiDiGraphPatternMatch is a valid + * description of how GraphPattern can be mapped to SubParallelComputationGraph. + * + * It checkes if the node and edge assignments satisfy the constraints of the pattern and whether + * the graph topology matches. + */ bool assignment_satisfies(SubParallelComputationGraph const &, GraphPattern const &, MultiDiGraphPatternMatch const &); diff --git a/lib/substitutions/include/substitutions/graph_pattern_match.h b/lib/substitutions/include/substitutions/graph_pattern_match.h index d441d5390e..baf3eae4c2 100644 --- a/lib/substitutions/include/substitutions/graph_pattern_match.h +++ b/lib/substitutions/include/substitutions/graph_pattern_match.h @@ -8,11 +8,22 @@ namespace FlexFlow { /** * @struct MultiDiGraphPatternMatch - * @brief MultiDiGraphPatternMatch is a struct that describes a mapping from how an open graph is matched with - * a PCG graph. - * To apply a substitution to a PCG, we should first match the pattern graph to a subgraph of the PCG. MultiDiGraphPatternMatch describes the match, - * which consists of a node_assignment that describes how the GraphPattern node mapped to PCG node and an edge_assignment that describes - * how the GraphPattern edge mapped to PCG edge. + * @brief MultiDiGraphPatternMatch describes a specific location in an OpenMultiDiGraph where a given pattern matches. + * + * Given a graph and a pattern there can be zero, one, or multiple locations where it can match. + * + * To provide some intuition, consider matching over strings instead of graphs: given a regex pattern "a.b" and a string "acbfadbga", there are two valid match locations: + * we can either match the "acb" at the beginning of the string, or the "adb" in the middle of the string. + * MultiDiGraphPatternMatch represents the difference between the two possible locations using a bidict which maps between + * objects in the pattern and the corresponding objects in the matched data structure. For example, in the string example above, + * the two matchings would be as follows: + * "acbfadbga" "acbfadbga" + * ^^^ ^^^ + * ||| ||| + * vvv vvv + * "a.b" "a.b" + * Of course in the context of graphs there are two types of objects to be matched: nodes and edges. + * As such our match consists of not one but two bidict mappings: one for nodes (node_assignment) and one for edges (edge_assignment). */ struct MultiDiGraphPatternMatch { using PatternNode = Node; @@ -25,20 +36,21 @@ struct MultiDiGraphPatternMatch { using PCGEdge = OpenMultiDiEdge; /** - * @brief node_assignment is a bidirectional map from PatternNode to PCGNode + * @brief node_assignment describes the mapping between PatternNode and PCGNode as a part of the substitution. */ bidict node_assignment; /** - * @brief edge_assignment is a bidirectional map from PatternEdge to PCGEdge + * @brief edge_assignment describes the mapping between PatternEdge and PCGEdge as a part of the substitution. */ bidict edge_assignment; }; /** * @struct MatchSplit - * @brief MatchSplit is a struct that describes a split of a MultiDiGraphPatternMatch into two sub MultiDiGraphPatternMatch - * + * @brief MatchSplit is a struct that describes a split of a MultiDiGraphPatternMatch into + * two sub MultiDiGraphPatternMatches by dividing the nodes into half. When applying pattern + * matches, the pattern will be split into two parts and recursively matched against the graph. */ struct MatchSplit { MultiDiGraphPatternMatch prefix_submatch; @@ -48,6 +60,9 @@ struct MatchSplit { /** * @struct MatchAdditionalCriterion * @brief The additional conditions need to be satisfied other than geometric properties of the graph. + * Specifically as mentioned in attribute_expr.h, other than matching graph topology, we also need to make sure + * the attributes(eg. shape of dense layer) should be matched as well. The additional constraints + * AttributeConstraint will be imposed inside node_criterion and edge_criterion for each potential match. */ struct MatchAdditionalCriterion { std::function node_criterion; @@ -56,14 +71,9 @@ struct MatchAdditionalCriterion { }; /** - * @brief pattern_matches checks if the pattern graph matches the graph with additional conditions defined by additional_criterion. - * @param pattern The pattern graph - * @param graph The graph to be matched - * @param match The mapping between the pattern graph and the graph - * @param additional_criterion The additional conditions need to be satisfied other than geometric properties of the graph. - * @return true if the pattern graph matches the graph, false otherwise. - * @details function is used to check whether the generated match from pattern to graph is valid or not. It is used in find_pattern_matches to check against all the enumerated matches - * and filter out the invalid ones. + * @brief pattern_matches checks if the pattern graph matches the graph with additional conditions defined + * by additional_criterion. It is used as the last checking step to see if the pattern matches the graph + * attributewise inside find_pattern_matches. */ bool pattern_matches(OpenMultiDiGraphView const &pattern, OpenMultiDiGraphView const &graph, @@ -71,13 +81,7 @@ bool pattern_matches(OpenMultiDiGraphView const &pattern, MatchAdditionalCriterion const &additional_criterion); /** - * @brief generate all valid matches from pattern to a subgraph of graph - * @param pattern - * @param graph - * @param additional_criterion - * @return std::vector - * - * @details Given a pattern and a graph, find all the valid matches between the pattern and the graph with additional conditions defined by additional_criterion. + * @brief find_pattern_matches generate all valid matches from pattern to a subgraph of graph. */ std::vector find_pattern_matches(OpenMultiDiGraphView const &pattern, diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index 31805438f9..0dae8c8d21 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -12,9 +12,16 @@ namespace FlexFlow { /** - * @brief OperatorAttributeKey is an enum class that represents the keys of the attributes of an Operator. - * Each operator has a set of attributes that describe its behavior. OperatorAttributeKey is used to retrieve the value of an attribute or the expression of an attribute stored - * in an attribute map. + * @enum OperatorAttributeKey + * @brief OperatorAttributeKey represents the keys of the attributes of an Operator. + * Specifically, each operator have a set of attributes, and each attribute will have + * a key as its name and a concrete value representation. + * The OP_TYPE is a OperatorAttributeKey is a special attribute key that represents the + * type of the Operator and will exist in every Operator. Given the OP_TYPE, the other + * attributes will be determined accordingly. + * + * For example, a batch matrix multiplication Operator will have OP_TYPE BATCH_MATMUL and + * dimensions as A_SEQ_LENGTH_DIM and B_SEQ_LENGTH_DIM */ enum class OperatorAttributeKey { OP_TYPE, // AnyOp @@ -76,9 +83,9 @@ enum class OperatorAttributeKey { }; /** - * @brief OperatorAttributeValue is a variant that represents the concrete value of an attribute of an Operator. - * The OperatorAttributeValue is evalutated from AttributeExpr - * The datatype of the value corresponds to the datatype of the attributekey listed in OperatorAttributeKey. + * @brief OperatorAttributeValue is a representation of the concrete value of an attribute of an Operator. + * The OperatorAttributeValue is evaluated from AttributeExpr. The datatype of the value corresponds to the + * datatype of the attributekey listed in OperatorAttributeKey. */ using OperatorAttributeValue = variant, FF_VISITABLE_STRUCT(ListSize, attribute_key); /** - * @todo: need to better understand what is constraints and pattern - * + * @brief OperatorAttributeConstraint is an instance of template struct AttributeConstraint. */ using OperatorAttributeConstraint = AttributeConstraint; +/** + * @brief OperatorPattern is an instance of template struct AttributePattern. + */ using OperatorPattern = AttributePattern; /** - * @brief Given a specific attribute of an Operator, evaluate the expression of the attribute and return the value of the attribute. - * @param attrs - * @param expr - * @return optional + * @brief Given a specific attribute of an Operator, evaluate the expression of the attribute + * using one of the three methods: direct value, list index access, or list size and return the + * value of the attribute. */ optional evaluate_attribute_expr(Operator const &attrs, diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index b9cf1f53f3..dbb1108b24 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -6,23 +6,52 @@ namespace FlexFlow { // NOTE(@wmdi) I am not sure whether these should be part of attribute expr. + +/** + * @struct OperatorAttrAccess + * @brief OperatorAttrAccess consists of a node and an expression attr_expr + * on the attributes of the operator associated with the node. The value of a + * NodeAttrAccess instance is the value of attr_expr evaluated on the operator + * associated with the node. + */ struct OperatorAttrAccess { Node node; AttributeExpr attr_expr; }; +/** + * @struct AttrConstant + * @brief AttrConstant is a constant value that is used as an attribute expression. + */ struct AttrConstant { OperatorAttributeValue value; }; +/** + * @brief OperatorAttributeExpr is a access to the attribute of an operator and can be + * evaluated to a concrete value. OperatorAttributeExpr is used at substitution phase. + * It will be evaluated and used to create new operator with the evaluated value. + */ using OperatorAttributeExpr = variant; -// NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can -// define the assignment for each operator type. +/** + * @brief OperatorAttrAssignment is a collection of OperatorAttributeKey and + * GraphAttributeExpr pairs for a single operator. It defines how the attributes + * of a single operator is calculated from the input graph. A pair + * {operator_attribute_key, graph_attribute_expr} in the collection means the value + * of graph_attribute_expr is assigned to the attribute named operator_attribute_key + * of the operator. + */ struct OperatorAttrAssignment { std::unordered_map assignments; }; +/** + * @brief An OutputGraphExpr is defined as an open graph with node label + * OperatorAttrAssignment and output label ParallelTensorAttrAssignment, which + * defines how the operator attributes and the parallel tensor attributes of the + * output graph are derived from the input graph. + */ struct OutputGraphExpr : public strong_typedef< OutputGraphExpr, diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h index 2c22ff878f..da873da40d 100644 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -7,26 +7,37 @@ namespace FlexFlow { /** - * @brief TensorAttributeKey is an enum class that represents the keys of the attributes of a Tensor(matrix). - * DIM_SIZES describes the length along each dimension of the tensor + * @brief TensorAttributeKey is an enum class that represents the keys of the + * attributes of a Tensor(matrix). + * DIM_SIZES describes the size of each dimension of the tensor for data parallelism computation * DIM_DEGREES describes the number of partitions along each dimension of the tensor for data parallelism computation */ enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; +/** + * @brief DIM_SIZES and DIM_DEGREES are represented by + * a vector of ints that is listed as corresponding dimension + */ using TensorAttributeValue = variant>; +/** + * @brief TensorAttributeConstraint is an instance of AttributeConstraint that + * defines the contraint a tensor should satisfy when doing pattern matching. + */ using TensorAttributeConstraint = AttributeConstraint; +/** + * @brief ParallelTensor is an instance of OperatorAttributeExpr that represents + * a set of constraints pattern matching should satisfy. + */ using ParallelTensorPattern = AttributePattern; /** * @brief evaluate_attribute_expr evaluates the attribute expression for a given ParallelTensor - * - * @param tensor_shape, which describes the attributes of a ParallelTensor - * @param expr, which describes the specific attribute expression to be evaluated - * @return optional + * the ParallelTensor parameter is named tensor_shape because the numerical value will only be used + * in runtime. For the substitution phase, all that matters is the shape of the tensor. */ optional evaluate_attribute_expr(ParallelTensor const &tensor_shape, diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 0d6bfe7628..e5940007c8 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -9,6 +9,13 @@ namespace FlexFlow { +/** + * @brief SubParallelComputationGraph is defined as an open graph, which allows nodes and edges + * that are not from the same graph to be added to it. + * This definition is useful when we want to split and merge graphs when doing pattern matching. + * In contrast, the ParallelComputationGraph is defined as a closed graph and all the edges and + * nodes are within that graph. + */ using SubParallelComputationGraph = OutputLabelledOpenMultiDiGraph; diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index a52906c612..3e6c8fd3c7 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -7,6 +7,18 @@ namespace FlexFlow { +/** + * @struct Substitution + * @brief A substitution is to replace a subgraph of the PCG by a new one. + * We refer to the subgraph to be replaced as the input graph, and the new + * subgraph to replace the input graph as the output graph. + * A Substitution object describes a substitution. It consists of An + * input_graph of type GraphPattern that describes which kind of input graphs + * the substitution can be applied to; An output_graph of type OutputGraphExpr + * that describes how the output graph is computed from the input graph; and + * An input_mapping and output_maping that describes how the output graph is + * connected to the original PCG. + */ struct Substitution { using InputPatternInput = InputMultiDiEdge; using InputPatternOutput = OutputMultiDiEdge; @@ -19,8 +31,14 @@ struct Substitution { bidict output_mapping; }; +/** + * @brief is_valid_substitution checks if the substitution is valid. + * The implementation will enumerate all the possible substitutions and filter + * out all the invalid ones. + */ bool is_valid_substitution(Substitution const &); + SubParallelComputationGraph apply_substitution(SubParallelComputationGraph const &, Substitution const &, From 9345400aab5fbe5cc20fd144df63194f145da84e Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 18 Mar 2024 15:51:34 -0400 Subject: [PATCH 19/32] add more unit tests --- lib/compiler/test/CMakeLists.txt | 2 +- .../test/src/test_labelled_open_graph.cc | 190 ++++++++----- lib/compiler/test/src/test_optimal_cost.cc | 4 +- lib/substitutions/src/substitution.cc | 7 +- .../test/src/test_substitution.cc | 18 +- .../utils/graph/labelled/labelled_open.decl.h | 124 --------- .../utils/graph/labelled/labelled_open.h | 173 ------------ .../graph/labelled/labelled_open_interfaces.h | 62 ----- .../utils/graph/labelled/node_labelled.h | 48 +--- .../graph/labelled/node_labelled_interfaces.h | 36 +++ .../utils/graph/labelled/node_labelled_open.h | 53 ++-- .../utils/graph/labelled/output_labelled.h | 73 ++--- .../labelled/output_labelled_interfaces.h | 15 +- .../graph/labelled/output_labelled_open.h | 88 ++----- .../output_labelled_open_interfaces.h | 34 +++ .../utils/graph/labelled/standard_labelled.h | 59 +---- .../labelled/unordered_labelled_graphs.h | 249 ++++++++++++------ .../include/utils/graph/labelled/views.h | 8 +- .../include/utils/graph/labelled_graphs.h | 1 + lib/utils/src/graph/open_graphs.cc | 2 +- lib/utils/test/CMakeLists.txt | 24 +- lib/utils/test/src/test_cow_ptr.cc | 60 +++++ 22 files changed, 563 insertions(+), 767 deletions(-) delete mode 100644 lib/utils/include/utils/graph/labelled/labelled_open.decl.h delete mode 100644 lib/utils/include/utils/graph/labelled/labelled_open.h delete mode 100644 lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h create mode 100644 lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h create mode 100644 lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h create mode 100644 lib/utils/test/src/test_cow_ptr.cc diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index cbd7e233c0..3d35fdabfd 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -2,7 +2,7 @@ ff_add_test_executable( NAME compiler-test SRC_PATTERNS - src/*.cc + src/test_labelled_open_graph.cc PRIVATE_INCLUDE src/ DEPS diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index a360d86ee7..a3b6319528 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -4,73 +4,141 @@ using namespace FlexFlow; -TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { - auto g = OpenMultiDiGraph::create(); +// TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { +// auto g = OpenMultiDiGraph::create(); - Node n0 = g.add_node(); - Node n1 = g.add_node(); - Node n2 = g.add_node(); - Node n3 = g.add_node(); - Node n4 = g.add_node(); +// Node n0 = g.add_node(); +// Node n1 = g.add_node(); +// Node n2 = g.add_node(); +// Node n3 = g.add_node(); +// Node n4 = g.add_node(); + +// NodePort p0 = g.add_node_port(); +// NodePort p1 = g.add_node_port(); +// NodePort p2 = g.add_node_port(); +// NodePort p3 = g.add_node_port(); +// NodePort p4 = g.add_node_port(); +// NodePort p5 = g.add_node_port(); +// NodePort p6 = g.add_node_port(); +// NodePort p7 = g.add_node_port(); +// NodePort p8 = g.add_node_port(); +// NodePort p9 = g.add_node_port(); + +// MultiDiEdge e0{n1, p1, n0, p0}; +// MultiDiEdge e1{n2, p2, n0, p0}; +// MultiDiEdge e2{n3, p5, n1, p3}; +// MultiDiEdge e3{n3, p6, n2, p4}; +// MultiDiEdge e4{n4, p8, n3, p7}; +// OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; + +// g.add_edge(e0); +// g.add_edge(e1); +// g.add_edge(e2); +// g.add_edge(e3); +// g.add_edge(e4); +// g.add_edge(e5); + +// std::unordered_set node_set0{n3, n4}; + +// auto subgraph0 = get_subgraph(g, node_set0); +// auto subgraph1 = get_subgraph(g, node_set0); +// auto subgraph2 = get_subgraph(g, +// node_set0); auto subgraph3 = get_subgraph(g, +// node_set0); + +// CHECK(get_nodes(subgraph0) == node_set0); +// CHECK(get_nodes(subgraph1) == node_set0); +// CHECK(get_nodes(subgraph2) == node_set0); +// CHECK(get_nodes(subgraph3) == node_set0); + +// std::unordered_set input_set{split_edge(e2).second, +// split_edge(e3).second}; +// std::unordered_set output_set{e5}; + +// CHECK(bool(get_open_inputs(subgraph0) == input_set)); +// CHECK(bool(get_open_inputs(subgraph1) == input_set)); +// CHECK(bool(get_open_inputs(subgraph2).empty())); +// CHECK(bool(get_open_inputs(subgraph3).empty())); + +// CHECK(bool(get_open_outputs(subgraph0) == output_set)); +// CHECK(bool(get_open_outputs(subgraph1).empty())); +// CHECK(bool(get_open_outputs(subgraph2) == output_set)); +// CHECK(bool(get_open_outputs(subgraph3).empty())); + +// CHECK(bool(get_edges(subgraph0) == +// std::unordered_set{ +// split_edge(e2).second, split_edge(e3).second, e4, e5})); +// CHECK(bool(get_edges(subgraph1) == +// std::unordered_set{ +// split_edge(e2).second, split_edge(e3).second, e4})); +// CHECK(bool(get_edges(subgraph2) == +// std::unordered_set{e4, e5})); +// CHECK(bool(get_edges(subgraph3) == +// std::unordered_set{e4})); + +// CHECK(get_closed_sources(subgraph2) == std::unordered_set{n3}); +// } + +// TEST_CASE("view OutputLabelledMultiDiGraph as open") { +// OutputLabelledMultiDiGraph g = +// OutputLabelledMultiDiGraph::create>(); + +// Node n0 = g.add_node(0); +// Node n1 = g.add_node(1); + +// NodePort p0 = g.add_node_port(); +// NodePort p1 = g.add_node_port(); + +// MultiDiEdge e0{n1, p1, n0, p0}; + +// g.add_edge(e0); +// g.add_output(e0, 2); + +// CHECK(get_edges(g).size() == 1); + +// OutputLabelledOpenMultiDiGraphView open_graph = +// view_output_labelled_as_output_labelled_open(g); + +// CHECK(open_graph.at(n0) == 0); +// CHECK(open_graph.at(n1) == 1); +// CHECK(open_graph.at(e0) == 2); + +// // CHECK(get_edges(open_graph).size() == 1); +// } + +TEST_CASE("OutputLabelledOpenMultiDiGraph") { + OutputLabelledOpenMultiDiGraph g = + OutputLabelledOpenMultiDiGraph::create< + UnorderedOutputLabelledOpenMultiDiGraph>(); + + Node n0 = g.add_node(0); + Node n1 = g.add_node(1); NodePort p0 = g.add_node_port(); NodePort p1 = g.add_node_port(); - NodePort p2 = g.add_node_port(); - NodePort p3 = g.add_node_port(); - NodePort p4 = g.add_node_port(); - NodePort p5 = g.add_node_port(); - NodePort p6 = g.add_node_port(); - NodePort p7 = g.add_node_port(); - NodePort p8 = g.add_node_port(); - NodePort p9 = g.add_node_port(); MultiDiEdge e0{n1, p1, n0, p0}; - MultiDiEdge e1{n2, p2, n0, p0}; - MultiDiEdge e2{n3, p5, n1, p3}; - MultiDiEdge e3{n3, p6, n2, p4}; - MultiDiEdge e4{n4, p8, n3, p7}; - OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - g.add_edge(e5); - - std::unordered_set node_set0{n3, n4}; - - auto subgraph0 = get_subgraph(g, node_set0); - auto subgraph1 = get_subgraph(g, node_set0); - auto subgraph2 = get_subgraph(g, node_set0); - auto subgraph3 = get_subgraph(g, node_set0); - - CHECK(get_nodes(subgraph0) == node_set0); - CHECK(get_nodes(subgraph1) == node_set0); - CHECK(get_nodes(subgraph2) == node_set0); - CHECK(get_nodes(subgraph3) == node_set0); - - std::unordered_set input_set{split_edge(e2).second, - split_edge(e3).second}; - std::unordered_set output_set{e5}; - - CHECK(bool(get_open_inputs(subgraph0) == input_set)); - CHECK(bool(get_open_inputs(subgraph1) == input_set)); - CHECK(bool(get_open_inputs(subgraph2).empty())); - CHECK(bool(get_open_inputs(subgraph3).empty())); - - CHECK(bool(get_open_outputs(subgraph0) == output_set)); - CHECK(bool(get_open_outputs(subgraph1).empty())); - CHECK(bool(get_open_outputs(subgraph2) == output_set)); - CHECK(bool(get_open_outputs(subgraph3).empty())); - - CHECK(bool(get_edges(subgraph0) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4, e5})); - CHECK(bool(get_edges(subgraph1) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4})); - CHECK(bool(get_edges(subgraph2) == - std::unordered_set{e4, e5})); - CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); + g.add_label(e0, 2); + + CHECK(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1); + CHECK(get_edges(g).size() == 1); } + +// TEST_CASE("OpenMultiDiGraph") { +// OpenMultiDiGraph g = OpenMultiDiGraph::create(); + +// Node n0 = g.add_node(); +// Node n1 = g.add_node(); + +// NodePort p0 = g.add_node_port(); +// NodePort p1 = g.add_node_port(); + +// MultiDiEdge e0{n1, p1, n0, p0}; + +// g.add_edge(e0); + +// CHECK(get_edges(g).size() == 1); +// } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index c5f74ff392..9d90285870 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -30,9 +30,7 @@ allowed machine views, trivial cost estimator and random machine specification. TEST_CASE("optimal_cost_0") { auto pcg = OutputLabelledMultiDiGraph::template create< - AdjacencyMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling>(); + UnorderedOutputLabelledMultiDiGraph>(); Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); Node n1 = pcg.add_node(Operator{ diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index f846171b62..da9f303ab8 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -413,11 +413,8 @@ SubParallelComputationGraph Substitution const &substitution, MultiDiGraphPatternMatch const &match) { SubParallelComputationGraph new_pcg = - OutputLabelledOpenMultiDiGraph::create< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling, - UnorderedLabelling>(); + OutputLabelledOpenMultiDiGraph::template create< + UnorderedOutputLabelledOpenMultiDiGraph>(); bidict node_mapping; // Refactor it with global nodes for (Node const &node : get_nodes(pcg)) { if (!contains_r(match.node_assignment, node)) { diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index a33e9127cc..a8f5283eda 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -19,12 +19,10 @@ TEST_CASE("apply_substitution") { ParallelTensorPattern tensor_pattern_empty{ std::vector{}}; - auto ig = - OutputLabelledOpenMultiDiGraph:: - create, - UnorderedLabelling, - UnorderedLabelling>(); + auto ig = OutputLabelledOpenMultiDiGraph:: + create>(); Node n0 = ig.add_node(operator_pattern_n0); NodePort p0 = ig.add_node_port(); InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; @@ -60,8 +58,7 @@ TEST_CASE("apply_substitution") { {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; auto og = NodeLabelledOpenMultiDiGraph::create< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling>(); + UnorderedNodeLabelledOpenMultiDiGraph>(); Node n1 = og.add_node(op_ass_n1); Node n2 = og.add_node(op_ass_n2); Node n3 = og.add_node(op_ass_n3); @@ -88,10 +85,7 @@ TEST_CASE("apply_substitution") { SubParallelComputationGraph pcg = OutputLabelledOpenMultiDiGraph::create< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling, - UnorderedLabelling>(); + UnorderedOutputLabelledOpenMultiDiGraph>(); Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); Node n5 = pcg.add_node(Operator{ diff --git a/lib/utils/include/utils/graph/labelled/labelled_open.decl.h b/lib/utils/include/utils/graph/labelled/labelled_open.decl.h deleted file mode 100644 index cdd22b7847..0000000000 --- a/lib/utils/include/utils/graph/labelled/labelled_open.decl.h +++ /dev/null @@ -1,124 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_DECL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_DECL_H - -#include "labelled_open_interfaces.h" -#include "node_labelled.h" -#include "utils/graph/open_graphs.h" - -namespace FlexFlow { - -template -struct LabelledOpenMultiDiGraphView { -private: - using Interface = ILabelledOpenMultiDiGraphView; - -public: - LabelledOpenMultiDiGraphView() = delete; - - operator OpenMultiDiGraphView() const; - // operator MultiDiGraphView() const; - - NodeLabel const &at(Node const &n) const; - EdgeLabel const &at(MultiDiEdge const &e) const; - InputLabel const &at(InputMultiDiEdge const &e) const; - OutputLabel const &at(OutputMultiDiEdge const &e) const; - - template - static typename std::enable_if::value, - LabelledOpenMultiDiGraphView>::type - create(); - -private: - std::shared_ptr ptr; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ( - LabelledOpenMultiDiGraphView); - -template -struct LabelledOpenMultiDiGraph { -private: - using Interface = - ILabelledOpenMultiDiGraph; - -public: - LabelledOpenMultiDiGraph() = delete; - LabelledOpenMultiDiGraph(LabelledOpenMultiDiGraph const &other) = default; - LabelledOpenMultiDiGraph & - operator=(LabelledOpenMultiDiGraph const &other) = default; - - operator LabelledOpenMultiDiGraphView() const; - - operator OpenMultiDiGraphView() const; - - friend void swap(LabelledOpenMultiDiGraph &lhs, - LabelledOpenMultiDiGraph &rhs) { - using std::swap; - - swap(lhs.ptr, rhs.ptr); - } - - Node add_node(NodeLabel const &l); - NodeLabel &at(Node const &n); - - NodePort add_node_port(); - - NodeLabel const &at(Node const &n) const; - - void add_node_unsafe(Node const &n, NodeLabel const &l); - - std::unordered_set query_nodes(NodeQuery const &q) const; - std::unordered_set - query_edges(OpenMultiDiEdgeQuery const &q) const; - - void add_edge( - MultiDiEdge const &e); // We should allow adding edges without labels. For - // example, we may want to first construct a PCG - // and infer its tensor shapes later. - void add_edge(InputMultiDiEdge const &e); - void add_edge(OutputMultiDiEdge const &e); - - void add_label(MultiDiEdge const &e, EdgeLabel const &l); - void add_label(InputMultiDiEdge const &e, EdgeLabel const &l); - void add_label(OutputMultiDiEdge const &e, EdgeLabel const &l); - - void add_edge(MultiDiEdge const &e, EdgeLabel const &l); - EdgeLabel &at(MultiDiEdge const &e); - EdgeLabel const &at(MultiDiEdge const &e) const; - - void add_edge(InputMultiDiEdge const &e, InputLabel const &l); - InputLabel &at(InputMultiDiEdge const &e); - InputLabel const &at(InputMultiDiEdge const &e) const; - - void add_edge(OutputMultiDiEdge const &, OutputLabel const &); - OutputLabel &at(OutputMultiDiEdge const &); - OutputLabel const &at(OutputMultiDiEdge const &) const; - - template - static typename std::enable_if::value, - LabelledOpenMultiDiGraph>::type - create(); - -private: - LabelledOpenMultiDiGraph(cow_ptr_t ptr); - -private: - cow_ptr_t ptr; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ( - LabelledOpenMultiDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled/labelled_open.h b/lib/utils/include/utils/graph/labelled/labelled_open.h deleted file mode 100644 index 58fd5416f7..0000000000 --- a/lib/utils/include/utils/graph/labelled/labelled_open.h +++ /dev/null @@ -1,173 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_H - -#include "labelled_open.decl.h" -#include "labelled_open_interfaces.h" -#include "node_labelled.h" -#include "utils/graph/open_graph_interfaces.h" -#include "utils/graph/open_graphs.h" - -namespace FlexFlow { - -// LabelledOpenMultiDiGraphView -template -LabelledOpenMultiDiGraphView::operator OpenMultiDiGraphView() - const { - return GraphInternal::create_open_multidigraph_view(this->ptr); -} - -// template -// LabelledOpenMultiDiGraphView::operator MultiDiGraphView() const { -// return GraphInternal::create_multidigraphview(this->ptr); -// } - -template -NodeLabel const & - LabelledOpenMultiDiGraphView::at(Node const &n) const { - return this->ptr->at(n); -} - -template -EdgeLabel const &LabelledOpenMultiDiGraphView::at( - MultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -InputLabel const &LabelledOpenMultiDiGraphView::at( - InputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -OutputLabel const &LabelledOpenMultiDiGraphView::at( - OutputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -template -enable_if_t::Interface, - BaseImpl>::value, - LabelledOpenMultiDiGraphView> - LabelledOpenMultiDiGraphView::create() { - return LabelledOpenMultiDiGraphView(std::make_shared()); -} - -// LabelledOpenMultiDiGraph -template -LabelledOpenMultiDiGraph:: - operator LabelledOpenMultiDiGraphView() const { - return GraphInternal::create_labelled_open_multidigraph_view( - this->ptr); -} - -template -LabelledOpenMultiDiGraph::operator OpenMultiDiGraphView() const { - return GraphInternal::create_open_multidigraph_view(this->ptr.get()); -} - -template -Node LabelledOpenMultiDiGraph::add_node( - NodeLabel const &l) { - return this->ptr.get_mutable()->add_node(l); -} - -template -NodeLabel &LabelledOpenMultiDiGraph::at(Node const &n) { - return this->ptr->at(n); -} - -template -NodeLabel const & - LabelledOpenMultiDiGraph::at(Node const &n) const { - return this->ptr->ILabelledMultiDiGraph::at(n); -} - -template -void LabelledOpenMultiDiGraph::add_node_unsafe( - Node const &n, NodeLabel const &l) { - this->ptr->add_node_unsafe(n, l); -} - -template -std::unordered_set LabelledOpenMultiDiGraph::query_nodes( - NodeQuery const &q) const { - return this->ptr->query_nodes(q); -} - -template -std::unordered_set - LabelledOpenMultiDiGraph::query_edges( - OpenMultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); -} - -template -void LabelledOpenMultiDiGraph::add_edge( - MultiDiEdge const &e, EdgeLabel const &l) { - return this->ptr->add_edge(e, l); -} - -template -EdgeLabel & - LabelledOpenMultiDiGraph::at(MultiDiEdge const &e) { - return this->ptr->at(e); -} - -template -EdgeLabel const &LabelledOpenMultiDiGraph::at( - MultiDiEdge const &e) const { - return this->ptr->ILabelledMultiDiGraph::at(e); -} - -template -void LabelledOpenMultiDiGraph::add_edge( - InputMultiDiEdge const &e, InputLabel const &l) { - return this->ptr->add_edge(e, l); -} - -template -InputLabel &LabelledOpenMultiDiGraph::at( - InputMultiDiEdge const &e) { - return this->ptr->at(e); -} - -template -InputLabel const &LabelledOpenMultiDiGraph::at( - InputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -void LabelledOpenMultiDiGraph::add_edge( - OutputMultiDiEdge const &e, OutputLabel const &l) { - return this->ptr->add_edge(e, l); -} - -template -OutputLabel &LabelledOpenMultiDiGraph::at( - OutputMultiDiEdge const &e) { - return this->ptr->at(e); -} - -template -OutputLabel const &LabelledOpenMultiDiGraph::at( - OutputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -template -enable_if_t< - std::is_base_of::Interface, - BaseImpl>::value, - LabelledOpenMultiDiGraph> - LabelledOpenMultiDiGraph::create() { - return LabelledOpenMultiDiGraph(make_cow_ptr()); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h b/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h deleted file mode 100644 index 2db654c615..0000000000 --- a/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_INTERFACES_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_INTERFACES_H - -#include "standard_labelled_interfaces.h" -#include "utils/containers.h" -#include "utils/graph/open_graph_interfaces.h" - -namespace FlexFlow { - -template -struct ILabelledOpenMultiDiGraphView - : public IOpenMultiDiGraphView, - public ILabelledMultiDiGraphView { -public: - std::unordered_set - query_edges(MultiDiEdgeQuery const &q) const final { - return map_over_unordered_set( - [](OpenMultiDiEdge const &e) { return get(e); }, - IOpenMultiDiGraphView::query_edges( - static_cast(q))); - } - - using ILabelledMultiDiGraphView::at; - virtual InputLabel const &at(InputMultiDiEdge const &e) const = 0; - virtual OutputLabel const &at(OutputMultiDiEdge const &e) const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT( - ILabelledOpenMultiDiGraphView); - -template -struct ILabelledOpenMultiDiGraph - : public ILabelledMultiDiGraph, - public ILabelledOpenMultiDiGraphView { -public: - virtual ILabelledOpenMultiDiGraph *clone() const = 0; - - virtual void add_edge(InputMultiDiEdge const &e, InputLabel const &label) = 0; - virtual void add_edge(OutputMultiDiEdge const &e, - OutputLabel const &label) = 0; - - virtual InputLabel const &at(InputMultiDiEdge const &e) const = 0; - virtual InputLabel &at(InputMultiDiEdge const &e) = 0; - - virtual OutputLabel const &at(OutputMultiDiEdge const &e) const = 0; - virtual OutputLabel &at(OutputMultiDiEdge const &e) = 0; - - using ILabelledMultiDiGraph::add_node; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledOpenMultiDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 1ecd87226c..9d8874fb14 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -1,24 +1,11 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_H -#include "label_interfaces.h" +#include "node_labelled_interfaces.h" #include "utils/graph/multidigraph.h" namespace FlexFlow { -template -struct INodeLabelledMultiDiGraphView : virtual public IMultiDiGraphView { - INodeLabelledMultiDiGraphView() = default; - INodeLabelledMultiDiGraphView(INodeLabelledMultiDiGraphView const &) = delete; - INodeLabelledMultiDiGraphView & - operator=(INodeLabelledMultiDiGraphView const &) = delete; - - virtual ~INodeLabelledMultiDiGraphView() {} - - virtual NodeLabel const &at(Node const &n) const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(INodeLabelledMultiDiGraphView); - template struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: @@ -65,7 +52,6 @@ struct NodeLabelledMultiDiGraph : virtual NodeLabelledMultiDiGraphView { private: using Interface = IMultiDiGraph; - using NodeLabelIf = ILabelling; public: NodeLabelledMultiDiGraph(NodeLabelledMultiDiGraph const &) = default; @@ -73,48 +59,42 @@ struct NodeLabelledMultiDiGraph operator=(NodeLabelledMultiDiGraph const &) = default; NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->get_ptr().at(n); } NodeLabel &at(Node const &n) { - return nl.get_mutable()->get_label(n); + return this->get_ptr().at(n); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(); + return this->get_ptr().query_nodes(); } std::unordered_set query_edges(MultiDiEdge const &q) const { - return get_ptr().query_edges(); + return this->get_ptr().query_edges(); } Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - nl->add_label(n, l); - return n; + return this->get_ptr().add_node(l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } void add_edge(MultiDiEdge const &e) { - return get_ptr().add_edge(e); + return this->get_ptr().add_edge(e); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of>::value, - NodeLabelledMultiDiGraph>::type + template + static typename std::enable_if::value, + NodeLabelledMultiDiGraph>::type create() { - return NodeLabelledMultiDiGraph(make_cow_ptr(), - make_cow_ptr()); + return NodeLabelledMultiDiGraph(make_cow_ptr()); } protected: - NodeLabelledMultiDiGraph(cow_ptr_t ptr, cow_ptr_t nl) - : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} + NodeLabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { return *std::reinterpret_pointer_cast( @@ -125,8 +105,6 @@ struct NodeLabelledMultiDiGraph return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } - - cow_ptr_t nl; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraph); diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h b/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h new file mode 100644 index 0000000000..37fb4db715 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_INTERFACES_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_INTERFACES_H + +#include "utils/graph/multidigraph.h" + +namespace FlexFlow { + +template +struct INodeLabelledMultiDiGraphView : virtual public IMultiDiGraphView { + INodeLabelledMultiDiGraphView() = default; + INodeLabelledMultiDiGraphView(INodeLabelledMultiDiGraphView const &) = delete; + INodeLabelledMultiDiGraphView & + operator=(INodeLabelledMultiDiGraphView const &) = delete; + + virtual ~INodeLabelledMultiDiGraphView() {} + + virtual NodeLabel const &at(Node const &n) const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(INodeLabelledMultiDiGraphView); + +template +struct INodeLabelledMultiDiGraph + : virtual INodeLabelledMultiDiGraphView { + virtual NodeLabel &at(Node const &) = 0; + virtual Node add_node(NodeLabel const &l) = 0; + virtual NodePort add_node_port() = 0; + virtual void add_edge(MultiDiEdge const &) = 0; + + virtual INodeLabelledMultiDiGraph *clone() const = 0; + + using INodeLabelledMultiDiGraphView::at; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 2162ee0384..826a8387cb 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -60,64 +60,65 @@ struct NodeLabelledOpenMultiDiGraphView } }; +template +struct INodeLabelledOpenMultiDiGraph + : virtual INodeLabelledOpenMultiDiGraphView { + virtual Node add_node(NodeLabel const &) = 0; + virtual NodePort add_node_port() = 0; + virtual NodeLabel &at(Node const &) = 0; + virtual void add_edge(OpenMultiDiEdge const &e) = 0; + + using INodeLabelledOpenMultiDiGraphView::at; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(INodeLabelledOpenMultiDiGraphView); + template struct NodeLabelledOpenMultiDiGraph : virtual NodeLabelledOpenMultiDiGraphView { private: - using Interface = IOpenMultiDiGraph; - using INodeLabel = ILabelling; + using Interface = INodeLabelledOpenMultiDiGraph; public: - // NodeLabelledOpenMultiDiGraph() = delete; NodeLabelledOpenMultiDiGraph(NodeLabelledOpenMultiDiGraph const &) = default; NodeLabelledOpenMultiDiGraph & operator=(NodeLabelledOpenMultiDiGraph const &) = default; - NodeLabel const &at(Node const &n) const { - return nl->get_label(n); - } - NodeLabel &at(Node const &n) { - return nl->get_label(n); + return this->get_ptr().at(n); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(OpenMultiDiEdge const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - nl.get_mutable()->add_label(n, l); - return n; + return this->get_ptr().add_node(l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } void add_edge(OpenMultiDiEdge const &e) { - return get_ptr().add_edge(e); + return this->get_ptr().add_edge(e); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of>::value, - NodeLabelledOpenMultiDiGraph>::type + using NodeLabelledOpenMultiDiGraphView::at; + + template + static typename std::enable_if::value, + NodeLabelledOpenMultiDiGraph>::type create() { - return NodeLabelledOpenMultiDiGraph(make_cow_ptr(), - make_cow_ptr()); + return NodeLabelledOpenMultiDiGraph(make_cow_ptr()); } private: - NodeLabelledOpenMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl) - : GraphView(ptr), nl(nl) {} + NodeLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { return *std::reinterpret_pointer_cast( @@ -128,8 +129,6 @@ struct NodeLabelledOpenMultiDiGraph return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } - - cow_ptr_t nl; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index 9c65db4daa..c6c521c38b 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -1,24 +1,11 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_H -#include "standard_labelled.h" +#include "node_labelled.h" +#include "output_labelled_interfaces.h" namespace FlexFlow { -template -struct IOutputLabelledMultiDiGraphView - : public INodeLabelledMultiDiGraphView { - IOutputLabelledMultiDiGraphView() = default; - IOutputLabelledMultiDiGraphView(IOutputLabelledMultiDiGraphView const &) = - delete; - IOutputLabelledMultiDiGraphView & - operator=(IOutputLabelledMultiDiGraphView const &) = delete; - - virtual OutputLabel const &at(MultiDiOutput const &) const = 0; - using INodeLabelledMultiDiGraphView::at; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraphView); - template struct OutputLabelledMultiDiGraphView : virtual public NodeLabelledMultiDiGraphView { @@ -32,19 +19,19 @@ struct OutputLabelledMultiDiGraphView operator=(OutputLabelledMultiDiGraphView const &) = default; NodeLabel const &at(Node const &n) const { - return get_ptr().at(n); + return this->get_ptr().at(n); } OutputLabel const &at(MultiDiOutput const &o) const { - return get_ptr().at(o); + return this->get_ptr().at(o); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } template @@ -69,9 +56,7 @@ template struct OutputLabelledMultiDiGraph : virtual OutputLabelledMultiDiGraphView { private: - using Interface = IMultiDiGraph; - using INodeLabel = ILabelling; - using IOutputLabel = ILabelling; + using Interface = IOutputLabelledMultiDiGraph; public: OutputLabelledMultiDiGraph(OutputLabelledMultiDiGraph const &other) = default; @@ -79,67 +64,58 @@ struct OutputLabelledMultiDiGraph operator=(OutputLabelledMultiDiGraph const &other) = default; Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - nl.get_mutable()->add_label(n, l); - return n; + return this->get_ptr().add_node(l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } NodeLabel &at(Node const &n) { - return nl.get_mutable()->get_label(n); + return this->get_ptr().at(n); } NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->get_ptr().at(n); } void add_output(MultiDiOutput const &o, OutputLabel const &l) { - ol.get_mutable()->add_label(o, l); + this->get_ptr().add_output(o, l); }; void add_edge(MultiDiOutput const &o, MultiDiInput const &i) { - return get_ptr().add_edge(o, i); + this->get_ptr().add_edge(o, i); }; void add_edge(MultiDiEdge const &e) { - return get_ptr().add_edge(e); + this->get_ptr().add_edge(e); } OutputLabel &at(MultiDiOutput const &o) { - return ol.get_mutable()->get_label(o); + return this->get_ptr().at(o); } OutputLabel const &at(MultiDiOutput const &o) const { - return ol->get_label(o); + return this->get_ptr().at(o); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of, - std::is_base_of>::value, - OutputLabelledMultiDiGraph>::type + template + static typename std::enable_if::value, + OutputLabelledMultiDiGraph>::type create() { - return OutputLabelledMultiDiGraph( - make_cow_ptr(), make_cow_ptr(), make_cow_ptr()); + return OutputLabelledMultiDiGraph(make_cow_ptr()); } private: - OutputLabelledMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl, - cow_ptr_t ol) - : GraphView(ptr), nl(nl), ol(ol) {} + OutputLabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} private: Interface &get_ptr() { @@ -151,9 +127,6 @@ struct OutputLabelledMultiDiGraph return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } - - cow_ptr_t nl; - cow_ptr_t ol; }; template struct IOutputLabelledMultiDiGraphView : public INodeLabelledMultiDiGraphView { - virtual OutputLabel &at(MultiDiOutput const &) = 0; + virtual OutputLabel const &at(MultiDiOutput const &) const = 0; + + using INodeLabelledMultiDiGraphView::at; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraphView); template struct IOutputLabelledMultiDiGraph - : public IOutputLabelledMultiDiGraphView { + : public IOutputLabelledMultiDiGraphView, + public INodeLabelledMultiDiGraph { public: virtual IOutputLabelledMultiDiGraph *clone() const = 0; virtual void add_output(MultiDiOutput const &output, OutputLabel const &label) = 0; - virtual void add_edge(MultiDiOutput const &output, - MultiDiInput const &input) = 0; - virtual NodePort add_node_ports() = 0; + virtual NodePort add_node_port() = 0; virtual NodeLabel &at(Node const &) = 0; virtual NodeLabel const &at(Node const &) const = 0; + virtual OutputLabel &at(MultiDiOutput const &) = 0; virtual OutputLabel const &at(MultiDiOutput const &) const = 0; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraph); diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 23dd9c190c..24235bee4c 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -2,19 +2,10 @@ #define _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN #include "node_labelled_open.h" -#include "utils/graph/adjacency_openmultidigraph.h" +#include "output_labelled_open_interfaces.h" namespace FlexFlow { -template -struct IOutputLabelledOpenMultiDiGraphView - : virtual INodeLabelledOpenMultiDiGraphView { - virtual EdgeLabel const &at(InputMultiDiEdge const &) const = 0; - virtual EdgeLabel const &at(MultiDiOutput const &) const = 0; - - using INodeLabelledOpenMultiDiGraphView::at; -}; - template struct OutputLabelledOpenMultiDiGraphView : virtual NodeLabelledOpenMultiDiGraphView, @@ -29,15 +20,15 @@ struct OutputLabelledOpenMultiDiGraphView operator=(OutputLabelledOpenMultiDiGraphView const &) = default; NodeLabel const &at(Node const &n) const { - return get_ptr().at(n); + return this->get_ptr().at(n); } EdgeLabel const &at(InputMultiDiEdge const &i) const { - return get_ptr().at(i); + return this->get_ptr().at(i); } EdgeLabel const &at(MultiDiOutput const &o) const { - return get_ptr().at(o); + return this->get_ptr().at(o); } template @@ -51,12 +42,12 @@ struct OutputLabelledOpenMultiDiGraphView } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(OpenMultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } template @@ -82,10 +73,7 @@ template struct OutputLabelledOpenMultiDiGraph : virtual OutputLabelledOpenMultiDiGraphView { private: - using Interface = IOpenMultiDiGraph; - using INodeLabel = ILabelling; - using IInputLabel = ILabelling; - using IOutputLabel = ILabelling; + using Interface = IOutputLabelledOpenMultiDiGraph; public: OutputLabelledOpenMultiDiGraph() = delete; @@ -95,14 +83,7 @@ struct OutputLabelledOpenMultiDiGraph operator=(OutputLabelledOpenMultiDiGraph const &) = default; Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - this->node_labelling.get_mutable()->add_label(n, l); - return n; - } - - void add_node_unsafe(Node const &n, NodeLabel const &l) { - this->get_ptr().add_node_unsafe(n); - this->node_labelling.get_mutable()->add_label(n, l); + return this->get_ptr().add_node(l); } NodePort add_node_port() { @@ -110,19 +91,15 @@ struct OutputLabelledOpenMultiDiGraph } NodeLabel &at(Node const &n) { - return this->node_labelling.get_mutable()->get_label(n); - } - - NodeLabel const &at(Node const &n) const { - return this->node_labelling->get_label(n); + return this->get_ptr().at(n); } void add_label(MultiDiOutput const &o, EdgeLabel const &l) { - this->output_labelling.get_mutable()->add_label(o, l); + this->get_ptr().add_label(o, l); }; void add_label(InputMultiDiEdge const &e, EdgeLabel const &l) { - this->input_labelling.get_mutable()->add_label(e, l); + this->get_ptr().add_label(e, l); } void add_edge(OpenMultiDiEdge const &e) { @@ -130,18 +107,11 @@ struct OutputLabelledOpenMultiDiGraph } EdgeLabel &at(MultiDiOutput const &o) { - return this->output_labelling.get_mutable()->get_label(o); - } - EdgeLabel const &at(MultiDiOutput const &o) const { - return this->output_labelling->get_label(o); + return this->get_ptr().at(o); } EdgeLabel &at(InputMultiDiEdge const &e) { - return this->input_labelling.get_mutable()->get_label(e); - } - - EdgeLabel const &at(InputMultiDiEdge const &e) const { - return this->input_labelling->get_label(e); + return this->get_ptr().at(e); } template @@ -155,34 +125,24 @@ struct OutputLabelledOpenMultiDiGraph } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(OpenMultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of, - std::is_base_of, - std::is_base_of>::value, - OutputLabelledOpenMultiDiGraph>::type + template + static typename std::enable_if::value, + OutputLabelledOpenMultiDiGraph>::type create() { - return OutputLabelledOpenMultiDiGraph(make_cow_ptr(), - make_cow_ptr(), - make_cow_ptr(), - make_cow_ptr()); + return OutputLabelledOpenMultiDiGraph(make_cow_ptr()); } + using OutputLabelledOpenMultiDiGraphView::at; + private: - OutputLabelledOpenMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl, - cow_ptr_t il, - cow_ptr_t ol) - : GraphView(ptr), node_labelling(nl), input_labelling(il), - output_labelling(ol) {} + OutputLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { return *std::reinterpret_pointer_cast( @@ -193,10 +153,6 @@ struct OutputLabelledOpenMultiDiGraph return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } - - cow_ptr_t node_labelling; - cow_ptr_t input_labelling; - cow_ptr_t output_labelling; }; template diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h b/lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h new file mode 100644 index 0000000000..501805fe2a --- /dev/null +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN_INTERFACES +#define _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN_INTERFACES + +#include "node_labelled_open.h" + +namespace FlexFlow { + +template +struct IOutputLabelledOpenMultiDiGraphView + : virtual INodeLabelledOpenMultiDiGraphView { + virtual EdgeLabel const &at(InputMultiDiEdge const &) const = 0; + virtual EdgeLabel const &at(MultiDiOutput const &) const = 0; + + using INodeLabelledOpenMultiDiGraphView::at; +}; + +template +struct IOutputLabelledOpenMultiDiGraph + : virtual public IOutputLabelledOpenMultiDiGraphView { + virtual EdgeLabel &at(InputMultiDiEdge const &) = 0; + virtual EdgeLabel &at(MultiDiOutput const &) = 0; + virtual Node add_node(NodeLabel const &) = 0; + virtual NodePort add_node_port() = 0; + virtual NodeLabel &at(Node const &) = 0; + virtual void add_label(MultiDiOutput const &o, EdgeLabel const &l) = 0; + virtual void add_label(InputMultiDiEdge const &e, EdgeLabel const &l) = 0; + virtual void add_edge(OpenMultiDiEdge const &e) = 0; + + using IOutputLabelledOpenMultiDiGraphView::at; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index 3c69d62ae9..e1c8e91634 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -2,23 +2,10 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_STANDARD_LABELLED_H #include "node_labelled.h" +#include "standard_labelled_interfaces.h" namespace FlexFlow { -template -struct ILabelledMultiDiGraphView - : public INodeLabelledMultiDiGraphView { - ILabelledMultiDiGraphView() = default; - ILabelledMultiDiGraphView(ILabelledMultiDiGraphView const &) = delete; - ILabelledMultiDiGraphView & - operator=(ILabelledMultiDiGraphView const &) = delete; - - virtual ~ILabelledMultiDiGraphView() = default; - - virtual EdgeLabel const &at(MultiDiEdge const &) const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledMultiDiGraphView); - template struct LabelledMultiDiGraphView : virtual public NodeLabelledMultiDiGraphView { @@ -70,19 +57,14 @@ template struct LabelledMultiDiGraph : virtual LabelledMultiDiGraphView { private: - using Interface = IMultiDiGraph; - using INodeLabel = ILabelling; - using IEdgeLabel = ILabelling; + using Interface = ILabelledMultiDiGraph; public: - // LabelledMultiDiGraph() = delete; LabelledMultiDiGraph(LabelledMultiDiGraph const &other) = default; LabelledMultiDiGraph &operator=(LabelledMultiDiGraph const &other) = default; Node add_node(NodeLabel const &l) { - Node n = MultiDiGraph::add_node(); - nl->add_label(n, l); - return n; + return this->get_ptr().add_node(); } NodePort add_node_port() { @@ -90,46 +72,36 @@ struct LabelledMultiDiGraph } NodeLabel &at(Node const &n) { - return nl.get_mutable()->get_label(n); - } - - NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->get_ptr().at(n); } void add_edge(MultiDiEdge const &e, EdgeLabel const &l) { return this->get_ptr().add_edge(e, l); } + EdgeLabel &at(MultiDiEdge const &e) { - return el.get_mutable()->get_label(e); - } - EdgeLabel const &at(MultiDiEdge const &e) const { - return el->get_label(e); + return this->get_ptr().at(e); } std::unordered_set query_nodes(NodeQuery const &q) const { return this->get_ptr().query_nodes(q); } + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { return this->get_ptr().query_edges(q); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of, - std::is_base_of>::value, - LabelledMultiDiGraph>::type + using LabelledMultiDiGraphView::at; + + template + static typename std::enable_if::value, + LabelledMultiDiGraph>::type create() { - return LabelledMultiDiGraph( - make_cow_ptr(), make_cow_ptr(), make_cow_ptr()); + return LabelledMultiDiGraph(make_cow_ptr()); } private: - LabelledMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl, - cow_ptr_t el) - : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} + LabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { return *std::reinterpret_pointer_cast( @@ -140,9 +112,6 @@ struct LabelledMultiDiGraph return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } - - cow_ptr_t nl; - cow_ptr_t el; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraph); diff --git a/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h b/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h index f7af522b3c..fe396e5989 100644 --- a/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h +++ b/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h @@ -1,138 +1,227 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_UNORDERED_LABELLED_GRAPHS_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_UNORDERED_LABELLED_GRAPHS_H -#include "labelled_open_interfaces.h" -#include "node_labelled_interfaces.h" -#include "output_labelled_interfaces.h" -#include "standard_labelled_interfaces.h" -#include "utils/graph/open_graphs.h" +#include "output_labelled_open_interfaces.h" +#include "unordered_label.h" +#include "utils/graph/adjacency_openmultidigraph.h" namespace FlexFlow { template -struct UnorderedNodeLabelledMultiDiGraph - : public INodeLabelledMultiDiGraph, - protected MultiDiGraph { -public: - UnorderedNodeLabelledMultiDiGraph() = delete; +struct UnorderedNodeLabelledOpenMultiDiGraph + : public INodeLabelledOpenMultiDiGraph { - Node add_node(NodeLabel const &label) override { - Node n = MultiDiGraph::add_node(); - node_map.insert({n, label}); - return n; + UnorderedNodeLabelledOpenMultiDiGraph() + : g(OpenMultiDiGraph::create()) {} + + Node add_node(NodeLabel const &l) override { + Node node = g.add_node(); + this->node_labelling.add_label(node, l); + return node; } - NodeLabel &at(Node const &n) override { - return this->node_map.at(n); + NodePort add_node_port() override { + return this->g.add_node_port(); } NodeLabel const &at(Node const &n) const override { - return this->node_map.at(n); + return this->node_labelling.get_label(n); } - using MultiDiGraph::query_edges; - using MultiDiGraph::query_nodes; + NodeLabel &at(Node const &n) override { + return this->node_labelling.get_label(n); + } -private: - std::unordered_map node_map; -}; + void add_edge(OpenMultiDiEdge const &e) override { + this->g.add_edge(e); + } -template -struct UnorderedLabelledMultiDiGraph - : public ILabelledMultiDiGraph, - public UnorderedNodeLabelledMultiDiGraph { - void add_edge(MultiDiEdge const &e, EdgeLabel const &label) override { - MultiDiGraph::add_edge(e); - edge_map.insert({e, label}); + std::unordered_set query_nodes(NodeQuery const &q) const override { + return g.query_nodes(q); } - EdgeLabel &at(MultiDiEdge const &n) override { - return this->edge_map.at(n); + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const override { + return g.query_edges(q); } - EdgeLabel const &at(MultiDiEdge const &n) const override { - return this->edge_map.at(n); + using INodeLabelledOpenMultiDiGraph::query_edges; + + UnorderedNodeLabelledOpenMultiDiGraph *clone() const override { + return new UnorderedNodeLabelledOpenMultiDiGraph(g, + node_labelling); } private: - std::unordered_map edge_map; -}; + UnorderedNodeLabelledOpenMultiDiGraph( + OpenMultiDiGraph const &g, + UnorderedLabelling const &node_labelling) + : g(g), node_labelling(node_labelling) {} -MultiDiOutput get_output(MultiDiEdge const &e); + OpenMultiDiGraph g; + UnorderedLabelling node_labelling; +}; +CHECK_NOT_ABSTRACT(UnorderedNodeLabelledOpenMultiDiGraph); template struct UnorderedOutputLabelledMultiDiGraph - : public IOutputLabelledMultiDiGraph, - public UnorderedNodeLabelledMultiDiGraph { -public: + : public IOutputLabelledMultiDiGraph { + + UnorderedOutputLabelledMultiDiGraph() + : g(MultiDiGraph::create()) {} + + OutputLabel const &at(MultiDiOutput const &i) const override { + return this->output_labelling.get_label(i); + } + + OutputLabel &at(MultiDiOutput const &i) override { + return this->output_labelling.get_label(i); + } + + Node add_node(NodeLabel const &l) override { + Node node = g.add_node(); + this->node_labelling.add_label(node, l); + return node; + } + + NodePort add_node_port() override { + return this->g.add_node_port(); + } + + NodeLabel const &at(Node const &n) const override { + return this->node_labelling.get_label(n); + } + + NodeLabel &at(Node const &n) override { + return this->node_labelling.get_label(n); + } + + void add_edge(MultiDiEdge const &e) override { + this->g.add_edge(e); + } + void add_output(MultiDiOutput const &output, OutputLabel const &label) override { - this->output_map.insert({output, label}); + this->output_labelling.add_label(output, label); } - void add_edge(MultiDiEdge const &e) override { - MultiDiOutput output = get_output(e); - if (!contains_key(this->output_map, output)) { - throw mk_runtime_error("Could not find output {}", output); - } - this->add_edge(e); + std::unordered_set query_nodes(NodeQuery const &q) const override { + return g.query_nodes(q); } - void add_edge(MultiDiOutput const &output, - MultiDiInput const &input) override { - this->add_edge(MultiDiEdge{output.node, input.node, output.idx, input.idx}); + std::unordered_set + query_edges(MultiDiEdgeQuery const &q) const override { + return g.query_edges(q); + } + + using IOutputLabelledMultiDiGraph::query_edges; + + UnorderedOutputLabelledMultiDiGraph *clone() const override { + return new UnorderedOutputLabelledMultiDiGraph( + g, node_labelling, output_labelling); } private: - std::unordered_map output_map; + UnorderedOutputLabelledMultiDiGraph( + MultiDiGraph const &g, + UnorderedLabelling const &node_labelling, + UnorderedLabelling const &output_labelling) + : g(g), node_labelling(node_labelling), + output_labelling(output_labelling) {} + + MultiDiGraph g; + UnorderedLabelling node_labelling; + UnorderedLabelling output_labelling; }; +CHECK_NOT_ABSTRACT(UnorderedOutputLabelledMultiDiGraph); -template -struct UnorderedLabelledOpenMultiDiGraph - : public ILabelledOpenMultiDiGraph, - public UnorderedLabelledMultiDiGraph { -public: - void add_edge(InputMultiDiEdge const &e, InputLabel const &label) { - this->add_edge(e); - this->input_map.insert({e, label}); +template +struct UnorderedOutputLabelledOpenMultiDiGraph + : public IOutputLabelledOpenMultiDiGraph { + + UnorderedOutputLabelledOpenMultiDiGraph() + : g(OpenMultiDiGraph::create()) {} + + EdgeLabel const &at(InputMultiDiEdge const &i) const override { + return this->input_labelling.get_label(i); } - void add_edge(OutputMultiDiEdge const &e, OutputLabel const &label) { - this->add_edge(e); - this->output_map.insert({e, label}); + EdgeLabel &at(InputMultiDiEdge const &i) override { + return this->input_labelling.get_label(i); } - InputLabel const &at(InputMultiDiEdge const &e) const { - return this->input_map.at(e); + EdgeLabel const &at(MultiDiOutput const &i) const override { + return this->output_labelling.get_label(i); } - InputLabel &at(InputMultiDiEdge const &e) { - return this->input_map.at(e); + EdgeLabel &at(MultiDiOutput const &i) override { + return this->output_labelling.get_label(i); } - OutputLabel const &at(OutputMultiDiEdge const &e) const { - return this->output_map.at(e); + Node add_node(NodeLabel const &l) override { + Node node = g.add_node(); + this->node_labelling.add_label(node, l); + return node; } - OutputLabel &at(DownwardOpenMultiDiEdge const &e) { - return this->output_map.at(e); + NodePort add_node_port() override { + return this->g.add_node_port(); } - UnorderedLabelledOpenMultiDiGraph() { - NOT_IMPLEMENTED(); + NodeLabel const &at(Node const &n) const override { + return this->node_labelling.get_label(n); + } + + NodeLabel &at(Node const &n) override { + return this->node_labelling.get_label(n); + } + + void add_label(MultiDiOutput const &o, EdgeLabel const &l) override { + this->output_labelling.add_label(o, l); + } + + void add_label(InputMultiDiEdge const &i, EdgeLabel const &l) override { + this->input_labelling.add_label(i, l); + } + + void add_edge(OpenMultiDiEdge const &e) override { + this->g.add_edge(e); + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->g.query_nodes(q); + } + + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const override { + return this->g.query_edges(q); + } + + using IOutputLabelledOpenMultiDiGraph::query_edges; + + UnorderedOutputLabelledOpenMultiDiGraph *clone() const override { + return new UnorderedOutputLabelledOpenMultiDiGraph( + g, node_labelling, input_labelling, output_labelling); } private: - OpenMultiDiGraph base_graph; - std::unordered_map input_map; - std::unordered_map output_map; + UnorderedOutputLabelledOpenMultiDiGraph( + OpenMultiDiGraph const &g, + UnorderedLabelling const &node_labelling, + UnorderedLabelling const &input_labelling, + UnorderedLabelling const &output_labelling) + : g(g), node_labelling(node_labelling), input_labelling(input_labelling), + output_labelling(output_labelling) {} + + OpenMultiDiGraph g; + UnorderedLabelling node_labelling; + UnorderedLabelling input_labelling; + UnorderedLabelling output_labelling; }; +CHECK_NOT_ABSTRACT( + UnorderedOutputLabelledOpenMultiDiGraph); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index 9c39dbf107..e31afad916 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -90,13 +90,13 @@ Impl materialize_output_labelled_multidigraph_view( } template + typename OutputLabelImpl> OutputLabelledOpenMultiDiGraph - materialize_output_labelled_open_multidigraph_view( + materialize_output_labelled_multidigraph_view( OutputLabelledOpenMultiDiGraphView const &g) { OutputLabelledOpenMultiDiGraph result = OutputLabelledOpenMultiDiGraph::template create< diff --git a/lib/utils/include/utils/graph/labelled_graphs.h b/lib/utils/include/utils/graph/labelled_graphs.h index 5c4b29038a..9cf5f0d97e 100644 --- a/lib/utils/include/utils/graph/labelled_graphs.h +++ b/lib/utils/include/utils/graph/labelled_graphs.h @@ -10,6 +10,7 @@ #include "labelled/output_labelled_open.h" #include "labelled/standard_labelled.h" #include "labelled/unordered_label.h" +#include "labelled/unordered_labelled_graphs.h" #include "labelled/views.h" #endif diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index e0bc94ca8c..c32ff6ded5 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -22,7 +22,7 @@ std::unordered_set } IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } diff --git a/lib/utils/test/CMakeLists.txt b/lib/utils/test/CMakeLists.txt index be4b33129b..97253b4ab7 100644 --- a/lib/utils/test/CMakeLists.txt +++ b/lib/utils/test/CMakeLists.txt @@ -1,14 +1,14 @@ -# ff_add_test_executable( -# NAME -# utils-test -# SRC_PATTERNS -# src/*.cc -# PRIVATE_INCLUDE -# src/ -# DEPS -# utils -# doctest -# utils-test-common -# ) +ff_add_test_executable( + NAME + utils-test + SRC_PATTERNS + src/test_cow_ptr.cc + PRIVATE_INCLUDE + src/ + DEPS + utils + doctest + utils-test-common +) add_subdirectory(common) diff --git a/lib/utils/test/src/test_cow_ptr.cc b/lib/utils/test/src/test_cow_ptr.cc new file mode 100644 index 0000000000..ce8516f21b --- /dev/null +++ b/lib/utils/test/src/test_cow_ptr.cc @@ -0,0 +1,60 @@ +#include "test/utils/doctest.h" +#include "utils/graph/cow_ptr_t.h" +#include +#include +#include + +using namespace FlexFlow; + +struct TestObject { + TestObject(int x) : x(x) {} + int x; + virtual TestObject *clone() const { + return new TestObject(x); + } +}; + +struct TestObjectDerived : public TestObject { + TestObjectDerived(int x, int y) : TestObject(x), y(y) {} + int y; + TestObjectDerived *clone() const override { + return new TestObjectDerived(x, y); + } +}; + +TEST_CASE("cow_ptr_t constructor") { + std::shared_ptr sp = std::make_shared(1); + cow_ptr_t p1(sp); + cow_ptr_t p2(std::make_shared(3)); + cow_ptr_t p3(TestObject(2)); + cow_ptr_t p4(p3); + cow_ptr_t p5 = p1; + CHECK(p1->x == 1); + CHECK(p2->x == 3); + CHECK(p3->x == 2); + CHECK(p4->x == p3->x); + CHECK(p5->x == p1->x); +} + +TEST_CASE("cow_ptr_t copy") { + cow_ptr_t p1(std::make_shared(1)); + cow_ptr_t p2(std::make_shared(2)); + p1 = p2; + CHECK(p1->x == p2->x); +} + +TEST_CASE("cow_ptr_t cast") { + cow_ptr_t p1(std::make_shared(1, 2)); + cow_ptr_t p2(p1); + CHECK(p2->x == 1); +} + +TEST_CASE("cow_ptr_t get_mutable") { + cow_ptr_t p1(std::make_shared(1)); + cow_ptr_t p2(p1); + p1.get_mutable()->x = 3; + CHECK(p1->x == 3); + CHECK(p2->x == 1); + p2.get_mutable()->x = 2; + CHECK(p1->x == 3); +} From c0015df306fca409d9b6b08edfdee548edae3a3c Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 18 Mar 2024 15:52:02 -0400 Subject: [PATCH 20/32] fmt --- .../include/utils/graph/labelled/node_labelled_interfaces.h | 2 +- lib/utils/test/src/test_cow_ptr.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h b/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h index 37fb4db715..c371a9a3bd 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h @@ -20,7 +20,7 @@ CHECK_RC_COPY_VIRTUAL_COMPLIANT(INodeLabelledMultiDiGraphView); template struct INodeLabelledMultiDiGraph - : virtual INodeLabelledMultiDiGraphView { + : virtual INodeLabelledMultiDiGraphView { virtual NodeLabel &at(Node const &) = 0; virtual Node add_node(NodeLabel const &l) = 0; virtual NodePort add_node_port() = 0; diff --git a/lib/utils/test/src/test_cow_ptr.cc b/lib/utils/test/src/test_cow_ptr.cc index ce8516f21b..62406bddec 100644 --- a/lib/utils/test/src/test_cow_ptr.cc +++ b/lib/utils/test/src/test_cow_ptr.cc @@ -16,7 +16,7 @@ struct TestObject { struct TestObjectDerived : public TestObject { TestObjectDerived(int x, int y) : TestObject(x), y(y) {} - int y; + int y; TestObjectDerived *clone() const override { return new TestObjectDerived(x, y); } From 102f5fb2ed3c0440ecb8288d0aa04789ea16f2b8 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 22 Mar 2024 13:54:22 -0700 Subject: [PATCH 21/32] Fix post-merge --- .flake/patches/doctest-template-test.patch | 50 ++++++ .flake/pkgs/fmt.nix | 73 ++++++++ .flake/pkgs/rapidcheck.nix | 48 ++++++ .github/workflows/helpers/build_libs.sh | 9 + .../helpers/{build_cuda.sh => cmake_cuda.sh} | 17 +- .github/workflows/helpers/test_libs.sh | 14 ++ .github/workflows/per-lib-check.yml | 38 ++--- CMakeLists.txt | 2 +- cmake/doctest.cmake | 9 - cmake/doctestlib.cmake | 11 ++ cmake/flexflow-utils.cmake | 4 +- cmake/fmt.cmake | 3 +- cmake/nccl.cmake | 1 + cmake/rapidcheck.cmake | 6 +- cmake/spdlog.cmake | 6 +- flake.nix | 79 ++++++--- lib/compiler/CMakeLists.txt | 3 +- lib/compiler/include/compiler/compiler.h | 4 +- .../include/compiler/machine_mapping.h | 2 +- lib/compiler/src/graph_utils.cc | 4 +- lib/compiler/src/machine_mapping.cc | 8 +- .../test/src/test_labelled_open_graph.cc | 2 + .../include/op-attrs/operator_attrs.h | 5 +- lib/pcg/include/pcg/device_id.h | 1 + lib/pcg/include/pcg/optimizer.h | 22 +-- .../include/substitutions/attribute_expr.h | 2 +- .../include/substitutions/get_attribute.h | 52 +++--- .../include/substitutions/operator_pattern.h | 6 +- .../include/substitutions/output_graph.h | 2 +- .../substitutions/parallel_tensor_pattern.h | 4 +- lib/substitutions/src/graph_pattern.cc | 88 +++++----- lib/substitutions/src/graph_pattern_match.cc | 24 +-- lib/substitutions/src/operator_attributes.cc | 110 ++++++------ lib/substitutions/src/substitution.cc | 156 +++++++++--------- lib/utils/include/utils/containers.decl.h | 8 +- lib/utils/include/utils/containers.h | 4 +- lib/utils/include/utils/dot_file.h | 7 +- .../graph/labelled/output_labelled_open.h | 4 +- lib/utils/include/utils/variant.h | 2 +- lib/utils/src/graph/open_edge.cc | 6 +- lib/utils/src/graph/serialparallel.cc | 6 +- lib/utils/test/src/test_variant.cc | 42 ++--- 42 files changed, 586 insertions(+), 358 deletions(-) create mode 100644 .flake/patches/doctest-template-test.patch create mode 100644 .flake/pkgs/fmt.nix create mode 100644 .flake/pkgs/rapidcheck.nix create mode 100755 .github/workflows/helpers/build_libs.sh rename .github/workflows/helpers/{build_cuda.sh => cmake_cuda.sh} (67%) create mode 100755 .github/workflows/helpers/test_libs.sh delete mode 100644 cmake/doctest.cmake create mode 100644 cmake/doctestlib.cmake diff --git a/.flake/patches/doctest-template-test.patch b/.flake/patches/doctest-template-test.patch new file mode 100644 index 0000000000..ca4d0d9a18 --- /dev/null +++ b/.flake/patches/doctest-template-test.patch @@ -0,0 +1,50 @@ +diff --git a/scripts/cmake/doctestAddTests.cmake b/scripts/cmake/doctestAddTests.cmake +index 3b25485..d3ba906 100644 +--- a/scripts/cmake/doctestAddTests.cmake ++++ b/scripts/cmake/doctestAddTests.cmake +@@ -56,12 +56,14 @@ foreach(line ${output}) + if("${line}" STREQUAL "===============================================================================" OR "${line}" MATCHES [==[^\[doctest\] ]==]) + continue() + endif() +- set(test ${line}) ++ set(unescaped_test ${line}) ++ # use escape commas to handle properly test cases with commas inside the name ++ string(REPLACE "," "\\," escaped_test ${unescaped_test}) + set(labels "") + if(${add_labels}) + # get test suite that test belongs to + execute_process( +- COMMAND ${TEST_EXECUTOR} "${TEST_EXECUTABLE}" --test-case=${test} --list-test-suites ++ COMMAND ${TEST_EXECUTOR} "${TEST_EXECUTABLE}" --test-case=${escaped_test} --list-test-suites + OUTPUT_VARIABLE labeloutput + RESULT_VARIABLE labelresult + WORKING_DIRECTORY "${TEST_WORKING_DIR}" +@@ -85,24 +87,22 @@ foreach(line ${output}) + + if(NOT "${junit_output_dir}" STREQUAL "") + # turn testname into a valid filename by replacing all special characters with "-" +- string(REGEX REPLACE "[/\\:\"|<>]" "-" test_filename "${test}") ++ string(REGEX REPLACE "[/\\:\"|<>]" "-" test_filename "${unescaped_test}") + set(TEST_JUNIT_OUTPUT_PARAM "--reporters=junit" "--out=${junit_output_dir}/${prefix}${test_filename}${suffix}.xml") + else() + unset(TEST_JUNIT_OUTPUT_PARAM) + endif() +- # use escape commas to handle properly test cases with commas inside the name +- string(REPLACE "," "\\," test_name ${test}) + # ...and add to script + add_command(add_test +- "${prefix}${test}${suffix}" ++ "${prefix}${unescaped_test}${suffix}" + ${TEST_EXECUTOR} + "${TEST_EXECUTABLE}" +- "--test-case=${test_name}" ++ "--test-case=${escaped_test}" + "${TEST_JUNIT_OUTPUT_PARAM}" + ${extra_args} + ) + add_command(set_tests_properties +- "${prefix}${test}${suffix}" ++ "${prefix}${unescaped_test}${suffix}" + PROPERTIES + WORKING_DIRECTORY "${TEST_WORKING_DIR}" + ${properties} diff --git a/.flake/pkgs/fmt.nix b/.flake/pkgs/fmt.nix new file mode 100644 index 0000000000..e2677bdea2 --- /dev/null +++ b/.flake/pkgs/fmt.nix @@ -0,0 +1,73 @@ +{ lib +, stdenv +, fetchFromGitHub, fetchpatch +, cmake +, enableShared ? !stdenv.hostPlatform.isStatic + +# tests +, mpd +, openimageio +, fcitx5 +, spdlog +}: + +let + generic = { version, sha256, patches ? [ ] }: + stdenv.mkDerivation { + pname = "fmt"; + inherit version; + + outputs = [ "out" "dev" ]; + + src = fetchFromGitHub { + owner = "fmtlib"; + repo = "fmt"; + rev = version; + inherit sha256; + }; + + inherit patches; + + nativeBuildInputs = [ cmake ]; + + cmakeFlags = [ + "-DBUILD_SHARED_LIBS=${if enableShared then "ON" else "OFF"}" + ]; + + doCheck = true; + + passthru.tests = { + inherit mpd openimageio fcitx5 spdlog; + }; + + meta = with lib; { + description = "Small, safe and fast formatting library"; + longDescription = '' + fmt (formerly cppformat) is an open-source formatting library. It can be + used as a fast and safe alternative to printf and IOStreams. + ''; + homepage = "https://fmt.dev/"; + changelog = "https://github.com/fmtlib/fmt/blob/${version}/ChangeLog.rst"; + downloadPage = "https://github.com/fmtlib/fmt/"; + maintainers = [ maintainers.jdehaas ]; + license = licenses.mit; + platforms = platforms.all; + }; + }; +in +{ + fmt_8 = generic { + version = "8.1.1"; + sha256 = "sha256-leb2800CwdZMJRWF5b1Y9ocK0jXpOX/nwo95icDf308="; + }; + + fmt_9 = generic { + version = "9.1.0"; + sha256 = "sha256-rP6ymyRc7LnKxUXwPpzhHOQvpJkpnRFOt2ctvUNlYI0="; + }; + + fmt_10 = generic { + version = "10.1.1"; + sha256 = "sha256-H9+1lEaHM12nzXSmo9m8S6527t+97e6necayyjCPm1A="; + }; +} diff --git a/.flake/pkgs/rapidcheck.nix b/.flake/pkgs/rapidcheck.nix new file mode 100644 index 0000000000..3ff63207b2 --- /dev/null +++ b/.flake/pkgs/rapidcheck.nix @@ -0,0 +1,48 @@ +{ lib +, stdenv +, fetchFromGitHub +, cmake +, unstableGitUpdater +, testers +}: + +stdenv.mkDerivation (finalAttrs: { + pname = "rapidcheck"; + version = "unstable-2023-12-14"; + + src = fetchFromGitHub { + owner = "emil-e"; + repo = "rapidcheck"; + rev = "ff6af6fc683159deb51c543b065eba14dfcf329b"; + hash = "sha256-Ixz5RpY0n8Un/Pv4XoTfbs40+70iyMbkQUjDqoLaWOg="; + }; + + nativeBuildInputs = [ cmake ]; + + cmakeFlags = [ + (lib.cmakeBool "BUILD_SHARED_LIBS" (!stdenv.hostPlatform.isStatic)) + (lib.cmakeBool "RC_INSTALL_ALL_EXTRAS" true) + ]; + + passthru = { + updateScript = unstableGitUpdater { }; + tests.pkg-config = testers.testMetaPkgConfig finalAttrs.finalPackage; + }; + + meta = with lib; { + description = "A C++ framework for property based testing inspired by QuickCheck"; + inherit (finalAttrs.src.meta) homepage; + maintainers = with maintainers; [ ]; + license = licenses.bsd2; + pkgConfigModules = [ + "rapidcheck" + # Extras + "rapidcheck_boost" + "rapidcheck_boost_test" + "rapidcheck_catch" + "rapidcheck_doctest" + "rapidcheck_gtest" + ]; + platforms = platforms.all; + }; +}) diff --git a/.github/workflows/helpers/build_libs.sh b/.github/workflows/helpers/build_libs.sh new file mode 100755 index 0000000000..cc4e25cc0b --- /dev/null +++ b/.github/workflows/helpers/build_libs.sh @@ -0,0 +1,9 @@ +#! /usr/bin/env bash + +set -euo pipefail + +DIR="$(realpath -- "$(dirname "${BASH_SOURCE[0]}")")" +REPO="$(realpath -- "$DIR/../../../")" + +cd "$REPO/build-ci" +make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) "$@" diff --git a/.github/workflows/helpers/build_cuda.sh b/.github/workflows/helpers/cmake_cuda.sh similarity index 67% rename from .github/workflows/helpers/build_cuda.sh rename to .github/workflows/helpers/cmake_cuda.sh index 3524f885a7..e549859a5a 100755 --- a/.github/workflows/helpers/build_cuda.sh +++ b/.github/workflows/helpers/cmake_cuda.sh @@ -8,22 +8,21 @@ REPO="$(realpath -- "$DIR/../../../")" export FF_GPU_BACKEND="cuda" export FF_CUDA_ARCH=70 -cd "$REPO" -mkdir build -cd build + +if [[ -d "$REPO/build-ci" ]]; then + rm -rf "$REPO/build-ci" +fi +mkdir "$REPO/build-ci" +cd "$REPO/build-ci" #if [[ "${FF_GPU_BACKEND}" == "cuda" ]]; then # export FF_BUILD_ALL_EXAMPLES=ON # export FF_BUILD_UNIT_TESTS=ON #fi +IFS=" " read -r -a FLAGS <<< "$CMAKE_FLAGS" ../config/config.linux \ - -DCMAKE_CXX_COMPILER="clang++" \ - -DCMAKE_C_COMPILER="clang" \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ - -DFF_USE_EXTERNAL_LEGION=ON \ - -DFF_USE_EXTERNAL_JSON=ON \ - -DFF_USE_EXTERNAL_FMT=ON \ - -DFF_USE_EXTERNAL_SPDLOG=ON + "${FLAGS[@]}" # vim: set tabstop=2 shiftwidth=2 expandtab: diff --git a/.github/workflows/helpers/test_libs.sh b/.github/workflows/helpers/test_libs.sh new file mode 100755 index 0000000000..7662a7e601 --- /dev/null +++ b/.github/workflows/helpers/test_libs.sh @@ -0,0 +1,14 @@ +#! /usr/bin/env bash + +set -euo pipefail +set -x + +DIR="$(realpath -- "$(dirname "${BASH_SOURCE[0]}")")" +REPO="$(realpath -- "$DIR/../../../")" + +TEST_LIBS=("${@/%/-tests}") +REGEX="^$(IFS='|'; echo "${TEST_LIBS[*]}")\$" + +cd "$REPO/build-ci" +make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) "${TEST_LIBS[@]}" +ctest --progress --output-on-failure -L "$REGEX" diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 4685983ce0..f1d069f252 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -20,6 +20,9 @@ jobs: with: submodules: recursive + - name: Add helpers directory to path + run: echo "${PWD}/.github/workflows/helpers" >> $GITHUB_PATH + - name: Install nix uses: cachix/install-nix-action@v25 with: @@ -51,49 +54,36 @@ jobs: - name: Run cmake run: | - .github/workflows/helpers/build_${{ matrix.gpu_backend }}.sh + cmake_${{ matrix.gpu_backend }}.sh - name: Build utils run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) utils + build_libs.sh utils - name: Build op-attrs run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) op-attrs + build_libs.sh op-attrs - name: Build pcg run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) pcg + build_libs.sh pcg - name: Build kernels run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) kernels + build_libs.sh kernels - name: Build substitutions run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) substitutions + build_libs.sh substitutions - name: Build compiler run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) compiler - - - name: Build substitutions-test - run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) substitutions-test + build_libs.sh compiler - - name: Build compiler-test + - name: Test substitutions run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) compiler-test + test_libs.sh substitutions - - name: Unit tests + - name: Test compiler run: | - cd build - ctest + test_libs.sh compiler diff --git a/CMakeLists.txt b/CMakeLists.txt index e04aa622c2..032bf1ac55 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,7 +84,7 @@ include(nccl) include(json) include(expected) include(spdlog) -include(doctest) +include(doctestlib) # named doctestlib to avoid a name collision with doctest.cmake in rapidcheck include(visit_struct) include(CTest) include(fmt) diff --git a/cmake/doctest.cmake b/cmake/doctest.cmake deleted file mode 100644 index b2d5243574..0000000000 --- a/cmake/doctest.cmake +++ /dev/null @@ -1,9 +0,0 @@ -include(aliasing) - -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest) -include(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest/scripts/cmake/doctest.cmake) - -add_library(doctest-ff INTERFACE) -target_compile_definitions(doctest-ff INTERFACE DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS) -target_link_libraries(doctest-ff INTERFACE doctest::doctest) -alias_library(doctest doctest-ff) diff --git a/cmake/doctestlib.cmake b/cmake/doctestlib.cmake new file mode 100644 index 0000000000..5f29d94fd0 --- /dev/null +++ b/cmake/doctestlib.cmake @@ -0,0 +1,11 @@ +include(aliasing) + +if (FF_USE_EXTERNAL_DOCTEST) + find_package(doctest REQUIRED) + include(doctest) # import doctest_discover_tests +else() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest) + include(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest/scripts/cmake/doctest.cmake) +endif() + +alias_library(doctest doctest::doctest) diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index d41573acab..4cf5450942 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -118,7 +118,9 @@ function(ff_add_test_executable) ${FF_TEST_EXEC_NAME} ${FF_TEST_EXEC_DEPS}) + target_compile_definitions(${FF_TEST_EXEC_NAME} PRIVATE FF_TEST_SUITE="${FF_TEST_EXEC_NAME}") + define_ff_vars(${FF_TEST_EXEC_NAME}) ff_set_cxx_properties(${FF_TEST_EXEC_NAME}) - doctest_discover_tests(${FF_TEST_EXEC_NAME}) + doctest_discover_tests(${FF_TEST_EXEC_NAME} ADD_LABELS 1) endfunction() diff --git a/cmake/fmt.cmake b/cmake/fmt.cmake index 283caad69d..470de6a847 100644 --- a/cmake/fmt.cmake +++ b/cmake/fmt.cmake @@ -4,6 +4,5 @@ if (FF_USE_EXTERNAL_FMT) find_package(fmt REQUIRED) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/fmt) - - alias_library(fmt fmt::fmt) endif() +alias_library(fmt fmt::fmt) diff --git a/cmake/nccl.cmake b/cmake/nccl.cmake index e89bee04c6..755fe00f1b 100644 --- a/cmake/nccl.cmake +++ b/cmake/nccl.cmake @@ -8,6 +8,7 @@ else() message(STATUS "Building NCCL from source") list(TRANSFORM CUDA_GENCODE PREPEND "NVCC_GENCODE=" OUTPUT_VARIABLE NCCL_BUILD_NVCC_GENCODE) + include(ExternalProject) ExternalProject_Add(nccl_source_build SOURCE_DIR ${PROJECT_SOURCE_DIR}/deps/${NCCL_NAME} PREFIX ${CMAKE_BINARY_DIR}/deps/${NCCL_NAME} diff --git a/cmake/rapidcheck.cmake b/cmake/rapidcheck.cmake index 1ff64bd974..bf8f058e63 100644 --- a/cmake/rapidcheck.cmake +++ b/cmake/rapidcheck.cmake @@ -1 +1,5 @@ -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/rapidcheck) +if (FF_USE_EXTERNAL_RAPIDCHECK) + find_package(rapidcheck REQUIRED) +else() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/rapidcheck) +endif() diff --git a/cmake/spdlog.cmake b/cmake/spdlog.cmake index cd18944460..02021fd51e 100644 --- a/cmake/spdlog.cmake +++ b/cmake/spdlog.cmake @@ -4,6 +4,8 @@ if (FF_USE_EXTERNAL_SPDLOG) find_package(spdlog REQUIRED) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/spdlog) - - alias_library(spdlog spdlog::spdlog) endif() + +add_library(spdlog INTERFACE) +target_link_libraries(spdlog INTERFACE spdlog::spdlog) +target_compile_definitions(spdlog INTERFACE SPDLOG_FMT_EXTERNAL) diff --git a/flake.nix b/flake.nix index 3d357ca86c..540d0f9a94 100644 --- a/flake.nix +++ b/flake.nix @@ -13,7 +13,6 @@ ]; }; - # Nixpkgs / NixOS version to use. inputs = { nixpkgs.url = "nixpkgs/nixos-23.11"; flake-utils.url = "github:numtide/flake-utils"; @@ -25,51 +24,84 @@ inherit system; config.allowUnfree = true; }; + lib = pkgs.lib; mkShell = pkgs.mkShell.override { - stdenv = pkgs.llvmPackages.libcxxStdenv; + stdenv = pkgs.cudaPackages.backendStdenv; }; in - { - packages = { - legion = pkgs.callPackage ./.flake/pkgs/legion.nix { }; + { + packages = { + legion = pkgs.callPackage ./.flake/pkgs/legion.nix { }; + rapidcheckFull = pkgs.symlinkJoin { + name = "rapidcheckFull"; + paths = (with pkgs; [ rapidcheck.out rapidcheck.dev ]); }; + doctest = pkgs.doctest.overrideAttrs ( old: rec { + version = "2.4.9"; + src = pkgs.fetchFromGitHub { + owner = "doctest"; + repo = "doctest"; + rev = "v${version}"; + sha256 = "sha256-ugmkeX2PN4xzxAZpWgswl4zd2u125Q/ADSKzqTfnd94="; + }; + patches = [ + ./.flake/patches/doctest-template-test.patch + ]; + }); + }; - devShells = rec { - ci = mkShell { - buildInputs = (with pkgs; [ - llvmPackages_17.clang - cmakeCurses - gcc10Stdenv - gcc10 - ccache - cudatoolkit + devShells = rec { + ci = mkShell { + CMAKE_FLAGS = lib.strings.concatStringsSep " " [ + "-DFF_USE_EXTERNAL_LEGION=ON" + "-DFF_USE_EXTERNAL_NCCL=ON" + "-DFF_USE_EXTERNAL_JSON=ON" + "-DFF_USE_EXTERNAL_FMT=ON" + "-DFF_USE_EXTERNAL_SPDLOG=ON" + "-DFF_USE_EXTERNAL_DOCTEST=ON" + "-DFF_USE_EXTERNAL_RAPIDCHECK=ON" + "-DFF_USE_EXTERNAL_RANGEV3=ON" + "-DFF_USE_EXTERNAL_BOOST_PREPROCESSOR=ON" + "-DFF_USE_EXTERNAL_TYPE_INDEX=ON" + ]; + + buildInputs = builtins.concatLists [ + (with pkgs; [ zlib - pkg-config - python3 - self.packages.${system}.legion + boost nlohmann_json spdlog range-v3 - rapidcheck - doctest fmt + cmakeCurses + ccache + pkg-config + python3 + cudatoolkit cudaPackages.cuda_nvcc cudaPackages.cudnn cudaPackages.nccl cudaPackages.libcublas cudaPackages.cuda_cudart - ]) ++ (with pkgs.python3Packages; [ - ]); + ]) + (with self.packages.${system}; [ + legion + rapidcheckFull + doctest + ]) + ]; }; default = mkShell { inputsFrom = [ ci ]; - + inherit (ci) CMAKE_FLAGS; + buildInputs = builtins.concatLists [ (with pkgs; [ - clang-tools_17 + ccls gh-markdown-preview + shellcheck plantuml gdb ruff @@ -96,4 +128,3 @@ } ); } -# vim: set tabstop=2 shiftwidth=2 expandtab: diff --git a/lib/compiler/CMakeLists.txt b/lib/compiler/CMakeLists.txt index 6610834eed..a2933efa50 100644 --- a/lib/compiler/CMakeLists.txt +++ b/lib/compiler/CMakeLists.txt @@ -11,11 +11,10 @@ ff_add_library( op-attrs utils json - optional pcg spdlog substitutions ) add_subdirectory(ffi) -add_subdirectory(test) \ No newline at end of file +add_subdirectory(test) diff --git a/lib/compiler/include/compiler/compiler.h b/lib/compiler/include/compiler/compiler.h index 3a75e3a9bf..a4f7b0ecd3 100644 --- a/lib/compiler/include/compiler/compiler.h +++ b/lib/compiler/include/compiler/compiler.h @@ -12,8 +12,8 @@ enum class SearchAlgorithm { DATA_PARALLEL, }; -using SearchAlgorithmConfig = variant<>; -using SearchSolution = variant<>; +using SearchAlgorithmConfig = std::variant<>; +using SearchSolution = std::variant<>; struct SearchResult { ParallelComputationGraph pcg; diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 185f2706ef..8b21b9522f 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -53,7 +53,7 @@ class OptimalCostCache { public: OptimalCostCache() = default; - optional load(OptimalCostState const &) const; + std::optional load(OptimalCostState const &) const; void save(OptimalCostState const &, OptimalCostResult const &); private: diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 3c6e44216b..5b76beb8c0 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -125,14 +125,14 @@ std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { std::unordered_set get_nodes(Serial const &serial) { return set_union( - transform(serial.children, [](variant const child) { + transform(serial.children, [](std::variant const child) { return visit(GetNodes{}, child); })); } std::unordered_set get_nodes(Parallel const ¶llel) { return set_union( - transform(parallel.children, [](variant const child) { + transform(parallel.children, [](std::variant const child) { return visit(GetNodes{}, child); })); } diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index b48e200c15..2b08e9fe23 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -43,13 +43,13 @@ bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, return lhs.runtime < rhs.runtime; } -optional +std::optional OptimalCostCache::load(OptimalCostState const &state) const { if (contains_key(cache, state)) { OptimalCostResult result = cache.at(state); - return make_optional(result); + return std::make_optional(result); } - return nullopt; + return std::nullopt; } void OptimalCostCache::save(OptimalCostState const &state, @@ -152,7 +152,7 @@ struct MachineMappingSearcher { OptimalCostResult operator()(T const &t) { OptimalCostState state{ t, resource, given_machine_views, frontier_machine_views}; - optional cached_result = + std::optional cached_result = searcher->cached_subgraph_costs.load(state); if (cached_result) { diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index a3b6319528..dfe1f6301c 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -4,6 +4,7 @@ using namespace FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { // TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { // auto g = OpenMultiDiGraph::create(); @@ -142,3 +143,4 @@ TEST_CASE("OutputLabelledOpenMultiDiGraph") { // CHECK(get_edges(g).size() == 1); // } +} diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 9da787cbf8..678a049c3b 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -32,6 +32,7 @@ #include "ops/topk.h" #include "ops/transpose.h" #include "utils/variant.h" +#include namespace FlexFlow { @@ -85,8 +86,8 @@ static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); -using ParallelOperatorAttrs = std:: - variant; +using ParallelOperatorAttrs = + std::variant; using ComputationGraphAttrs = variant_join>; diff --git a/lib/pcg/include/pcg/device_id.h b/lib/pcg/include/pcg/device_id.h index 50c2558e39..b118d69259 100644 --- a/lib/pcg/include/pcg/device_id.h +++ b/lib/pcg/include/pcg/device_id.h @@ -3,6 +3,7 @@ #include "device_type.h" #include "utils/strong_typedef.h" +#include namespace FlexFlow { diff --git a/lib/pcg/include/pcg/optimizer.h b/lib/pcg/include/pcg/optimizer.h index df5bddf729..0bb3fab974 100644 --- a/lib/pcg/include/pcg/optimizer.h +++ b/lib/pcg/include/pcg/optimizer.h @@ -7,21 +7,21 @@ namespace FlexFlow { struct SGDOptimizer { - req lr; - req momentum; - req nesterov; + double lr; + double momentum; + bool nesterov; req weight_decay; }; FF_VISITABLE_STRUCT(SGDOptimizer, lr, momentum, nesterov, weight_decay); struct AdamOptimizer { - req alpha; - req beta1; - req beta2; - req weight_decay; - req epsilon; - req alpha_t; - req beta_t; + double alpha; + double beta1; + double beta2; + double weight_decay; + double epsilon; + double alpha_t; + double beta_t; req beta2_t; }; FF_VISITABLE_STRUCT(AdamOptimizer, @@ -34,7 +34,7 @@ FF_VISITABLE_STRUCT(AdamOptimizer, beta_t, beta2_t); -using Optimizer = variant; +using Optimizer = std::variant; } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/attribute_expr.h b/lib/substitutions/include/substitutions/attribute_expr.h index d6902d1274..0afd48b431 100644 --- a/lib/substitutions/include/substitutions/attribute_expr.h +++ b/lib/substitutions/include/substitutions/attribute_expr.h @@ -19,7 +19,7 @@ struct ListSize { }; template -using AttributeExpr = variant, ListSize>; +using AttributeExpr = std::variant, ListSize>; template struct AttributeConstraint { diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/get_attribute.h index 50c4108a67..7088730c53 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/get_attribute.h @@ -7,57 +7,57 @@ namespace FlexFlow { -optional get_attribute(PCGOperatorAttrs const &, +std::optional get_attribute(PCGOperatorAttrs const &, OperatorAttributeKey); -optional get_attribute(BatchMatmulAttrs const &p, +std::optional get_attribute(BatchMatmulAttrs const &p, OperatorAttributeKey); -optional get_attribute(CastAttrs const &p, +std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey); -optional get_attribute(CombineAttrs const &p, +std::optional get_attribute(CombineAttrs const &p, OperatorAttributeKey); -optional get_attribute(ConcatAttrs const &p, +std::optional get_attribute(ConcatAttrs const &p, OperatorAttributeKey); -optional get_attribute(Conv2DAttrs const &p, +std::optional get_attribute(Conv2DAttrs const &p, OperatorAttributeKey); -optional get_attribute(ElementBinaryAttrs const &p, +std::optional get_attribute(ElementBinaryAttrs const &p, OperatorAttributeKey); -optional get_attribute(ElementUnaryAttrs const &p, +std::optional get_attribute(ElementUnaryAttrs const &p, OperatorAttributeKey); -optional get_attribute(DropoutAttrs const &p, +std::optional get_attribute(DropoutAttrs const &p, OperatorAttributeKey); -optional get_attribute(ElementScalarUnaryAttrs const &p, +std::optional get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey); -optional get_attribute(EmbeddingAttrs const &p, +std::optional get_attribute(EmbeddingAttrs const &p, OperatorAttributeKey); -optional get_attribute(FlatAttrs const &p, +std::optional get_attribute(FlatAttrs const &p, OperatorAttributeKey); -optional get_attribute(GatherAttrs const &p, +std::optional get_attribute(GatherAttrs const &p, OperatorAttributeKey); -optional get_attribute(LayerNormAttrs const &p, +std::optional get_attribute(LayerNormAttrs const &p, OperatorAttributeKey); -optional get_attribute(LinearAttrs const &p, +std::optional get_attribute(LinearAttrs const &p, OperatorAttributeKey); -optional get_attribute(MultiHeadAttentionAttrs const &p, +std::optional get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); -optional get_attribute(Pool2DAttrs const &p, +std::optional get_attribute(Pool2DAttrs const &p, OperatorAttributeKey); -optional get_attribute(ReduceAttrs const &p, +std::optional get_attribute(ReduceAttrs const &p, OperatorAttributeKey); -optional get_attribute(ReductionAttrs const &p, +std::optional get_attribute(ReductionAttrs const &p, OperatorAttributeKey); -optional get_attribute(RepartitionAttrs const &p, +std::optional get_attribute(RepartitionAttrs const &p, OperatorAttributeKey); -optional get_attribute(ReplicateAttrs const &p, +std::optional get_attribute(ReplicateAttrs const &p, OperatorAttributeKey); -optional get_attribute(ReshapeAttrs const &p, +std::optional get_attribute(ReshapeAttrs const &p, OperatorAttributeKey); -optional get_attribute(SplitAttrs const &p, +std::optional get_attribute(SplitAttrs const &p, OperatorAttributeKey); -optional get_attribute(SoftmaxAttrs const &p, +std::optional get_attribute(SoftmaxAttrs const &p, OperatorAttributeKey); -optional get_attribute(TopKAttrs const &p, +std::optional get_attribute(TopKAttrs const &p, OperatorAttributeKey); -optional get_attribute(TransposeAttrs const &p, +std::optional get_attribute(TransposeAttrs const &p, OperatorAttributeKey); // optional get_attribute(FusedParallelOpAttrs const &p, // OperatorAttributeKey); diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index 9392a7876e..35544f3003 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -70,7 +70,7 @@ enum class OperatorAttributeKey { NUM_INPUTS }; -using OperatorAttributeValue = variant, @@ -81,7 +81,7 @@ using OperatorAttributeValue = variant, - optional, + std::optional, PoolOp, TensorShape, DataType>; @@ -97,7 +97,7 @@ using OperatorAttributeConstraint = using OperatorPattern = AttributePattern; -optional +std::optional evaluate_attribute_expr(Operator const &attrs, AttributeExpr const &expr); diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index b9cf1f53f3..4ed90aed06 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -15,7 +15,7 @@ struct AttrConstant { OperatorAttributeValue value; }; -using OperatorAttributeExpr = variant; +using OperatorAttributeExpr = std::variant; // NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can // define the assignment for each operator type. diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h index d07a1da23b..741554142f 100644 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -8,7 +8,7 @@ namespace FlexFlow { enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; -using TensorAttributeValue = variant>; +using TensorAttributeValue = std::variant>; using TensorAttributeConstraint = AttributeConstraint; @@ -16,7 +16,7 @@ using TensorAttributeConstraint = using ParallelTensorPattern = AttributePattern; -optional +std::optional evaluate_attribute_expr(ParallelTensor const &tensor_shape, AttributeExpr const &expr); diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 1dba5c4af8..6f933dd300 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -9,51 +9,51 @@ namespace FlexFlow { -optional +std::optional evaluate_list_index_access(int index, - optional const &v) { + std::optional const &v) { if (!v.has_value() || - !holds_alternative>(v.value()) || - !holds_alternative>(v.value())) { - return nullopt; + !std::holds_alternative>(v.value()) || + !std::holds_alternative>(v.value())) { + return std::nullopt; } if (index >= MAX_TENSOR_DIM) { - return nullopt; + return std::nullopt; } - if (holds_alternative>(v.value())) { + if (std::holds_alternative>(v.value())) { return get>(v.value()).at(index); } else { return get>(v.value()).at(index); } } -optional +std::optional evaluate_list_index_access(int const &index, - optional const &v) { - if (!v.has_value() || !holds_alternative>(v.value())) { - return nullopt; + std::optional const &v) { + if (!v.has_value() || !std::holds_alternative>(v.value())) { + return std::nullopt; } auto vec = get>(v.value()); if (index >= vec.size()) { - return nullopt; + return std::nullopt; } return vec.at(index); } -optional - evaluate_list_size(optional const &v) { +std::optional + evaluate_list_size(std::optional const &v) { return MAX_TENSOR_DIM; } -optional - evaluate_list_size(optional const &v) { - if (!v.has_value() || !holds_alternative>(v.value())) { - return nullopt; +std::optional + evaluate_list_size(std::optional const &v) { + if (!v.has_value() || !std::holds_alternative>(v.value())) { + return std::nullopt; } return (int)get>(v.value()).size(); @@ -62,20 +62,20 @@ optional struct EvaluateOperatorAttributeExpr { EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} - optional operator()(OperatorAttributeKey const &key) { + std::optional operator()(OperatorAttributeKey const &key) { return get_attribute(this->attrs.attrs, key); } - optional + std::optional operator()(ListIndexAccess const &index_access) { - optional v = + std::optional v = get_attribute(this->attrs.attrs, index_access.attribute_key); return evaluate_list_index_access(index_access.index, v); } - optional + std::optional operator()(ListSize const &list_size) { - optional v = + std::optional v = get_attribute(this->attrs.attrs, list_size.attribute_key); return evaluate_list_size(v); } @@ -84,7 +84,7 @@ struct EvaluateOperatorAttributeExpr { Operator attrs; }; -optional +std::optional evaluate_tensor_attribute_expr(ParallelTensor const &, AttributeExpr const &); @@ -93,11 +93,11 @@ struct EvaluateTensorAttributeExpr { : tensor_shape(tensor_shape) {} template - optional evaluate(T const &t) { + std::optional evaluate(T const &t) { return this->operator()(t); } - optional operator()(TensorAttributeKey key) { + std::optional operator()(TensorAttributeKey key) { switch (key) { case TensorAttributeKey::DIM_SIZES: { std::vector result; @@ -118,14 +118,14 @@ struct EvaluateTensorAttributeExpr { } } - optional + std::optional operator()(ListIndexAccess const &index_access) { - optional v = + std::optional v = this->evaluate(index_access.attribute_key); return evaluate_list_index_access(index_access.index, v); } - optional + std::optional operator()(ListSize const &list_size) { return evaluate_list_size(this->evaluate(list_size.attribute_key)); } @@ -134,29 +134,29 @@ struct EvaluateTensorAttributeExpr { ParallelTensor tensor_shape; }; -optional +std::optional evaluate_attribute_expr(ParallelTensor const &tensor_shape, AttributeExpr const &expr) { return visit(EvaluateTensorAttributeExpr(tensor_shape), expr); } -optional +std::optional evaluate_attribute_expr(Operator const &attrs, AttributeExpr const &expr) { return visit(EvaluateOperatorAttributeExpr(attrs), expr); } template -optional satisfies(ConstraintType constraint_type, +std::optional satisfies(ConstraintType constraint_type, V const &constraint_value, - optional const &maybe_attribute_value) { + std::optional const &maybe_attribute_value) { if (!maybe_attribute_value.has_value()) { - return nullopt; + return std::nullopt; } V attr_val = maybe_attribute_value.value(); if (attr_val.index() != constraint_value.index()) { - return nullopt; + return std::nullopt; } if (constraint_type == ConstraintType::EQUAL) { @@ -166,14 +166,14 @@ optional satisfies(ConstraintType constraint_type, } } -optional satisfies(ParallelTensor const &tensor_shape, +std::optional satisfies(ParallelTensor const &tensor_shape, TensorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); return satisfies( constraint.constraint_type, constraint.attribute_value, value); } -optional satisfies(Operator const ¶ms, +std::optional satisfies(Operator const ¶ms, OperatorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(params, constraint.attribute_expr); OperatorAttributeValue v = value.value(); @@ -182,12 +182,12 @@ optional satisfies(Operator const ¶ms, } template -optional optional_all_of(Container const &container, +std::optional optional_all_of(Container const &container, Function const &func) { for (auto const &element : container) { - optional condition = func(element); + std::optional condition = func(element); if (!condition.has_value()) { - return nullopt; + return std::nullopt; } if (!condition.value()) { @@ -197,7 +197,7 @@ optional optional_all_of(Container const &container, return true; } -optional satisfies(Operator const ¶ms, +std::optional satisfies(Operator const ¶ms, OperatorPattern const &pattern) { return optional_all_of(pattern.attribute_constraints, [&](OperatorAttributeConstraint const &c) { @@ -205,7 +205,7 @@ optional satisfies(Operator const ¶ms, }); } -optional satisfies(ParallelTensor const ¶ms, +std::optional satisfies(ParallelTensor const ¶ms, ParallelTensorPattern const &pattern) { return optional_all_of( pattern.attribute_constraints, @@ -229,7 +229,7 @@ bool assignment_satisfies(SubParallelComputationGraph const &pcg, for (auto const &kv : patternMatch.node_assignment) { Node patternNode = kv.first; Node pcgNode = kv.second; - optional constraintResult = + std::optional constraintResult = satisfies(pcg.at(pcgNode), pattern.value().at(patternNode)); result &= constraintResult.value_or(false); } @@ -237,7 +237,7 @@ bool assignment_satisfies(SubParallelComputationGraph const &pcg, for (auto const &kv : patternMatch.edge_assignment) { OpenMultiDiEdge patternEdge = kv.first; OpenMultiDiEdge pcgEdge = kv.second; - optional constraintResult = + std::optional constraintResult = satisfies(pcg.at(pcgEdge), pattern.value().at(patternEdge)); result &= constraintResult.value_or(false); } diff --git a/lib/substitutions/src/graph_pattern_match.cc b/lib/substitutions/src/graph_pattern_match.cc index 7114c2d8ce..f9c6b9a773 100644 --- a/lib/substitutions/src/graph_pattern_match.cc +++ b/lib/substitutions/src/graph_pattern_match.cc @@ -56,7 +56,7 @@ MatchSplit apply_split(OpenMultiDiGraphView const &pattern, } else { assert(is_standard_edge(pattern_edge)); assert(is_standard_edge(graph_edge)); - auto standard_edge = mpark::get(pattern_edge); + auto standard_edge = std::get(pattern_edge); auto divided = edge_splits.at_l(standard_edge); auto divided_graph_edge = split_edge(get(graph_edge)); handle_edge(divided.first, divided_graph_edge.first); @@ -98,7 +98,7 @@ bool pattern_matches(OpenMultiDiGraphView const &pattern, } UpwardOpenMultiDiEdge matched_edge = narrow(graph_matched_edge).value(); - InputMultiDiEdge input_edge = mpark::get(e); + InputMultiDiEdge input_edge = std::get(e); if (match.node_assignment.at_l(input_edge.dst) != get_dst_node(matched_edge)) { return false; @@ -109,7 +109,7 @@ bool pattern_matches(OpenMultiDiGraphView const &pattern, } DownwardOpenMultiDiEdge matched_edge = narrow(graph_matched_edge).value(); - OutputMultiDiEdge output_edge = mpark::get(e); + OutputMultiDiEdge output_edge = std::get(e); if (match.node_assignment.at_l(output_edge.src) != get_src_node(matched_edge)) { return false; @@ -148,7 +148,7 @@ bool src_compare(T const &lhs, T const &rhs) { return get_src_idx(lhs) < get_src_idx(rhs); } -optional +std::optional get_candidate_singleton_match(OpenMultiDiGraphView const &pattern, OpenMultiDiGraphView const &graph, Node const &graph_node) { @@ -170,11 +170,11 @@ optional get_outgoing_edges(pattern, pattern_node); if (!pattern_incoming.empty() && pattern_incoming.size() != incoming.size()) { - return nullopt; + return std::nullopt; } if (!pattern_outgoing.empty() && pattern_outgoing.size() != outgoing.size()) { - return nullopt; + return std::nullopt; } std::vector incoming_ordered = @@ -198,7 +198,7 @@ optional node_port_mapping.emplace(graph_port, pattern_port); } else { if (pattern_port != node_port_mapping.at(graph_port)) { - return nullopt; + return std::nullopt; } } match.edge_assignment.equate(widen(pattern_edge), @@ -217,7 +217,7 @@ optional node_port_mapping.insert({graph_port, pattern_port}); } else { if (pattern_port != node_port_mapping.at(graph_port)) { - return nullopt; + return std::nullopt; } } match.edge_assignment.equate(widen(pattern_edge), @@ -228,7 +228,7 @@ optional return match; } -optional unsplit_matches( +std::optional unsplit_matches( MultiDiGraphPatternMatch const &prefix, MultiDiGraphPatternMatch const &postfix, bidict> const @@ -248,7 +248,7 @@ optional unsplit_matches( if (output_graph_edge == input_graph_edge) { result.edge_assignment.equate(standard_edge, output_graph_edge); } else { - return nullopt; + return std::nullopt; } } @@ -272,7 +272,7 @@ std::vector std::vector matches; if (is_singleton_pattern(pattern)) { for (Node const &graph_node : get_nodes(graph)) { - optional candidate = + std::optional candidate = get_candidate_singleton_match(pattern, graph, graph_node); if (candidate.has_value() && pattern_matches( @@ -290,7 +290,7 @@ std::vector auto edge_splits = get_edge_splits(pattern, split); for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { - optional unsplit = + std::optional unsplit = unsplit_matches(prefix_match, postfix_match, edge_splits); if (unsplit.has_value()) { matches.push_back(unsplit.value()); diff --git a/lib/substitutions/src/operator_attributes.cc b/lib/substitutions/src/operator_attributes.cc index 3922b091a7..76533507a3 100644 --- a/lib/substitutions/src/operator_attributes.cc +++ b/lib/substitutions/src/operator_attributes.cc @@ -3,25 +3,25 @@ namespace FlexFlow { -optional get_attribute(BatchMatmulAttrs const &p, +std::optional get_attribute(BatchMatmulAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(CastAttrs const &p, +std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: return p.dtype; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(CombineAttrs const &p, +std::optional get_attribute(CombineAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: @@ -29,21 +29,21 @@ optional get_attribute(CombineAttrs const &p, case OperatorAttributeKey::PARALLEL_DIM: return p.combine_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ConcatAttrs const &p, +std::optional get_attribute(ConcatAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.axis; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(Conv2DAttrs const &p, +std::optional get_attribute(Conv2DAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: @@ -65,43 +65,43 @@ optional get_attribute(Conv2DAttrs const &p, case OperatorAttributeKey::USE_BIAS: return p.use_bias; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ElementBinaryAttrs const &p, +std::optional get_attribute(ElementBinaryAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ElementUnaryAttrs const &p, +std::optional get_attribute(ElementUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ElementScalarUnaryAttrs const &p, +std::optional get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(DropoutAttrs const &p, +std::optional get_attribute(DropoutAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(EmbeddingAttrs const &p, +std::optional get_attribute(EmbeddingAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: @@ -113,37 +113,37 @@ optional get_attribute(EmbeddingAttrs const &p, case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(FlatAttrs const &p, +std::optional get_attribute(FlatAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(GatherAttrs const &p, +std::optional get_attribute(GatherAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.dim; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(LayerNormAttrs const &p, +std::optional get_attribute(LayerNormAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(LinearAttrs const &p, +std::optional get_attribute(LinearAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::OUT_CHANNELS: @@ -159,11 +159,11 @@ optional get_attribute(LinearAttrs const &p, case OperatorAttributeKey::REGULARIZER: return p.regularizer; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(MultiHeadAttentionAttrs const &p, +std::optional get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::NUM_HEADS: @@ -171,11 +171,11 @@ optional get_attribute(MultiHeadAttentionAttrs const &p, case OperatorAttributeKey::USE_BIAS: return p.bias; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(Pool2DAttrs const &p, +std::optional get_attribute(Pool2DAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: @@ -195,19 +195,19 @@ optional get_attribute(Pool2DAttrs const &p, case OperatorAttributeKey::ACTIVATION: return p.activation; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReduceAttrs const &p, +std::optional get_attribute(ReduceAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReductionAttrs const &p, +std::optional get_attribute(ReductionAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: @@ -215,11 +215,11 @@ optional get_attribute(ReductionAttrs const &p, case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.reduction_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(RepartitionAttrs const &p, +std::optional get_attribute(RepartitionAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: @@ -227,11 +227,11 @@ optional get_attribute(RepartitionAttrs const &p, case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.repartition_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReplicateAttrs const &p, +std::optional get_attribute(ReplicateAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: @@ -239,53 +239,53 @@ optional get_attribute(ReplicateAttrs const &p, case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.replicate_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReshapeAttrs const &p, +std::optional get_attribute(ReshapeAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(SplitAttrs const &p, +std::optional get_attribute(SplitAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.axis; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(SoftmaxAttrs const &p, +std::optional get_attribute(SoftmaxAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.dim; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(TopKAttrs const &p, +std::optional get_attribute(TopKAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(TransposeAttrs const &p, +std::optional get_attribute(TransposeAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PERMUTATION: return p.perm; default: - return nullopt; + return std::nullopt; } } @@ -293,7 +293,7 @@ struct GetAttribute { GetAttribute(OperatorAttributeKey key) : key(key) {} template - optional operator()(T const &t) { + std::optional operator()(T const &t) { return get_attribute(t, this->key); } @@ -303,17 +303,17 @@ struct GetAttribute { struct GetOpType { template - optional operator()(T const &t) { + std::optional operator()(T const &t) { return get_op_type(t); } }; -optional get_attribute(PCGOperatorAttrs const &p, +std::optional get_attribute(PCGOperatorAttrs const &p, OperatorAttributeKey key) { if (key == OperatorAttributeKey::OP_TYPE) { - return visit(GetOpType{}, p); + return std::visit(GetOpType{}, p); } - return visit(GetAttribute(key), p); + return std::visit(GetAttribute(key), p); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 31659b88fc..4f6572948a 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -113,49 +113,49 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, assignments.emplace(key, value); } assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); - assert(holds_alternative( + assert(std::holds_alternative( assignments.at(OperatorAttributeKey::OP_TYPE))); OperatorType op_type = - get(assignments.at(OperatorAttributeKey::OP_TYPE)); + std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); switch (op_type) { case Op::BATCHMATMUL: return Operator{ BatchMatmulAttrs{ - get(assignments.at(OperatorAttributeKey::A_SEQ_LENGTH_DIM)), - get(assignments.at(OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, + std::get(assignments.at(OperatorAttributeKey::A_SEQ_LENGTH_DIM)), + std::get(assignments.at(OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, std::nullopt}; case Op::BATCHNORM: return Operator{ - BatchNormAttrs{get(assignments.at(OperatorAttributeKey::RELU))}, + BatchNormAttrs{std::get(assignments.at(OperatorAttributeKey::RELU))}, std::nullopt}; case Op::CAST: - return Operator{CastAttrs{get( + return Operator{CastAttrs{std::get( assignments.at(OperatorAttributeKey::DATA_TYPE))}, std::nullopt}; case Op::CONCAT: return Operator{ ConcatAttrs{ - get(assignments.at(OperatorAttributeKey::AXIS)), - get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, + std::get(assignments.at(OperatorAttributeKey::AXIS)), + std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, std::nullopt}; case Op::CONV2D: return Operator{ Conv2DAttrs{ - get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - get(assignments.at(OperatorAttributeKey::KERNEL_H)), - get(assignments.at(OperatorAttributeKey::KERNEL_W)), - get(assignments.at(OperatorAttributeKey::STRIDE_H)), - get(assignments.at(OperatorAttributeKey::STRIDE_W)), - get(assignments.at(OperatorAttributeKey::PADDING_H)), - get(assignments.at(OperatorAttributeKey::PADDING_W)), - get(assignments.at(OperatorAttributeKey::GROUPS)), - get(assignments.at(OperatorAttributeKey::ACTIVATION)), - get(assignments.at(OperatorAttributeKey::USE_BIAS))}, + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + std::get(assignments.at(OperatorAttributeKey::GROUPS)), + std::get(assignments.at(OperatorAttributeKey::ACTIVATION)), + std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, std::nullopt}; case Op::DROPOUT: return Operator{ - DropoutAttrs{get(assignments.at(OperatorAttributeKey::RATE)), - get( + DropoutAttrs{std::get(assignments.at(OperatorAttributeKey::RATE)), + std::get( assignments.at(OperatorAttributeKey::SEED))}, std::nullopt}; case Op::EW_ADD: @@ -170,10 +170,10 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, return Operator{ ElementBinaryAttrs{ op_type, - get(assignments.at(OperatorAttributeKey::DATA_TYPE)), - get( + std::get(assignments.at(OperatorAttributeKey::DATA_TYPE)), + std::get( assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_LHS)), - get( + std::get( assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, std::nullopt}; case Op::SCALAR_ADD: @@ -184,7 +184,7 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, return Operator{ ElementScalarUnaryAttrs{ op_type, - get(assignments.at(OperatorAttributeKey::SCALAR))}, + std::get(assignments.at(OperatorAttributeKey::SCALAR))}, std::nullopt}; case Op::EXP: case Op::IDENTITY: @@ -197,63 +197,63 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::EMBEDDING: return Operator{ EmbeddingAttrs{ - get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), - get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - get(assignments.at(OperatorAttributeKey::AGGR)), - get(assignments.at(OperatorAttributeKey::OP_TYPE))}, + std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::AGGR)), + std::get(assignments.at(OperatorAttributeKey::OP_TYPE))}, std::nullopt}; case Op::FLAT: return Operator{FlatAttrs{}, std::nullopt}; case Op::GATHER: return Operator{ - GatherAttrs{get(assignments.at(OperatorAttributeKey::DIM))}, + GatherAttrs{std::get(assignments.at(OperatorAttributeKey::DIM))}, std::nullopt}; case Op::INPUT: return Operator{InputAttrs{}, std::nullopt}; case Op::LAYERNORM: return Operator{ LayerNormAttrs{ - get>( + std::get>( assignments.at(OperatorAttributeKey::AXES)), - get( + std::get( assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), - get(assignments.at(OperatorAttributeKey::EPSILON))}, + std::get(assignments.at(OperatorAttributeKey::EPSILON))}, std::nullopt}; case Op::LINEAR: return Operator{ LinearAttrs{ - get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - get(assignments.at(OperatorAttributeKey::USE_BIAS)), - get(assignments.at(OperatorAttributeKey::DATA_TYPE)), - get(assignments.at(OperatorAttributeKey::ACTIVATION)), - get>( + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), + std::get(assignments.at(OperatorAttributeKey::DATA_TYPE)), + std::get(assignments.at(OperatorAttributeKey::ACTIVATION)), + std::get>( assignments.at(OperatorAttributeKey::REGULARIZER))}, std::nullopt}; case Op::MULTIHEAD_ATTENTION: return Operator{ MultiHeadAttentionAttrs{ - get(assignments.at(OperatorAttributeKey::EMBED_DIM)), - get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - get(assignments.at(OperatorAttributeKey::VDIM)), - get(assignments.at(OperatorAttributeKey::DROPOUT)), - get(assignments.at(OperatorAttributeKey::BIAS)), - get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), - get(assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, + std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), + std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + std::get(assignments.at(OperatorAttributeKey::VDIM)), + std::get(assignments.at(OperatorAttributeKey::DROPOUT)), + std::get(assignments.at(OperatorAttributeKey::BIAS)), + std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), + std::get(assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, std::nullopt}; case Op::NOOP: return Operator{NoopAttrs{}, std::nullopt}; case Op::POOL2D: return Operator{ Pool2DAttrs{ - get(assignments.at(OperatorAttributeKey::KERNEL_H)), - get(assignments.at(OperatorAttributeKey::KERNEL_W)), - get(assignments.at(OperatorAttributeKey::STRIDE_H)), - get(assignments.at(OperatorAttributeKey::STRIDE_W)), - get(assignments.at(OperatorAttributeKey::PADDING_H)), - get(assignments.at(OperatorAttributeKey::PADDING_W)), - get(assignments.at(OperatorAttributeKey::POOL_TYPE)), - get( + std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), + std::get( assignments.at(OperatorAttributeKey::ACTIVATION))}, std::nullopt}; case Op::REDUCE_ARGMAX: @@ -265,65 +265,65 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::REDUCE_SUM: return Operator{ ReduceAttrs{ - get>( + std::get>( assignments.at(OperatorAttributeKey::AXES)), op_type, - get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, + std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, std::nullopt}; case Op::REVERSE: - return Operator{ReverseAttrs{get( + return Operator{ReverseAttrs{std::get( assignments.at(OperatorAttributeKey::AXIS))}, std::nullopt}; case Op::RESHAPE: - return Operator{ReshapeAttrs{get( + return Operator{ReshapeAttrs{std::get( assignments.at(OperatorAttributeKey::SHAPE))}, std::nullopt}; case Op::SPLIT: return Operator{ - SplitAttrs{get>( + SplitAttrs{std::get>( assignments.at(OperatorAttributeKey::SPLITS)), - get(assignments.at(OperatorAttributeKey::AXIS))}, + std::get(assignments.at(OperatorAttributeKey::AXIS))}, std::nullopt}; case Op::SOFTMAX: - return Operator{SoftmaxAttrs{get( + return Operator{SoftmaxAttrs{std::get( assignments.at(OperatorAttributeKey::DIM))}, std::nullopt}; case Op::TOPK: return Operator{ - TopKAttrs{get(assignments.at(OperatorAttributeKey::K)), - get(assignments.at(OperatorAttributeKey::SORTED))}, + TopKAttrs{std::get(assignments.at(OperatorAttributeKey::K)), + std::get(assignments.at(OperatorAttributeKey::SORTED))}, std::nullopt}; case Op::TRANSPOSE: return Operator{ - TransposeAttrs{get>( + TransposeAttrs{std::get>( assignments.at(OperatorAttributeKey::PERMUTATION))}, std::nullopt}; case Op::COMBINE: return Operator{ CombineAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; case Op::REDUCTION: return Operator{ ReductionAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; case Op::REPARTITION: return Operator{ RepartitionAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; case Op::REPLICATE: return Operator{ ReplicateAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; default: - mk_runtime_error("Unknown Operator"); + throw mk_runtime_error("Unknown Operator"); } } @@ -435,23 +435,23 @@ SubParallelComputationGraph } for (OpenMultiDiEdge const &output_edge : get_edges(substitution.output_graph_expr.value())) { - if (holds_alternative(output_edge)) { - InputMultiDiEdge e = get(output_edge); + if (std::holds_alternative(output_edge)) { + InputMultiDiEdge e = std::get(output_edge); OpenMultiDiEdge original_edge = match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, original_edge, output_edge); - } else if (holds_alternative(output_edge)) { - OutputMultiDiEdge e = get(output_edge); + } else if (std::holds_alternative(output_edge)) { + OutputMultiDiEdge e = std::get(output_edge); OpenMultiDiEdge original_edge = match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, original_edge, output_edge); } else { - assert(holds_alternative(output_edge)); - MultiDiEdge e = get(output_edge); + assert(std::holds_alternative(output_edge)); + MultiDiEdge e = std::get(output_edge); new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), new_pcg.add_node_port(), node_mapping.at_l(e.src), diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index ed47226297..40ac0a4a1c 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -2,11 +2,11 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_DECL_H #include "utils/bidict.h" -#include "utils/optional.decl" #include "utils/required_core.h" #include "utils/type_traits_core.h" #include #include +#include namespace FlexFlow { @@ -108,7 +108,7 @@ template std::vector values(C const &c); template -std::unordered_set> +std::unordered_set> items(C const &c); template @@ -291,10 +291,10 @@ template T reversed(T const &t); template -std::vector value_all(std::vector> const &v); +std::vector value_all(std::vector> const &v); template -std::unordered_set value_all(std::unordered_set> const &v); +std::unordered_set value_all(std::unordered_set> const &v); template std::vector subvec(std::vector const &v, diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 750c43abee..1606eb0605 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -674,8 +674,8 @@ std::vector value_all(std::vector> const &v) { } template -std::unordered_set value_all(std::unordered_set> const &v) { - return transform(v, [](optional const &element) { +std::unordered_set value_all(std::unordered_set> const &v) { + return transform(v, [](std::optional const &element) { return unwrap(element, [] { throw mk_runtime_error( "Encountered element without value in call to value_all"); diff --git a/lib/utils/include/utils/dot_file.h b/lib/utils/include/utils/dot_file.h index 6cdc78f6d4..6cf06d12a7 100644 --- a/lib/utils/include/utils/dot_file.h +++ b/lib/utils/include/utils/dot_file.h @@ -10,6 +10,7 @@ #include #include #include +#include template class DotFile { @@ -28,16 +29,16 @@ class DotFile { return s.str(); } bool has_ostream() const { - return this->owned_fstream.has_value() || this->out.has_value(); + return this->owned_fstream.has_value() || this->out != nullptr; } std::ostream &get_ostream() { bool has_owned_stream = this->owned_fstream.has_value(); - bool has_stream_ref = this->out.has_value(); + bool has_stream_ref = (this->out != nullptr); assert(has_owned_stream != has_stream_ref); if (has_owned_stream) { return this->owned_fstream.value(); } else if (has_stream_ref) { - return this->out.value(); + return *this->out; } else { throw std::runtime_error("No ostream value set"); } diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 1ccf881d97..027c3243b9 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -32,12 +32,12 @@ struct OutputLabelledOpenMultiDiGraphView } template - EdgeLabel const &at(variant const &e) const { + EdgeLabel const &at(std::variant const &e) const { return visit([&](auto const &e) -> auto const & { return this->at(e); }, e); } template - EdgeLabel &at(variant const &e) { + EdgeLabel &at(std::variant const &e) { return visit([&](auto const &e) -> auto & { return this->at(e); }, e); } diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index 420f9736d1..feb263335a 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -42,7 +42,7 @@ struct elements_satisfy> : elements_satisfy_impl {}; template -struct is_in_variant; +struct is_in_variant : std::false_type {}; template struct is_in_variant> : std::true_type {}; template diff --git a/lib/utils/src/graph/open_edge.cc b/lib/utils/src/graph/open_edge.cc index b12f87dd1c..1b571d5c6c 100644 --- a/lib/utils/src/graph/open_edge.cc +++ b/lib/utils/src/graph/open_edge.cc @@ -3,15 +3,15 @@ namespace FlexFlow { bool is_input_edge(OpenMultiDiEdge const &e) { - return holds_alternative(e); + return std::holds_alternative(e); } bool is_output_edge(OpenMultiDiEdge const &e) { - return holds_alternative(e); + return std::holds_alternative(e); } bool is_standard_edge(OpenMultiDiEdge const &e) { - return holds_alternative(e); + return std::holds_alternative(e); } OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery( diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 4a6a056d59..f1c9e41005 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -142,7 +142,7 @@ SplitASTNode::SplitASTNode(SplitType type, struct FlattenAST { void add_flattened_child_to_parent(SplitASTNode &parent, SplitAST const &child) { - if (holds_alternative(child)) { + if (std::holds_alternative(child)) { parent.children.push_back(child); return; } @@ -178,11 +178,11 @@ struct ToFinalAST { std::variant operator()(SplitASTNode const &node) { if (node.type == SplitType::SERIAL) { return Serial{transform(node.children, [](SplitAST const &s) { - return narrow>(to_final_ast(s)).value(); + return narrow>(to_final_ast(s)).value(); })}; } else { return Parallel{transform(node.children, [](SplitAST const &s) { - return narrow>(to_final_ast(s)).value(); + return narrow>(to_final_ast(s)).value(); })}; } } diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 031defd417..1494f0ac27 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -3,61 +3,61 @@ TEST_CASE("widen and narrow functions") { SUBCASE("widen function") { - variant v1 = 42; - variant result = widen>(v1); - variant expected = 42; + std::variant v1 = 42; + std::variant result = widen>(v1); + std::variant expected = 42; CHECK(result == expected); } SUBCASE("narrow function fail") { - variant v2 = + std::variant v2 = 3.14; // this is a doule, because 3.14 default to double - optional> result = narrow>(v2); - optional> expected = float(3.14); + std::optional> result = narrow>(v2); + std::optional> expected = float(3.14); CHECK(!result.has_value()); // result should be empty due to narrowing } SUBCASE("narrow function success") { - variant v2 = + std::variant v2 = 3.14; // this is a doule, because 3.14 default to double - optional> result = narrow>(v2); - optional> expected = 3.14; + std::optional> result = narrow>(v2); + std::optional> expected = 3.14; CHECK(result == expected); // } SUBCASE("cast function") { - variant v3 = 42; - optional> result = cast>(v3); - optional> expected = 42; + std::variant v3 = 42; + std::optional> result = cast>(v3); + std::optional> expected = 42; CHECK(result == expected); } } TEST_CASE("Narrow and cast variants") { - variant original_variant = 42; + std::variant original_variant = 42; // narrow - optional> narrow_result = - narrow>(original_variant); + std::optional> narrow_result = + narrow>(original_variant); CHECK(narrow_result.has_value()); // assert narrow has value // cast - optional> cast_result = - cast>(narrow_result.value()); + std::optional> cast_result = + cast>(narrow_result.value()); CHECK(cast_result.has_value()); // assert cast has value CHECK(get(cast_result.value()) == 42); } TEST_CASE("casting and widening a variant") { - variant smaller_variant = 42; - variant wider_variant; + std::variant smaller_variant = 42; + std::variant wider_variant; // Perform the cast operation - optional> cast_result = cast>(smaller_variant); + std::optional> cast_result = cast>(smaller_variant); REQUIRE(cast_result); // Ensure the cast was successful // Perform the widening operation - wider_variant = widen>(cast_result.value()); + wider_variant = widen>(cast_result.value()); // Check the result CHECK(get(wider_variant) == 42); From d6e10bb0d579f2328e9ce6d355205ca69bc1a6dc Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 22 Mar 2024 17:11:35 -0700 Subject: [PATCH 22/32] Add shell hook for sapling development --- flake.nix | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flake.nix b/flake.nix index 540d0f9a94..595fedac46 100644 --- a/flake.nix +++ b/flake.nix @@ -53,6 +53,10 @@ devShells = rec { ci = mkShell { + shellHook = '' + export PATH="$HOME/ff/.scripts/:$HOME/ff/.modules/proj/bin/:$PATH" + ''; + CMAKE_FLAGS = lib.strings.concatStringsSep " " [ "-DFF_USE_EXTERNAL_LEGION=ON" "-DFF_USE_EXTERNAL_NCCL=ON" From 95fb4cc529a7de643bd4e2af532a2aa88a81f60f Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Fri, 22 Mar 2024 17:24:18 -0700 Subject: [PATCH 23/32] changed from nullopt to std::nullopt --- lib/substitutions/test/src/test_substitution.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index a8f5283eda..552d46a98f 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -89,7 +89,7 @@ TEST_CASE("apply_substitution") { Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); Node n5 = pcg.add_node(Operator{ - LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, "linear"}); NodePort p4 = pcg.add_node_port(); NodePort p5 = pcg.add_node_port(); From c09147959670d605f87c8020ef73b0632f0a1faf Mon Sep 17 00:00:00 2001 From: wmdi Date: Sat, 23 Mar 2024 15:48:50 -0400 Subject: [PATCH 24/32] fix cast issue --- lib/compiler/test/CMakeLists.txt | 2 +- .../test/src/test_labelled_open_graph.cc | 220 ++++++++---------- .../utils/graph/labelled/node_labelled.h | 6 +- .../utils/graph/labelled/node_labelled_open.h | 6 +- .../utils/graph/labelled/output_labelled.h | 6 +- .../graph/labelled/output_labelled_open.h | 6 +- .../utils/graph/labelled/standard_labelled.h | 6 +- lib/utils/src/graph/digraph.cc | 6 +- lib/utils/src/graph/multidigraph.cc | 6 +- lib/utils/src/graph/node.cc | 4 +- lib/utils/src/graph/open_graphs.cc | 18 +- lib/utils/src/graph/undirected.cc | 6 +- 12 files changed, 138 insertions(+), 154 deletions(-) diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index 3d35fdabfd..cbd7e233c0 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -2,7 +2,7 @@ ff_add_test_executable( NAME compiler-test SRC_PATTERNS - src/test_labelled_open_graph.cc + src/*.cc PRIVATE_INCLUDE src/ DEPS diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index a3b6319528..74071160cb 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -4,113 +4,85 @@ using namespace FlexFlow; -// TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { -// auto g = OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); -// Node n2 = g.add_node(); -// Node n3 = g.add_node(); -// Node n4 = g.add_node(); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); -// NodePort p2 = g.add_node_port(); -// NodePort p3 = g.add_node_port(); -// NodePort p4 = g.add_node_port(); -// NodePort p5 = g.add_node_port(); -// NodePort p6 = g.add_node_port(); -// NodePort p7 = g.add_node_port(); -// NodePort p8 = g.add_node_port(); -// NodePort p9 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; -// MultiDiEdge e1{n2, p2, n0, p0}; -// MultiDiEdge e2{n3, p5, n1, p3}; -// MultiDiEdge e3{n3, p6, n2, p4}; -// MultiDiEdge e4{n4, p8, n3, p7}; -// OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; - -// g.add_edge(e0); -// g.add_edge(e1); -// g.add_edge(e2); -// g.add_edge(e3); -// g.add_edge(e4); -// g.add_edge(e5); - -// std::unordered_set node_set0{n3, n4}; - -// auto subgraph0 = get_subgraph(g, node_set0); -// auto subgraph1 = get_subgraph(g, node_set0); -// auto subgraph2 = get_subgraph(g, -// node_set0); auto subgraph3 = get_subgraph(g, -// node_set0); - -// CHECK(get_nodes(subgraph0) == node_set0); -// CHECK(get_nodes(subgraph1) == node_set0); -// CHECK(get_nodes(subgraph2) == node_set0); -// CHECK(get_nodes(subgraph3) == node_set0); - -// std::unordered_set input_set{split_edge(e2).second, -// split_edge(e3).second}; -// std::unordered_set output_set{e5}; - -// CHECK(bool(get_open_inputs(subgraph0) == input_set)); -// CHECK(bool(get_open_inputs(subgraph1) == input_set)); -// CHECK(bool(get_open_inputs(subgraph2).empty())); -// CHECK(bool(get_open_inputs(subgraph3).empty())); - -// CHECK(bool(get_open_outputs(subgraph0) == output_set)); -// CHECK(bool(get_open_outputs(subgraph1).empty())); -// CHECK(bool(get_open_outputs(subgraph2) == output_set)); -// CHECK(bool(get_open_outputs(subgraph3).empty())); - -// CHECK(bool(get_edges(subgraph0) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4, e5})); -// CHECK(bool(get_edges(subgraph1) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4})); -// CHECK(bool(get_edges(subgraph2) == -// std::unordered_set{e4, e5})); -// CHECK(bool(get_edges(subgraph3) == -// std::unordered_set{e4})); - -// CHECK(get_closed_sources(subgraph2) == std::unordered_set{n3}); -// } - -// TEST_CASE("view OutputLabelledMultiDiGraph as open") { -// OutputLabelledMultiDiGraph g = -// OutputLabelledMultiDiGraph::create>(); - -// Node n0 = g.add_node(0); -// Node n1 = g.add_node(1); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; - -// g.add_edge(e0); -// g.add_output(e0, 2); - -// CHECK(get_edges(g).size() == 1); - -// OutputLabelledOpenMultiDiGraphView open_graph = -// view_output_labelled_as_output_labelled_open(g); - -// CHECK(open_graph.at(n0) == 0); -// CHECK(open_graph.at(n1) == 1); -// CHECK(open_graph.at(e0) == 2); - -// // CHECK(get_edges(open_graph).size() == 1); -// } +TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { + auto g = OpenMultiDiGraph::create(); -TEST_CASE("OutputLabelledOpenMultiDiGraph") { - OutputLabelledOpenMultiDiGraph g = - OutputLabelledOpenMultiDiGraph::create< - UnorderedOutputLabelledOpenMultiDiGraph>(); + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + Node n3 = g.add_node(); + Node n4 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + NodePort p2 = g.add_node_port(); + NodePort p3 = g.add_node_port(); + NodePort p4 = g.add_node_port(); + NodePort p5 = g.add_node_port(); + NodePort p6 = g.add_node_port(); + NodePort p7 = g.add_node_port(); + NodePort p8 = g.add_node_port(); + NodePort p9 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + MultiDiEdge e1{n2, p2, n0, p0}; + MultiDiEdge e2{n3, p5, n1, p3}; + MultiDiEdge e3{n3, p6, n2, p4}; + MultiDiEdge e4{n4, p8, n3, p7}; + OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + std::unordered_set node_set0{n3, n4}; + + auto subgraph0 = get_subgraph(g, node_set0); + auto subgraph1 = get_subgraph(g, node_set0); + auto subgraph2 = get_subgraph(g, + node_set0); auto subgraph3 = get_subgraph(g, + node_set0); + + CHECK(get_nodes(subgraph0) == node_set0); + CHECK(get_nodes(subgraph1) == node_set0); + CHECK(get_nodes(subgraph2) == node_set0); + CHECK(get_nodes(subgraph3) == node_set0); + + std::unordered_set input_set{split_edge(e2).second, + split_edge(e3).second}; + std::unordered_set output_set{e5}; + + CHECK(bool(get_open_inputs(subgraph0) == input_set)); + CHECK(bool(get_open_inputs(subgraph1) == input_set)); + CHECK(bool(get_open_inputs(subgraph2).empty())); + CHECK(bool(get_open_inputs(subgraph3).empty())); + + CHECK(bool(get_open_outputs(subgraph0) == output_set)); + CHECK(bool(get_open_outputs(subgraph1).empty())); + CHECK(bool(get_open_outputs(subgraph2) == output_set)); + CHECK(bool(get_open_outputs(subgraph3).empty())); + + CHECK(bool(get_edges(subgraph0) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4, e5})); + CHECK(bool(get_edges(subgraph1) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4})); + CHECK(bool(get_edges(subgraph2) == + std::unordered_set{e4, e5})); + CHECK(bool(get_edges(subgraph3) == + std::unordered_set{e4})); + + CHECK(get_closed_sources(subgraph2) == std::unordered_set{n3}); +} + +TEST_CASE("view OutputLabelledMultiDiGraph as open") { + OutputLabelledMultiDiGraph g = + OutputLabelledMultiDiGraph::create>(); Node n0 = g.add_node(0); Node n1 = g.add_node(1); @@ -121,24 +93,36 @@ TEST_CASE("OutputLabelledOpenMultiDiGraph") { MultiDiEdge e0{n1, p1, n0, p0}; g.add_edge(e0); - g.add_label(e0, 2); + g.add_output(e0, 2); - CHECK(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1); CHECK(get_edges(g).size() == 1); + + OutputLabelledOpenMultiDiGraphView open_graph = + view_output_labelled_as_output_labelled_open(g); + + CHECK(open_graph.at(n0) == 0); + CHECK(open_graph.at(n1) == 1); + CHECK(open_graph.at(e0) == 2); + + CHECK(get_edges(open_graph).size() == 1); } -// TEST_CASE("OpenMultiDiGraph") { -// OpenMultiDiGraph g = OpenMultiDiGraph::create(); +TEST_CASE("OutputLabelledOpenMultiDiGraph") { + OutputLabelledOpenMultiDiGraph g = + OutputLabelledOpenMultiDiGraph::create< + UnorderedOutputLabelledOpenMultiDiGraph>(); -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); + Node n0 = g.add_node(0); + Node n1 = g.add_node(1); -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); -// MultiDiEdge e0{n1, p1, n0, p0}; + MultiDiEdge e0{n1, p1, n0, p0}; -// g.add_edge(e0); + g.add_edge(e0); + g.add_label(e0, 2); -// CHECK(get_edges(g).size() == 1); -// } + CHECK(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1); + CHECK(get_edges(g).size() == 1); +} diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 9d8874fb14..9aed91f107 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -41,7 +41,7 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; @@ -97,12 +97,12 @@ struct NodeLabelledMultiDiGraph NodeLabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 826a8387cb..0fea57cab7 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -55,7 +55,7 @@ struct NodeLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; @@ -121,12 +121,12 @@ struct NodeLabelledOpenMultiDiGraph NodeLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index c6c521c38b..8aab0320b5 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -47,7 +47,7 @@ struct OutputLabelledMultiDiGraphView private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; @@ -119,12 +119,12 @@ struct OutputLabelledMultiDiGraph private: Interface &get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 24235bee4c..2be56cb477 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -64,7 +64,7 @@ struct OutputLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; @@ -145,12 +145,12 @@ struct OutputLabelledOpenMultiDiGraph OutputLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index e1c8e91634..c6d1521471 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -47,7 +47,7 @@ struct LabelledMultiDiGraphView : NodeLabelledMultiDiGraphView(ptr) {} Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; @@ -104,12 +104,12 @@ struct LabelledMultiDiGraph LabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index dda9eef5e0..ecad1db3f0 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -14,7 +14,7 @@ std::unordered_set } IDiGraphView const &DiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -48,11 +48,11 @@ std::unordered_set } IDiGraph &DiGraph::get_ptr() { - return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } IDiGraph const &DiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 99a7ea86fa..41ae3e1aa3 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -24,7 +24,7 @@ std::unordered_set } IMultiDiGraphView const &MultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -66,12 +66,12 @@ std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { } IMultiDiGraph const &MultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } IMultiDiGraph &MultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/node.cc b/lib/utils/src/graph/node.cc index 9854afffbf..72caa3136e 100644 --- a/lib/utils/src/graph/node.cc +++ b/lib/utils/src/graph/node.cc @@ -53,11 +53,11 @@ std::unordered_set Graph::query_nodes(NodeQuery const &q) const { } IGraph const &Graph::get_ptr() const { - return *std::reinterpret_pointer_cast(GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } IGraph &Graph::get_ptr() { - return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index c32ff6ded5..387dd7e75b 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -22,7 +22,7 @@ std::unordered_set } IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -56,12 +56,12 @@ NodePort OpenMultiDiGraph::add_node_port() { } IOpenMultiDiGraph &OpenMultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } IOpenMultiDiGraph const &OpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -77,7 +77,7 @@ std::unordered_set } IUpwardOpenMultiDiGraphView const &UpwardOpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -107,12 +107,12 @@ std::unordered_set UpwardOpenMultiDiGraph::query_edges( } IUpwardOpenMultiDiGraph const &UpwardOpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } IUpwardOpenMultiDiGraph &UpwardOpenMultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -129,7 +129,7 @@ std::unordered_set IDownwardOpenMultiDiGraphView const & DownwardOpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -165,12 +165,12 @@ std::unordered_set } IDownwardOpenMultiDiGraph &DownwardOpenMultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } IDownwardOpenMultiDiGraph const &DownwardOpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index ce42cfe22c..b1e8be7f14 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -26,12 +26,12 @@ void UndirectedGraph::remove_edge(UndirectedEdge const &e) { } IUndirectedGraph const &UndirectedGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } IUndirectedGraph &UndirectedGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -56,7 +56,7 @@ std::unordered_set } IUndirectedGraphView const &UndirectedGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } From 54c604af83abf974d90c3c08401f930be72f242f Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 23 Mar 2024 22:10:53 -0700 Subject: [PATCH 25/32] Fix spdlog cmake issue --- cmake/spdlog.cmake | 7 ++++--- flake.nix | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cmake/spdlog.cmake b/cmake/spdlog.cmake index 02021fd51e..5ba1d6cc15 100644 --- a/cmake/spdlog.cmake +++ b/cmake/spdlog.cmake @@ -6,6 +6,7 @@ else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/spdlog) endif() -add_library(spdlog INTERFACE) -target_link_libraries(spdlog INTERFACE spdlog::spdlog) -target_compile_definitions(spdlog INTERFACE SPDLOG_FMT_EXTERNAL) +add_library(ff_spdlog INTERFACE) +target_link_libraries(ff_spdlog INTERFACE spdlog::spdlog) +target_compile_definitions(ff_spdlog INTERFACE SPDLOG_FMT_EXTERNAL) +alias_library(spdlog ff_spdlog) diff --git a/flake.nix b/flake.nix index 595fedac46..d402d3c271 100644 --- a/flake.nix +++ b/flake.nix @@ -103,7 +103,7 @@ buildInputs = builtins.concatLists [ (with pkgs; [ - ccls + clang-tools gh-markdown-preview shellcheck plantuml From 8b914cf1e79655ab0f0ab2f5accbdb8282847480 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 23 Mar 2024 22:17:37 -0700 Subject: [PATCH 26/32] Re-remove submodules --- deps/any | 1 - deps/boost_preprocessor | 1 - deps/googletest | 1 - deps/invoke | 1 - deps/nameof | 1 - deps/optional | 1 - deps/pybind11 | 1 - deps/variant | 1 - 8 files changed, 8 deletions(-) delete mode 160000 deps/any delete mode 160000 deps/boost_preprocessor delete mode 160000 deps/googletest delete mode 160000 deps/invoke delete mode 160000 deps/nameof delete mode 160000 deps/optional delete mode 160000 deps/pybind11 delete mode 160000 deps/variant diff --git a/deps/any b/deps/any deleted file mode 160000 index e88b1bfc16..0000000000 --- a/deps/any +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e88b1bfc160fa9b01e6174dd29c812eeeece3be9 diff --git a/deps/boost_preprocessor b/deps/boost_preprocessor deleted file mode 160000 index 667e87b339..0000000000 --- a/deps/boost_preprocessor +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 667e87b3392db338a919cbe0213979713aca52e3 diff --git a/deps/googletest b/deps/googletest deleted file mode 160000 index 2fe3bd994b..0000000000 --- a/deps/googletest +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 2fe3bd994b3189899d93f1d5a881e725e046fdc2 diff --git a/deps/invoke b/deps/invoke deleted file mode 160000 index 2c1eabc2e2..0000000000 --- a/deps/invoke +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 2c1eabc2e20ab02961f95c704ff0c0818671ddd1 diff --git a/deps/nameof b/deps/nameof deleted file mode 160000 index 8aeb677413..0000000000 --- a/deps/nameof +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8aeb6774132a01765d8c8679d016b728acd069f5 diff --git a/deps/optional b/deps/optional deleted file mode 160000 index c28fcf74d2..0000000000 --- a/deps/optional +++ /dev/null @@ -1 +0,0 @@ -Subproject commit c28fcf74d207fc667c4ed3dbae4c251ea551c8c1 diff --git a/deps/pybind11 b/deps/pybind11 deleted file mode 160000 index 8de7772cc7..0000000000 --- a/deps/pybind11 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8de7772cc72daca8e947b79b83fea46214931604 diff --git a/deps/variant b/deps/variant deleted file mode 160000 index 23cb94f027..0000000000 --- a/deps/variant +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 23cb94f027d4ef33bf48133acc2695c7e5c6f1e7 From 189f32303c9b19739837fed921de94785b1c6c10 Mon Sep 17 00:00:00 2001 From: wmdi Date: Sun, 24 Mar 2024 16:24:01 -0400 Subject: [PATCH 27/32] minor fix & fmt --- ...ive_logger.cc => recursive_logger.cc.todo} | 0 ...rsive_logger.h => recursive_logger.h.todo} | 0 .../test/src/test_labelled_open_graph.cc | 36 +++--- lib/compiler/test/src/test_optimal_cost.cc | 2 +- .../include/op-attrs/operator_attrs.h | 4 +- .../include/substitutions/get_attribute.h | 56 ++++----- .../include/substitutions/operator_pattern.h | 31 ++--- lib/substitutions/src/graph_pattern.cc | 20 +-- lib/substitutions/src/operator_attributes.cc | 56 ++++----- lib/substitutions/src/substitution.cc | 117 ++++++++++-------- lib/utils/include/utils/containers.decl.h | 2 +- lib/utils/include/utils/dot_file.h | 2 +- .../utils/graph/labelled/node_labelled.h | 9 +- .../utils/graph/labelled/node_labelled_open.h | 9 +- .../utils/graph/labelled/output_labelled.h | 9 +- .../graph/labelled/output_labelled_open.h | 9 +- .../utils/graph/labelled/standard_labelled.h | 9 +- lib/utils/include/utils/variant.h | 13 +- lib/utils/src/graph/digraph.cc | 3 +- lib/utils/src/graph/multidigraph.cc | 3 +- lib/utils/test/src/test_variant.cc | 15 ++- 21 files changed, 203 insertions(+), 202 deletions(-) rename lib/compiler/src/utils/{recursive_logger.cc => recursive_logger.cc.todo} (100%) rename lib/compiler/src/utils/{recursive_logger.h => recursive_logger.h.todo} (100%) diff --git a/lib/compiler/src/utils/recursive_logger.cc b/lib/compiler/src/utils/recursive_logger.cc.todo similarity index 100% rename from lib/compiler/src/utils/recursive_logger.cc rename to lib/compiler/src/utils/recursive_logger.cc.todo diff --git a/lib/compiler/src/utils/recursive_logger.h b/lib/compiler/src/utils/recursive_logger.h.todo similarity index 100% rename from lib/compiler/src/utils/recursive_logger.h rename to lib/compiler/src/utils/recursive_logger.h.todo diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index 74071160cb..c59d7ee78a 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -1,6 +1,6 @@ #include "compiler/unity_algorithm.h" #include "doctest/doctest.h" -#include "rapidcheck.h" +// #include "rapidcheck.h" using namespace FlexFlow; @@ -42,14 +42,13 @@ TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { auto subgraph0 = get_subgraph(g, node_set0); auto subgraph1 = get_subgraph(g, node_set0); - auto subgraph2 = get_subgraph(g, - node_set0); auto subgraph3 = get_subgraph(g, - node_set0); + auto subgraph2 = get_subgraph(g, node_set0); + auto subgraph3 = get_subgraph(g, node_set0); - CHECK(get_nodes(subgraph0) == node_set0); - CHECK(get_nodes(subgraph1) == node_set0); - CHECK(get_nodes(subgraph2) == node_set0); - CHECK(get_nodes(subgraph3) == node_set0); + CHECK(bool(get_nodes(subgraph0) == node_set0)); + CHECK(bool(get_nodes(subgraph1) == node_set0)); + CHECK(bool(get_nodes(subgraph2) == node_set0)); + CHECK(bool(get_nodes(subgraph3) == node_set0)); std::unordered_set input_set{split_edge(e2).second, split_edge(e3).second}; @@ -73,16 +72,15 @@ TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { split_edge(e2).second, split_edge(e3).second, e4})); CHECK(bool(get_edges(subgraph2) == std::unordered_set{e4, e5})); - CHECK(bool(get_edges(subgraph3) == - std::unordered_set{e4})); + CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); - CHECK(get_closed_sources(subgraph2) == std::unordered_set{n3}); + CHECK(bool(get_closed_sources(subgraph2) == std::unordered_set{n3})); } TEST_CASE("view OutputLabelledMultiDiGraph as open") { OutputLabelledMultiDiGraph g = - OutputLabelledMultiDiGraph::create>(); + OutputLabelledMultiDiGraph::create< + UnorderedOutputLabelledMultiDiGraph>(); Node n0 = g.add_node(0); Node n1 = g.add_node(1); @@ -95,14 +93,14 @@ TEST_CASE("view OutputLabelledMultiDiGraph as open") { g.add_edge(e0); g.add_output(e0, 2); - CHECK(get_edges(g).size() == 1); + CHECK(bool(get_edges(g).size() == 1)); OutputLabelledOpenMultiDiGraphView open_graph = view_output_labelled_as_output_labelled_open(g); - CHECK(open_graph.at(n0) == 0); - CHECK(open_graph.at(n1) == 1); - CHECK(open_graph.at(e0) == 2); + CHECK(bool(open_graph.at(n0) == 0)); + CHECK(bool(open_graph.at(n1) == 1)); + CHECK(bool(open_graph.at(e0) == 2)); CHECK(get_edges(open_graph).size() == 1); } @@ -123,6 +121,6 @@ TEST_CASE("OutputLabelledOpenMultiDiGraph") { g.add_edge(e0); g.add_label(e0, 2); - CHECK(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1); - CHECK(get_edges(g).size() == 1); + CHECK(bool(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1)); + CHECK(bool(get_edges(g).size() == 1)); } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 9d90285870..5f5f7d093e 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -34,7 +34,7 @@ TEST_CASE("optimal_cost_0") { Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); Node n1 = pcg.add_node(Operator{ - LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, "linear"}); MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 678a049c3b..b63563cd67 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -86,8 +86,8 @@ static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); -using ParallelOperatorAttrs = - std::variant; +using ParallelOperatorAttrs = std:: + variant; using ComputationGraphAttrs = variant_join>; diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/get_attribute.h index 7088730c53..0e6dd4c69b 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/get_attribute.h @@ -8,57 +8,57 @@ namespace FlexFlow { std::optional get_attribute(PCGOperatorAttrs const &, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(BatchMatmulAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(CastAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(CombineAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ConcatAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(Conv2DAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ElementBinaryAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ElementUnaryAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(DropoutAttrs const &p, - OperatorAttributeKey); -std::optional get_attribute(ElementScalarUnaryAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); +std::optional + get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey); std::optional get_attribute(EmbeddingAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(FlatAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(GatherAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(LayerNormAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(LinearAttrs const &p, - OperatorAttributeKey); -std::optional get_attribute(MultiHeadAttentionAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); +std::optional + get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); std::optional get_attribute(Pool2DAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ReduceAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ReductionAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(RepartitionAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ReplicateAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ReshapeAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(SplitAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(SoftmaxAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(TopKAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(TransposeAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); // optional get_attribute(FusedParallelOpAttrs const &p, // OperatorAttributeKey); diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index 35544f3003..8fc4ebefc2 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -70,21 +70,22 @@ enum class OperatorAttributeKey { NUM_INPUTS }; -using OperatorAttributeValue = std::variant, - stack_vector, - OperatorType, - Activation, - ff_dim_t, - unsigned long long, - AggregateOp, - stack_vector, - std::optional, - PoolOp, - TensorShape, - DataType>; +using OperatorAttributeValue = + std::variant, + stack_vector, + OperatorType, + Activation, + ff_dim_t, + unsigned long long, + AggregateOp, + stack_vector, + std::optional, + PoolOp, + TensorShape, + DataType>; FF_VISITABLE_STRUCT(ListIndexAccess, attribute_key, diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 6f933dd300..296a975626 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -14,7 +14,8 @@ std::optional std::optional const &v) { if (!v.has_value() || !std::holds_alternative>(v.value()) || - !std::holds_alternative>(v.value())) { + !std::holds_alternative>( + v.value())) { return std::nullopt; } @@ -62,7 +63,8 @@ std::optional struct EvaluateOperatorAttributeExpr { EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} - std::optional operator()(OperatorAttributeKey const &key) { + std::optional + operator()(OperatorAttributeKey const &key) { return get_attribute(this->attrs.attrs, key); } @@ -148,8 +150,8 @@ std::optional template std::optional satisfies(ConstraintType constraint_type, - V const &constraint_value, - std::optional const &maybe_attribute_value) { + V const &constraint_value, + std::optional const &maybe_attribute_value) { if (!maybe_attribute_value.has_value()) { return std::nullopt; } @@ -167,14 +169,14 @@ std::optional satisfies(ConstraintType constraint_type, } std::optional satisfies(ParallelTensor const &tensor_shape, - TensorAttributeConstraint const &constraint) { + TensorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); return satisfies( constraint.constraint_type, constraint.attribute_value, value); } std::optional satisfies(Operator const ¶ms, - OperatorAttributeConstraint const &constraint) { + OperatorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(params, constraint.attribute_expr); OperatorAttributeValue v = value.value(); return satisfies( @@ -183,7 +185,7 @@ std::optional satisfies(Operator const ¶ms, template std::optional optional_all_of(Container const &container, - Function const &func) { + Function const &func) { for (auto const &element : container) { std::optional condition = func(element); if (!condition.has_value()) { @@ -198,7 +200,7 @@ std::optional optional_all_of(Container const &container, } std::optional satisfies(Operator const ¶ms, - OperatorPattern const &pattern) { + OperatorPattern const &pattern) { return optional_all_of(pattern.attribute_constraints, [&](OperatorAttributeConstraint const &c) { return satisfies(params, c); @@ -206,7 +208,7 @@ std::optional satisfies(Operator const ¶ms, } std::optional satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern) { + ParallelTensorPattern const &pattern) { return optional_all_of( pattern.attribute_constraints, [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); diff --git a/lib/substitutions/src/operator_attributes.cc b/lib/substitutions/src/operator_attributes.cc index 76533507a3..8bd8688194 100644 --- a/lib/substitutions/src/operator_attributes.cc +++ b/lib/substitutions/src/operator_attributes.cc @@ -4,7 +4,7 @@ namespace FlexFlow { std::optional get_attribute(BatchMatmulAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -12,7 +12,7 @@ std::optional get_attribute(BatchMatmulAttrs const &p, } std::optional get_attribute(CastAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: return p.dtype; @@ -22,7 +22,7 @@ std::optional get_attribute(CastAttrs const &p, } std::optional get_attribute(CombineAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.combine_dim; @@ -34,7 +34,7 @@ std::optional get_attribute(CombineAttrs const &p, } std::optional get_attribute(ConcatAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.axis; @@ -44,7 +44,7 @@ std::optional get_attribute(ConcatAttrs const &p, } std::optional get_attribute(Conv2DAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: return p.kernel_h; @@ -70,7 +70,7 @@ std::optional get_attribute(Conv2DAttrs const &p, } std::optional get_attribute(ElementBinaryAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -78,15 +78,15 @@ std::optional get_attribute(ElementBinaryAttrs const &p, } std::optional get_attribute(ElementUnaryAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; } } -std::optional get_attribute(ElementScalarUnaryAttrs const &p, - OperatorAttributeKey key) { +std::optional + get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -94,7 +94,7 @@ std::optional get_attribute(ElementScalarUnaryAttrs cons } std::optional get_attribute(DropoutAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -102,7 +102,7 @@ std::optional get_attribute(DropoutAttrs const &p, } std::optional get_attribute(EmbeddingAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: return p.data_type; @@ -118,7 +118,7 @@ std::optional get_attribute(EmbeddingAttrs const &p, } std::optional get_attribute(FlatAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -126,7 +126,7 @@ std::optional get_attribute(FlatAttrs const &p, } std::optional get_attribute(GatherAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.dim; @@ -136,7 +136,7 @@ std::optional get_attribute(GatherAttrs const &p, } std::optional get_attribute(LayerNormAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -144,7 +144,7 @@ std::optional get_attribute(LayerNormAttrs const &p, } std::optional get_attribute(LinearAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; @@ -163,8 +163,8 @@ std::optional get_attribute(LinearAttrs const &p, } } -std::optional get_attribute(MultiHeadAttentionAttrs const &p, - OperatorAttributeKey key) { +std::optional + get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::NUM_HEADS: return p.num_heads; @@ -176,7 +176,7 @@ std::optional get_attribute(MultiHeadAttentionAttrs cons } std::optional get_attribute(Pool2DAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: return p.kernel_h; @@ -200,7 +200,7 @@ std::optional get_attribute(Pool2DAttrs const &p, } std::optional get_attribute(ReduceAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -208,7 +208,7 @@ std::optional get_attribute(ReduceAttrs const &p, } std::optional get_attribute(ReductionAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.reduction_dim; @@ -220,7 +220,7 @@ std::optional get_attribute(ReductionAttrs const &p, } std::optional get_attribute(RepartitionAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.repartition_dim; @@ -232,7 +232,7 @@ std::optional get_attribute(RepartitionAttrs const &p, } std::optional get_attribute(ReplicateAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.replicate_dim; @@ -244,7 +244,7 @@ std::optional get_attribute(ReplicateAttrs const &p, } std::optional get_attribute(ReshapeAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -252,7 +252,7 @@ std::optional get_attribute(ReshapeAttrs const &p, } std::optional get_attribute(SplitAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.axis; @@ -262,7 +262,7 @@ std::optional get_attribute(SplitAttrs const &p, } std::optional get_attribute(SoftmaxAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.dim; @@ -272,7 +272,7 @@ std::optional get_attribute(SoftmaxAttrs const &p, } std::optional get_attribute(TopKAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -280,7 +280,7 @@ std::optional get_attribute(TopKAttrs const &p, } std::optional get_attribute(TransposeAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PERMUTATION: return p.perm; @@ -309,7 +309,7 @@ struct GetOpType { }; std::optional get_attribute(PCGOperatorAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { if (key == OperatorAttributeKey::OP_TYPE) { return std::visit(GetOpType{}, p); } diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 4f6572948a..15816185ee 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -120,14 +120,15 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, switch (op_type) { case Op::BATCHMATMUL: return Operator{ - BatchMatmulAttrs{ - std::get(assignments.at(OperatorAttributeKey::A_SEQ_LENGTH_DIM)), - std::get(assignments.at(OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, + BatchMatmulAttrs{std::get(assignments.at( + OperatorAttributeKey::A_SEQ_LENGTH_DIM)), + std::get(assignments.at( + OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, std::nullopt}; case Op::BATCHNORM: - return Operator{ - BatchNormAttrs{std::get(assignments.at(OperatorAttributeKey::RELU))}, - std::nullopt}; + return Operator{BatchNormAttrs{std::get( + assignments.at(OperatorAttributeKey::RELU))}, + std::nullopt}; case Op::CAST: return Operator{CastAttrs{std::get( assignments.at(OperatorAttributeKey::DATA_TYPE))}, @@ -135,13 +136,13 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::CONCAT: return Operator{ ConcatAttrs{ - std::get(assignments.at(OperatorAttributeKey::AXIS)), + std::get(assignments.at(OperatorAttributeKey::AXIS)), std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, std::nullopt}; case Op::CONV2D: return Operator{ Conv2DAttrs{ - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), @@ -149,15 +150,16 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, std::get(assignments.at(OperatorAttributeKey::PADDING_H)), std::get(assignments.at(OperatorAttributeKey::PADDING_W)), std::get(assignments.at(OperatorAttributeKey::GROUPS)), - std::get(assignments.at(OperatorAttributeKey::ACTIVATION)), + std::get( + assignments.at(OperatorAttributeKey::ACTIVATION)), std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, std::nullopt}; case Op::DROPOUT: - return Operator{ - DropoutAttrs{std::get(assignments.at(OperatorAttributeKey::RATE)), - std::get( - assignments.at(OperatorAttributeKey::SEED))}, - std::nullopt}; + return Operator{DropoutAttrs{std::get(assignments.at( + OperatorAttributeKey::RATE)), + std::get(assignments.at( + OperatorAttributeKey::SEED))}, + std::nullopt}; case Op::EW_ADD: case Op::EW_DIV: case Op::EW_EQUAL: @@ -168,13 +170,13 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::EW_MUL: case Op::EW_SUB: return Operator{ - ElementBinaryAttrs{ - op_type, - std::get(assignments.at(OperatorAttributeKey::DATA_TYPE)), - std::get( - assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_LHS)), - std::get( - assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, + ElementBinaryAttrs{op_type, + std::get(assignments.at( + OperatorAttributeKey::DATA_TYPE)), + std::get(assignments.at( + OperatorAttributeKey::SHOULD_BROADCAST_LHS)), + std::get(assignments.at( + OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, std::nullopt}; case Op::SCALAR_ADD: case Op::SCALAR_FLOOR_DIV: @@ -197,23 +199,24 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::EMBEDDING: return Operator{ EmbeddingAttrs{ - std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), + std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), std::get(assignments.at(OperatorAttributeKey::AGGR)), - std::get(assignments.at(OperatorAttributeKey::OP_TYPE))}, + std::get( + assignments.at(OperatorAttributeKey::OP_TYPE))}, std::nullopt}; case Op::FLAT: return Operator{FlatAttrs{}, std::nullopt}; case Op::GATHER: - return Operator{ - GatherAttrs{std::get(assignments.at(OperatorAttributeKey::DIM))}, - std::nullopt}; + return Operator{GatherAttrs{std::get( + assignments.at(OperatorAttributeKey::DIM))}, + std::nullopt}; case Op::INPUT: return Operator{InputAttrs{}, std::nullopt}; case Op::LAYERNORM: return Operator{ LayerNormAttrs{ - std::get>( + std::get>( assignments.at(OperatorAttributeKey::AXES)), std::get( assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), @@ -222,31 +225,34 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::LINEAR: return Operator{ LinearAttrs{ - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), - std::get(assignments.at(OperatorAttributeKey::DATA_TYPE)), - std::get(assignments.at(OperatorAttributeKey::ACTIVATION)), + std::get( + assignments.at(OperatorAttributeKey::DATA_TYPE)), + std::get( + assignments.at(OperatorAttributeKey::ACTIVATION)), std::get>( assignments.at(OperatorAttributeKey::REGULARIZER))}, std::nullopt}; case Op::MULTIHEAD_ATTENTION: return Operator{ MultiHeadAttentionAttrs{ - std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), + std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), std::get(assignments.at(OperatorAttributeKey::VDIM)), std::get(assignments.at(OperatorAttributeKey::DROPOUT)), std::get(assignments.at(OperatorAttributeKey::BIAS)), std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), - std::get(assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, + std::get( + assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, std::nullopt}; case Op::NOOP: return Operator{NoopAttrs{}, std::nullopt}; case Op::POOL2D: return Operator{ Pool2DAttrs{ - std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), @@ -265,7 +271,7 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::REDUCE_SUM: return Operator{ ReduceAttrs{ - std::get>( + std::get>( assignments.at(OperatorAttributeKey::AXES)), op_type, std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, @@ -280,9 +286,10 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, std::nullopt}; case Op::SPLIT: return Operator{ - SplitAttrs{std::get>( - assignments.at(OperatorAttributeKey::SPLITS)), - std::get(assignments.at(OperatorAttributeKey::AXIS))}, + SplitAttrs{ + std::get>( + assignments.at(OperatorAttributeKey::SPLITS)), + std::get(assignments.at(OperatorAttributeKey::AXIS))}, std::nullopt}; case Op::SOFTMAX: return Operator{SoftmaxAttrs{std::get( @@ -290,8 +297,9 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, std::nullopt}; case Op::TOPK: return Operator{ - TopKAttrs{std::get(assignments.at(OperatorAttributeKey::K)), - std::get(assignments.at(OperatorAttributeKey::SORTED))}, + TopKAttrs{ + std::get(assignments.at(OperatorAttributeKey::K)), + std::get(assignments.at(OperatorAttributeKey::SORTED))}, std::nullopt}; case Op::TRANSPOSE: return Operator{ @@ -299,28 +307,31 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, assignments.at(OperatorAttributeKey::PERMUTATION))}, std::nullopt}; case Op::COMBINE: - return Operator{ - CombineAttrs{ - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; + return Operator{CombineAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, + std::nullopt}; case Op::REDUCTION: return Operator{ - ReductionAttrs{ - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + ReductionAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; case Op::REPARTITION: return Operator{ - RepartitionAttrs{ - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + RepartitionAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; case Op::REPLICATE: return Operator{ - ReplicateAttrs{ - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + ReplicateAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; default: throw mk_runtime_error("Unknown Operator"); diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 40ac0a4a1c..0332a331b2 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -4,9 +4,9 @@ #include "utils/bidict.h" #include "utils/required_core.h" #include "utils/type_traits_core.h" +#include #include #include -#include namespace FlexFlow { diff --git a/lib/utils/include/utils/dot_file.h b/lib/utils/include/utils/dot_file.h index 6cf06d12a7..1fd9813646 100644 --- a/lib/utils/include/utils/dot_file.h +++ b/lib/utils/include/utils/dot_file.h @@ -5,12 +5,12 @@ #include #include #include +#include #include #include #include #include #include -#include template class DotFile { diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 9aed91f107..856dd4434e 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -41,8 +41,7 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraphView); @@ -97,13 +96,11 @@ struct NodeLabelledMultiDiGraph NodeLabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraph); diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 0fea57cab7..c864c7dacf 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -55,8 +55,7 @@ struct NodeLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -121,13 +120,11 @@ struct NodeLabelledOpenMultiDiGraph NodeLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index 8aab0320b5..ac5648c2e1 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -47,8 +47,7 @@ struct OutputLabelledMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -119,13 +118,11 @@ struct OutputLabelledMultiDiGraph private: Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index aaf051c83d..bc4fe3d828 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -64,8 +64,7 @@ struct OutputLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -145,13 +144,11 @@ struct OutputLabelledOpenMultiDiGraph OutputLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index c6d1521471..34dabb5391 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -47,8 +47,7 @@ struct LabelledMultiDiGraphView : NodeLabelledMultiDiGraphView(ptr) {} Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraphView); @@ -104,13 +103,11 @@ struct LabelledMultiDiGraph LabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraph); diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index feb263335a..272caaffde 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -194,12 +194,13 @@ auto narrow(Container const &c) { return transform(c, [](VariantIn const &e) { return get(e); }); } -template , VariantIn>::value>> +template < + typename T1, + typename T2, + typename... Trest, + typename VariantIn, + typename = std::enable_if_t< + !is_subeq_variant, VariantIn>::value>> std::optional> narrow(VariantIn const &v) { return visit(VariantNarrowFunctor>{}, v); } diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index ecad1db3f0..bdfe5ff599 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -14,8 +14,7 @@ std::unordered_set } IDiGraphView const &DiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } Node DiGraph::add_node() { diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 41ae3e1aa3..771e01e573 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -66,8 +66,7 @@ std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { } IMultiDiGraph const &MultiDiGraph::get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } IMultiDiGraph &MultiDiGraph::get_ptr() { diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 1494f0ac27..541ff40920 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -4,7 +4,8 @@ TEST_CASE("widen and narrow functions") { SUBCASE("widen function") { std::variant v1 = 42; - std::variant result = widen>(v1); + std::variant result = + widen>(v1); std::variant expected = 42; CHECK(result == expected); } @@ -12,7 +13,8 @@ TEST_CASE("widen and narrow functions") { SUBCASE("narrow function fail") { std::variant v2 = 3.14; // this is a doule, because 3.14 default to double - std::optional> result = narrow>(v2); + std::optional> result = + narrow>(v2); std::optional> expected = float(3.14); CHECK(!result.has_value()); // result should be empty due to narrowing } @@ -20,14 +22,16 @@ TEST_CASE("widen and narrow functions") { SUBCASE("narrow function success") { std::variant v2 = 3.14; // this is a doule, because 3.14 default to double - std::optional> result = narrow>(v2); + std::optional> result = + narrow>(v2); std::optional> expected = 3.14; CHECK(result == expected); // } SUBCASE("cast function") { std::variant v3 = 42; - std::optional> result = cast>(v3); + std::optional> result = + cast>(v3); std::optional> expected = 42; CHECK(result == expected); } @@ -53,7 +57,8 @@ TEST_CASE("casting and widening a variant") { std::variant wider_variant; // Perform the cast operation - std::optional> cast_result = cast>(smaller_variant); + std::optional> cast_result = + cast>(smaller_variant); REQUIRE(cast_result); // Ensure the cast was successful // Perform the widening operation From d2eb505120fb3a09abfb6811dee106e7f24ba7f9 Mon Sep 17 00:00:00 2001 From: wmdi Date: Sun, 24 Mar 2024 17:17:22 -0400 Subject: [PATCH 28/32] upd tests name to match ci --- lib/compiler/test/CMakeLists.txt | 2 +- lib/substitutions/test/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index cbd7e233c0..13b1fd3b83 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -1,6 +1,6 @@ ff_add_test_executable( NAME - compiler-test + compiler-tests SRC_PATTERNS src/*.cc PRIVATE_INCLUDE diff --git a/lib/substitutions/test/CMakeLists.txt b/lib/substitutions/test/CMakeLists.txt index d7e35ef9af..cfd6383e95 100644 --- a/lib/substitutions/test/CMakeLists.txt +++ b/lib/substitutions/test/CMakeLists.txt @@ -1,6 +1,6 @@ ff_add_test_executable( NAME - substitutions-test + substitutions-tests SRC_PATTERNS src/*.cc PRIVATE_INCLUDE From 371324a505a5f61aca276ecf621e2eb862f2cb5c Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 26 Mar 2024 16:19:08 -0700 Subject: [PATCH 29/32] Add TEST_SUITE declaration to make tests findable by ctest --- .../test/src/test_labelled_open_graph.cc | 240 +++---- lib/compiler/test/src/test_machine_mapping.cc | 36 +- lib/compiler/test/src/test_open_graph.cc | 136 ++-- lib/compiler/test/src/test_optimal_cost.cc | 102 +-- lib/compiler/test/src/test_unity_algorithm.cc | 45 +- .../test/src/test_pattern_matches.cc | 68 +- .../test/src/test_substitution.cc | 240 +++---- lib/utils/test/src/test_algorithms.cc | 408 +++++------ lib/utils/test/src/test_bidict.cc | 100 +-- lib/utils/test/src/test_containers.cc | 651 +++++++++--------- lib/utils/test/src/test_cow_ptr.cc | 66 +- .../src/test_deduplicated_priority_queue.cc | 48 +- lib/utils/test/src/test_disjoint_set.cc | 76 +- lib/utils/test/src/test_dot_file.cc | 76 +- lib/utils/test/src/test_format.cc | 46 +- lib/utils/test/src/test_hash.cc | 20 +- lib/utils/test/src/test_multidigraph.cc | 138 ++-- lib/utils/test/src/test_random_utils.cc | 72 +- lib/utils/test/src/test_sequence.cc | 308 +++++---- lib/utils/test/src/test_stack_map.cc | 88 +-- lib/utils/test/src/test_stack_string.cc | 124 ++-- lib/utils/test/src/test_stack_vector.cc | 142 ++-- lib/utils/test/src/test_tuple.cc | 118 ++-- lib/utils/test/src/test_type_index.cc | 42 +- lib/utils/test/src/test_undirected_graph.cc | 54 +- lib/utils/test/src/test_variant.cc | 110 +-- lib/utils/test/src/test_vector.cc | 46 +- 27 files changed, 1828 insertions(+), 1772 deletions(-) diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index c59d7ee78a..e3498a769a 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -4,123 +4,125 @@ using namespace FlexFlow; -TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { - auto g = OpenMultiDiGraph::create(); - - Node n0 = g.add_node(); - Node n1 = g.add_node(); - Node n2 = g.add_node(); - Node n3 = g.add_node(); - Node n4 = g.add_node(); - - NodePort p0 = g.add_node_port(); - NodePort p1 = g.add_node_port(); - NodePort p2 = g.add_node_port(); - NodePort p3 = g.add_node_port(); - NodePort p4 = g.add_node_port(); - NodePort p5 = g.add_node_port(); - NodePort p6 = g.add_node_port(); - NodePort p7 = g.add_node_port(); - NodePort p8 = g.add_node_port(); - NodePort p9 = g.add_node_port(); - - MultiDiEdge e0{n1, p1, n0, p0}; - MultiDiEdge e1{n2, p2, n0, p0}; - MultiDiEdge e2{n3, p5, n1, p3}; - MultiDiEdge e3{n3, p6, n2, p4}; - MultiDiEdge e4{n4, p8, n3, p7}; - OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; - - g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - g.add_edge(e5); - - std::unordered_set node_set0{n3, n4}; - - auto subgraph0 = get_subgraph(g, node_set0); - auto subgraph1 = get_subgraph(g, node_set0); - auto subgraph2 = get_subgraph(g, node_set0); - auto subgraph3 = get_subgraph(g, node_set0); - - CHECK(bool(get_nodes(subgraph0) == node_set0)); - CHECK(bool(get_nodes(subgraph1) == node_set0)); - CHECK(bool(get_nodes(subgraph2) == node_set0)); - CHECK(bool(get_nodes(subgraph3) == node_set0)); - - std::unordered_set input_set{split_edge(e2).second, - split_edge(e3).second}; - std::unordered_set output_set{e5}; - - CHECK(bool(get_open_inputs(subgraph0) == input_set)); - CHECK(bool(get_open_inputs(subgraph1) == input_set)); - CHECK(bool(get_open_inputs(subgraph2).empty())); - CHECK(bool(get_open_inputs(subgraph3).empty())); - - CHECK(bool(get_open_outputs(subgraph0) == output_set)); - CHECK(bool(get_open_outputs(subgraph1).empty())); - CHECK(bool(get_open_outputs(subgraph2) == output_set)); - CHECK(bool(get_open_outputs(subgraph3).empty())); - - CHECK(bool(get_edges(subgraph0) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4, e5})); - CHECK(bool(get_edges(subgraph1) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4})); - CHECK(bool(get_edges(subgraph2) == - std::unordered_set{e4, e5})); - CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); - - CHECK(bool(get_closed_sources(subgraph2) == std::unordered_set{n3})); -} - -TEST_CASE("view OutputLabelledMultiDiGraph as open") { - OutputLabelledMultiDiGraph g = - OutputLabelledMultiDiGraph::create< - UnorderedOutputLabelledMultiDiGraph>(); - - Node n0 = g.add_node(0); - Node n1 = g.add_node(1); - - NodePort p0 = g.add_node_port(); - NodePort p1 = g.add_node_port(); - - MultiDiEdge e0{n1, p1, n0, p0}; - - g.add_edge(e0); - g.add_output(e0, 2); - - CHECK(bool(get_edges(g).size() == 1)); - - OutputLabelledOpenMultiDiGraphView open_graph = - view_output_labelled_as_output_labelled_open(g); - - CHECK(bool(open_graph.at(n0) == 0)); - CHECK(bool(open_graph.at(n1) == 1)); - CHECK(bool(open_graph.at(e0) == 2)); - - CHECK(get_edges(open_graph).size() == 1); -} - -TEST_CASE("OutputLabelledOpenMultiDiGraph") { - OutputLabelledOpenMultiDiGraph g = - OutputLabelledOpenMultiDiGraph::create< - UnorderedOutputLabelledOpenMultiDiGraph>(); - - Node n0 = g.add_node(0); - Node n1 = g.add_node(1); - - NodePort p0 = g.add_node_port(); - NodePort p1 = g.add_node_port(); - - MultiDiEdge e0{n1, p1, n0, p0}; - - g.add_edge(e0); - g.add_label(e0, 2); - - CHECK(bool(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1)); - CHECK(bool(get_edges(g).size() == 1)); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { + auto g = OpenMultiDiGraph::create(); + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + Node n3 = g.add_node(); + Node n4 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + NodePort p2 = g.add_node_port(); + NodePort p3 = g.add_node_port(); + NodePort p4 = g.add_node_port(); + NodePort p5 = g.add_node_port(); + NodePort p6 = g.add_node_port(); + NodePort p7 = g.add_node_port(); + NodePort p8 = g.add_node_port(); + NodePort p9 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + MultiDiEdge e1{n2, p2, n0, p0}; + MultiDiEdge e2{n3, p5, n1, p3}; + MultiDiEdge e3{n3, p6, n2, p4}; + MultiDiEdge e4{n4, p8, n3, p7}; + OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + std::unordered_set node_set0{n3, n4}; + + auto subgraph0 = get_subgraph(g, node_set0); + auto subgraph1 = get_subgraph(g, node_set0); + auto subgraph2 = get_subgraph(g, node_set0); + auto subgraph3 = get_subgraph(g, node_set0); + + CHECK(bool(get_nodes(subgraph0) == node_set0)); + CHECK(bool(get_nodes(subgraph1) == node_set0)); + CHECK(bool(get_nodes(subgraph2) == node_set0)); + CHECK(bool(get_nodes(subgraph3) == node_set0)); + + std::unordered_set input_set{split_edge(e2).second, + split_edge(e3).second}; + std::unordered_set output_set{e5}; + + CHECK(bool(get_open_inputs(subgraph0) == input_set)); + CHECK(bool(get_open_inputs(subgraph1) == input_set)); + CHECK(bool(get_open_inputs(subgraph2).empty())); + CHECK(bool(get_open_inputs(subgraph3).empty())); + + CHECK(bool(get_open_outputs(subgraph0) == output_set)); + CHECK(bool(get_open_outputs(subgraph1).empty())); + CHECK(bool(get_open_outputs(subgraph2) == output_set)); + CHECK(bool(get_open_outputs(subgraph3).empty())); + + CHECK(bool(get_edges(subgraph0) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4, e5})); + CHECK(bool(get_edges(subgraph1) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4})); + CHECK(bool(get_edges(subgraph2) == + std::unordered_set{e4, e5})); + CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); + + CHECK(bool(get_closed_sources(subgraph2) == std::unordered_set{n3})); + } + + TEST_CASE("view OutputLabelledMultiDiGraph as open") { + OutputLabelledMultiDiGraph g = + OutputLabelledMultiDiGraph::create< + UnorderedOutputLabelledMultiDiGraph>(); + + Node n0 = g.add_node(0); + Node n1 = g.add_node(1); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + + g.add_edge(e0); + g.add_output(e0, 2); + + CHECK(bool(get_edges(g).size() == 1)); + + OutputLabelledOpenMultiDiGraphView open_graph = + view_output_labelled_as_output_labelled_open(g); + + CHECK(bool(open_graph.at(n0) == 0)); + CHECK(bool(open_graph.at(n1) == 1)); + CHECK(bool(open_graph.at(e0) == 2)); + + CHECK(get_edges(open_graph).size() == 1); + } + + TEST_CASE("OutputLabelledOpenMultiDiGraph") { + OutputLabelledOpenMultiDiGraph g = + OutputLabelledOpenMultiDiGraph::create< + UnorderedOutputLabelledOpenMultiDiGraph>(); + + Node n0 = g.add_node(0); + Node n1 = g.add_node(1); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + + g.add_edge(e0); + g.add_label(e0, 2); + + CHECK(bool(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1)); + CHECK(bool(get_edges(g).size() == 1)); + } } diff --git a/lib/compiler/test/src/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc index b2abc6929d..365ed3e1db 100644 --- a/lib/compiler/test/src/test_machine_mapping.cc +++ b/lib/compiler/test/src/test_machine_mapping.cc @@ -1,21 +1,23 @@ -// #include "doctest/doctest.h" -// #include "test_generator.h" +#include "doctest/doctest.h" +#include "test_generator.h" -// TEST_CASE("MachineMapping::combine") { -// rc::check([](MachineMapping const &m0, MachineMapping const &m1) { -// RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); +TEST_SUITE(FF_TEST_SUITE) { + // TEST_CASE("MachineMapping::combine") { + // rc::check([](MachineMapping const &m0, MachineMapping const &m1) { + // RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); -// MachineMapping comb = MachineMapping::combine(m0, m1); + // MachineMapping comb = MachineMapping::combine(m0, m1); -// RC_ASSERT(comb.machine_views.size() == -// m0.machine_views.size() + m1.machine_views.size()); -// RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); -// RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); -// }); -// } + // RC_ASSERT(comb.machine_views.size() == + // m0.machine_views.size() + m1.machine_views.size()); + // RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); + // RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); + // }); + // } -// TEST_CASE("OptimalCostResult::infinity") { -// rc::check([](OptimalCostResult const &c) { -// RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); -// }); -// } + // TEST_CASE("OptimalCostResult::infinity") { + // rc::check([](OptimalCostResult const &c) { + // RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); + // }); + // } +} diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc index 7436f213d7..db3630d316 100644 --- a/lib/compiler/test/src/test_open_graph.cc +++ b/lib/compiler/test/src/test_open_graph.cc @@ -4,71 +4,73 @@ using namespace FlexFlow; -TEST_CASE("get_source_sink_open_graph") { - OpenMultiDiGraph g = OpenMultiDiGraph::create(); - - Node n0 = g.add_node(); - NodePort p0 = g.add_node_port(); - InputMultiDiEdge e0{ - n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; - g.add_edge(e0); - - CHECK(bool(get_closed_sources(g) == std::unordered_set{})); - CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - - CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); - CHECK(bool(get_open_sinks(g) == std::unordered_set{})); -} - -TEST_CASE("get_source_sink_open_graph:unconnected") { - OpenMultiDiGraph g = OpenMultiDiGraph::create(); - - Node n0 = g.add_node(); - Node n1 = g.add_node(); - - NodePort p0 = g.add_node_port(); - NodePort p1 = g.add_node_port(); - - InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; - OutputMultiDiEdge e1{n1, p1, std::make_pair(p1.value(), p1.value())}; - g.add_edge(e0); - g.add_edge(e1); - - /* - g: ->n0 - n1-> - */ - - CHECK(bool(get_closed_sources(g) == std::unordered_set{n1})); - CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - - CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); - CHECK(bool(get_open_sinks(g) == std::unordered_set{n1})); -} - -TEST_CASE("get_cut") { - auto g = OpenMultiDiGraph::create(); - - std::vector ns = add_nodes(g, 5); - - MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; - MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; - MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; - MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; - MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; - OutputMultiDiEdge e5{ - ns[4], g.add_node_port(), std::make_pair(ns[4].value(), ns[4].value())}; - - g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - g.add_edge(e5); - - GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; - CHECK(bool(get_cut_set(g, gs0) == std::unordered_set{e1, e2})); - - GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; - CHECK(bool(get_cut_set(g, gs1) == std::unordered_set{e3, e4})); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_source_sink_open_graph") { + OpenMultiDiGraph g = OpenMultiDiGraph::create(); + + Node n0 = g.add_node(); + NodePort p0 = g.add_node_port(); + InputMultiDiEdge e0{ + n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; + g.add_edge(e0); + + CHECK(bool(get_closed_sources(g) == std::unordered_set{})); + CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); + + CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); + CHECK(bool(get_open_sinks(g) == std::unordered_set{})); + } + + TEST_CASE("get_source_sink_open_graph:unconnected") { + OpenMultiDiGraph g = OpenMultiDiGraph::create(); + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + + InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; + OutputMultiDiEdge e1{n1, p1, std::make_pair(p1.value(), p1.value())}; + g.add_edge(e0); + g.add_edge(e1); + + /* + g: ->n0 + n1-> + */ + + CHECK(bool(get_closed_sources(g) == std::unordered_set{n1})); + CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); + + CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); + CHECK(bool(get_open_sinks(g) == std::unordered_set{n1})); + } + + TEST_CASE("get_cut") { + auto g = OpenMultiDiGraph::create(); + + std::vector ns = add_nodes(g, 5); + + MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; + MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; + MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; + MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; + MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; + OutputMultiDiEdge e5{ + ns[4], g.add_node_port(), std::make_pair(ns[4].value(), ns[4].value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; + CHECK(bool(get_cut_set(g, gs0) == std::unordered_set{e1, e2})); + + GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; + CHECK(bool(get_cut_set(g, gs1) == std::unordered_set{e3, e4})); + } } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 5f5f7d093e..da303e3ccc 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -4,63 +4,65 @@ using namespace FlexFlow; -// Rapidcheck infrastructures for graphs does not work for now -/* -Tests whether optimal_cost can give a valid result given random PCG, trivial -allowed machine views, trivial cost estimator and random machine specification. -*/ -// TEST_CASE("optimal_cost") { -// auto test_allowed_machine_views = [](Operator const &, -// MachineSpecification const &) { -// return std::unordered_set{make_1d_machine_view(0, 1, 1)}; -// }; -// rc::check([](ParallelComputationGraph const &g, -// MachineSpecification const &machine_spec) { -// OptimalCostCache cached_subgraph_costs; -// OptimalCostResult result = optimal_cost(g, -// test_allowed_machine_views, -// TestCostEstimator{}, -// machine_spec, -// cached_subgraph_costs); -// RC_ASSERT(result.runtime > 0); -// RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); -// }); -// } +TEST_SUITE(FF_TEST_SUITE) { + // Rapidcheck infrastructures for graphs does not work for now + /* + Tests whether optimal_cost can give a valid result given random PCG, trivial + allowed machine views, trivial cost estimator and random machine specification. + */ + // TEST_CASE("optimal_cost") { + // auto test_allowed_machine_views = [](Operator const &, + // MachineSpecification const &) { + // return std::unordered_set{make_1d_machine_view(0, 1, 1)}; + // }; + // rc::check([](ParallelComputationGraph const &g, + // MachineSpecification const &machine_spec) { + // OptimalCostCache cached_subgraph_costs; + // OptimalCostResult result = optimal_cost(g, + // test_allowed_machine_views, + // TestCostEstimator{}, + // machine_spec, + // cached_subgraph_costs); + // RC_ASSERT(result.runtime > 0); + // RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); + // }); + // } -TEST_CASE("optimal_cost_0") { - auto pcg = - OutputLabelledMultiDiGraph::template create< - UnorderedOutputLabelledMultiDiGraph>(); + TEST_CASE("optimal_cost_0") { + auto pcg = + OutputLabelledMultiDiGraph::template create< + UnorderedOutputLabelledMultiDiGraph>(); - Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); - Node n1 = pcg.add_node(Operator{ - LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, - "linear"}); + Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); + Node n1 = pcg.add_node(Operator{ + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, + "linear"}); - MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; - pcg.add_edge(e); - pcg.add_output(e, - ParallelTensor(ParallelTensorDims({2, 1}), - DataType::FLOAT, - CreateGrad::YES)); + MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; + pcg.add_edge(e); + pcg.add_output(e, + ParallelTensor(ParallelTensorDims({2, 1}), + DataType::FLOAT, + CreateGrad::YES)); - auto test_allowed_machine_views = [](Operator const &, - MachineSpecification const &) { - return std::unordered_set{ - make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; - }; + auto test_allowed_machine_views = [](Operator const &, + MachineSpecification const &) { + return std::unordered_set{ + make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; + }; - CostEstimator estimator = CostEstimator::create(); + CostEstimator estimator = CostEstimator::create(); - MachineSpecification machine_spec{1, 1, 1, 1, 1}; + MachineSpecification machine_spec{1, 1, 1, 1, 1}; - OptimalCostCache cached_results; + OptimalCostCache cached_results; - OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), - test_allowed_machine_views, - estimator, - machine_spec, - cached_results); + OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), + test_allowed_machine_views, + estimator, + machine_spec, + cached_results); - CHECK(bool(result.runtime > 0)); + CHECK(bool(result.runtime > 0)); + } } diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc index c39b3ef14f..b8fde91c51 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/test_unity_algorithm.cc @@ -1,25 +1,28 @@ #include "compiler/unity_algorithm.h" #include "test_cost_estimator.h" #include "test_generator.h" +#include "doctest/doctest.h" -// Rapidcheck does not work for now -// TEST_CASE("graph_optimize") { -// rc::check([](ComputationGraph const &g, -// float alpha, -// int budget, -// float threshold, -// int max_num_ops) { -// Strategy s = graph_optimize( -// g, -// TestCostEstimator{}, -// MachineSpecification{1, 1, 4, 0.1, 0.2}, -// [](Operator const &, MachineSpecification const &) { -// return std::unordered_set{make_1d_machine_view(0, 1, -// 1)}; -// }, -// OptimizerConfig{alpha, budget, threshold, max_num_ops}); -// RC_ASSERT(get_nodes(s.pcg).size() > 0); -// RC_ASSERT(s.machine_mapping.runtime > 0); -// RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); -// }); -// } +TEST_SUITE(FF_TEST_SUITE) { + // Rapidcheck does not work for now + // TEST_CASE("graph_optimize") { + // rc::check([](ComputationGraph const &g, + // float alpha, + // int budget, + // float threshold, + // int max_num_ops) { + // Strategy s = graph_optimize( + // g, + // TestCostEstimator{}, + // MachineSpecification{1, 1, 4, 0.1, 0.2}, + // [](Operator const &, MachineSpecification const &) { + // return std::unordered_set{make_1d_machine_view(0, 1, + // 1)}; + // }, + // OptimizerConfig{alpha, budget, threshold, max_num_ops}); + // RC_ASSERT(get_nodes(s.pcg).size() > 0); + // RC_ASSERT(s.machine_mapping.runtime > 0); + // RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); + // }); + // } +} diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index cc8a5cd5bd..f1abd5c17e 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -62,46 +62,48 @@ struct Arbitrary { // }); // } -TEST_CASE("find_pattern_matches_small") { - MultiDiGraph g = MultiDiGraph::template create(); - - { - Node n0 = g.add_node(); - Node n1 = g.add_node(); - Node n2 = g.add_node(); - Node n3 = g.add_node(); - - MultiDiEdge e0{n1, g.add_node_port(), n0, g.add_node_port()}; - MultiDiEdge e1{n2, g.add_node_port(), n1, g.add_node_port()}; - MultiDiEdge e2{n3, g.add_node_port(), n2, g.add_node_port()}; - - g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_pattern_matches_small") { + MultiDiGraph g = MultiDiGraph::template create(); - MultiDiGraph sg0 = MultiDiGraph::template create(); + { + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + Node n3 = g.add_node(); - { - Node n0 = sg0.add_node(); - Node n1 = sg0.add_node(); + MultiDiEdge e0{n1, g.add_node_port(), n0, g.add_node_port()}; + MultiDiEdge e1{n2, g.add_node_port(), n1, g.add_node_port()}; + MultiDiEdge e2{n3, g.add_node_port(), n2, g.add_node_port()}; - MultiDiEdge e0{n1, sg0.add_node_port(), n0, sg0.add_node_port()}; + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + } - sg0.add_edge(e0); - } + MultiDiGraph sg0 = MultiDiGraph::template create(); + + { + Node n0 = sg0.add_node(); + Node n1 = sg0.add_node(); + + MultiDiEdge e0{n1, sg0.add_node_port(), n0, sg0.add_node_port()}; + + sg0.add_edge(e0); + } - MatchAdditionalCriterion always_true{ - [](Node const &, Node const &) { return true; }, - [](OpenMultiDiEdge const &, OpenMultiDiEdge const &) { return true; }}; + MatchAdditionalCriterion always_true{ + [](Node const &, Node const &) { return true; }, + [](OpenMultiDiEdge const &, OpenMultiDiEdge const &) { return true; }}; - std::vector matches = find_pattern_matches( - as_openmultidigraph(sg0), as_openmultidigraph(g), always_true); + std::vector matches = find_pattern_matches( + as_openmultidigraph(sg0), as_openmultidigraph(g), always_true); - RC_ASSERT(matches.size() == 3); + RC_ASSERT(matches.size() == 3); - for (MultiDiGraphPatternMatch const &match : matches) { - RC_ASSERT(pattern_matches( - as_openmultidigraph(sg0), as_openmultidigraph(g), match, always_true)); + for (MultiDiGraphPatternMatch const &match : matches) { + RC_ASSERT(pattern_matches( + as_openmultidigraph(sg0), as_openmultidigraph(g), match, always_true)); + } } } diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index 552d46a98f..86ee087a29 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -5,123 +5,125 @@ using namespace FlexFlow; -TEST_CASE("apply_substitution") { - OperatorPattern operator_pattern_n0{ - std::vector{OperatorAttributeConstraint{ - ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; - - ParallelTensorPattern tensor_pattern_e0{ - std::vector{TensorAttributeConstraint{ - ConstraintType::EQUAL, - ListIndexAccess{TensorAttributeKey::DIM_SIZES, 0}, - 2}}}; - - ParallelTensorPattern tensor_pattern_empty{ - std::vector{}}; - - auto ig = OutputLabelledOpenMultiDiGraph:: - create>(); - Node n0 = ig.add_node(operator_pattern_n0); - NodePort p0 = ig.add_node_port(); - InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; - ig.add_edge(e0); - ig.add_label(e0, tensor_pattern_e0); - - RC_ASSERT(get_nodes(ig).size() == 1); - RC_ASSERT(get_edges(ig).size() == 1); - - GraphPattern input_graph{ig}; - - OperatorAttrAssignment op_ass_n1{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REPARTITION}}, - {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, - {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; - - OperatorAttrAssignment op_ass_n2{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::LINEAR}}, - {OperatorAttributeKey::OUT_CHANNELS, - OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, - {OperatorAttributeKey::USE_BIAS, - OperatorAttrAccess{n0, OperatorAttributeKey::USE_BIAS}}, - {OperatorAttributeKey::DATA_TYPE, - OperatorAttrAccess{n0, OperatorAttributeKey::DATA_TYPE}}, - {OperatorAttributeKey::ACTIVATION, - OperatorAttrAccess{n0, OperatorAttributeKey::ACTIVATION}}, - {OperatorAttributeKey::REGULARIZER, - OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; - - OperatorAttrAssignment op_ass_n3{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REDUCTION}}, - {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, - {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; - - auto og = NodeLabelledOpenMultiDiGraph::create< - UnorderedNodeLabelledOpenMultiDiGraph>(); - Node n1 = og.add_node(op_ass_n1); - Node n2 = og.add_node(op_ass_n2); - Node n3 = og.add_node(op_ass_n3); - NodePort p1 = og.add_node_port(); - NodePort p2 = og.add_node_port(); - NodePort p3 = og.add_node_port(); - InputMultiDiEdge e1{n1, p1, {p1.value(), p1.value()}}; - MultiDiEdge e2{n2, p2, n1, p1}; - MultiDiEdge e3{n3, p3, n2, p2}; - og.add_edge(e1); - og.add_edge(e2); - og.add_edge(e3); - OutputGraphExpr output_graph_expr{og}; - - RC_ASSERT(get_nodes(og).size() == 3); - RC_ASSERT(get_edges(og).size() == 3); - - bidict input_mapping; - input_mapping.equate(e0, e1); - bidict output_mapping; - - Substitution substitution{ - input_graph, output_graph_expr, input_mapping, output_mapping}; - - SubParallelComputationGraph pcg = - OutputLabelledOpenMultiDiGraph::create< - UnorderedOutputLabelledOpenMultiDiGraph>(); - - Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); - Node n5 = pcg.add_node(Operator{ - LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, - "linear"}); - NodePort p4 = pcg.add_node_port(); - NodePort p5 = pcg.add_node_port(); - - MultiDiEdge e4{n5, p5, n4, p4}; - pcg.add_edge(e4); - pcg.add_label(e4, - ParallelTensor(ParallelTensorDims({2, 1}), - DataType::FLOAT, - CreateGrad::YES)); - - MatchAdditionalCriterion criterion{ - [&](Node const &pattern_node, Node const &graph_node) { - return operator_satisfies(pcg.at(graph_node), - input_graph.value().at(pattern_node)); - }, - [&](OpenMultiDiEdge const &pattern_edge, - OpenMultiDiEdge const &graph_edge) { - return parallel_tensor_satisfies(pcg.at(graph_edge), - input_graph.value().at(pattern_edge)); - }}; - - RC_ASSERT(criterion.node_criterion(n0, n5)); - - std::vector matches = - find_pattern_matches(input_graph, pcg, criterion); - - RC_ASSERT(matches.size() == 1); - - SubParallelComputationGraph new_pcg = - apply_substitution(pcg, substitution, matches[0]); - - RC_ASSERT(get_nodes(new_pcg).size() == 4); - RC_ASSERT(get_edges(new_pcg).size() == 3); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("apply_substitution") { + OperatorPattern operator_pattern_n0{ + std::vector{OperatorAttributeConstraint{ + ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; + + ParallelTensorPattern tensor_pattern_e0{ + std::vector{TensorAttributeConstraint{ + ConstraintType::EQUAL, + ListIndexAccess{TensorAttributeKey::DIM_SIZES, 0}, + 2}}}; + + ParallelTensorPattern tensor_pattern_empty{ + std::vector{}}; + + auto ig = OutputLabelledOpenMultiDiGraph:: + create>(); + Node n0 = ig.add_node(operator_pattern_n0); + NodePort p0 = ig.add_node_port(); + InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; + ig.add_edge(e0); + ig.add_label(e0, tensor_pattern_e0); + + RC_ASSERT(get_nodes(ig).size() == 1); + RC_ASSERT(get_edges(ig).size() == 1); + + GraphPattern input_graph{ig}; + + OperatorAttrAssignment op_ass_n1{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REPARTITION}}, + {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, + {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; + + OperatorAttrAssignment op_ass_n2{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::LINEAR}}, + {OperatorAttributeKey::OUT_CHANNELS, + OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, + {OperatorAttributeKey::USE_BIAS, + OperatorAttrAccess{n0, OperatorAttributeKey::USE_BIAS}}, + {OperatorAttributeKey::DATA_TYPE, + OperatorAttrAccess{n0, OperatorAttributeKey::DATA_TYPE}}, + {OperatorAttributeKey::ACTIVATION, + OperatorAttrAccess{n0, OperatorAttributeKey::ACTIVATION}}, + {OperatorAttributeKey::REGULARIZER, + OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; + + OperatorAttrAssignment op_ass_n3{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REDUCTION}}, + {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, + {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; + + auto og = NodeLabelledOpenMultiDiGraph::create< + UnorderedNodeLabelledOpenMultiDiGraph>(); + Node n1 = og.add_node(op_ass_n1); + Node n2 = og.add_node(op_ass_n2); + Node n3 = og.add_node(op_ass_n3); + NodePort p1 = og.add_node_port(); + NodePort p2 = og.add_node_port(); + NodePort p3 = og.add_node_port(); + InputMultiDiEdge e1{n1, p1, {p1.value(), p1.value()}}; + MultiDiEdge e2{n2, p2, n1, p1}; + MultiDiEdge e3{n3, p3, n2, p2}; + og.add_edge(e1); + og.add_edge(e2); + og.add_edge(e3); + OutputGraphExpr output_graph_expr{og}; + + RC_ASSERT(get_nodes(og).size() == 3); + RC_ASSERT(get_edges(og).size() == 3); + + bidict input_mapping; + input_mapping.equate(e0, e1); + bidict output_mapping; + + Substitution substitution{ + input_graph, output_graph_expr, input_mapping, output_mapping}; + + SubParallelComputationGraph pcg = + OutputLabelledOpenMultiDiGraph::create< + UnorderedOutputLabelledOpenMultiDiGraph>(); + + Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); + Node n5 = pcg.add_node(Operator{ + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, + "linear"}); + NodePort p4 = pcg.add_node_port(); + NodePort p5 = pcg.add_node_port(); + + MultiDiEdge e4{n5, p5, n4, p4}; + pcg.add_edge(e4); + pcg.add_label(e4, + ParallelTensor(ParallelTensorDims({2, 1}), + DataType::FLOAT, + CreateGrad::YES)); + + MatchAdditionalCriterion criterion{ + [&](Node const &pattern_node, Node const &graph_node) { + return operator_satisfies(pcg.at(graph_node), + input_graph.value().at(pattern_node)); + }, + [&](OpenMultiDiEdge const &pattern_edge, + OpenMultiDiEdge const &graph_edge) { + return parallel_tensor_satisfies(pcg.at(graph_edge), + input_graph.value().at(pattern_edge)); + }}; + + RC_ASSERT(criterion.node_criterion(n0, n5)); + + std::vector matches = + find_pattern_matches(input_graph, pcg, criterion); + + RC_ASSERT(matches.size() == 1); + + SubParallelComputationGraph new_pcg = + apply_substitution(pcg, substitution, matches[0]); + + RC_ASSERT(get_nodes(new_pcg).size() == 4); + RC_ASSERT(get_edges(new_pcg).size() == 3); + } } diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index 2e97496b6b..d3236a7b1c 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -12,232 +12,234 @@ using namespace FlexFlow; -TEST_CASE("MultiDiGraph") { - MultiDiGraph g = MultiDiGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector p = add_node_ports(g, 4); - - MultiDiEdge e0{n[3], p[3], n[0], p[0]}; - MultiDiEdge e1{n[2], p[2], n[1], p[0]}; - MultiDiEdge e2{n[3], p[3], n[1], p[1]}; - MultiDiEdge e3{n[3], p[3], n[2], p[2]}; - - std::vector e = {e0, e1, e2, e3}; - - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[1], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_incoming_edges(g, {n[1]}) == std::unordered_set{}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{e[3]}); - std::unordered_map> expected_result = - std::unordered_map>{ - {n[1], {}}, - {n[2], {n[1]}}, - {n[3], {n[0], n[1], n[2]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); -} - -TEST_CASE("DiGraph") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - std::vector e = { - {n[0], n[3]}, - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[2]}, - }; - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[2], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{}); - auto expected_result = std::unordered_map>{ - {n[1], {n[0]}}, - {n[2], {n[0], n[1]}}, - {n[3], {n[0]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); - - SUBCASE("get_imm_dominators") { - std::unordered_map> result = get_imm_dominators(g); - - std::unordered_map> expected_result = { - {n[2], n[0]}, - {n[1], n[0]}, - {n[3], n[0]}, - {n[0], nullopt}, - }; - CHECK(result == expected_result); - } - - SUBCASE("get_dominators") { - std::unordered_map> expected = { - {n[0], {n[0]}}, - {n[1], {n[0], n[1]}}, - {n[2], {n[0], n[2]}}, - {n[3], {n[0], n[3]}}, - }; - CHECK(get_dominators(g) == expected); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MultiDiGraph") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 4); + std::vector p = add_node_ports(g, 4); + + MultiDiEdge e0{n[3], p[3], n[0], p[0]}; + MultiDiEdge e1{n[2], p[2], n[1], p[0]}; + MultiDiEdge e2{n[3], p[3], n[1], p[1]}; + MultiDiEdge e3{n[3], p[3], n[2], p[2]}; + + std::vector e = {e0, e1, e2, e3}; + + add_edges(g, e); + + CHECK(get_incoming_edges(g, {n[1], n[3]}) == + std::unordered_set{e[0], e[2], e[3]}); + CHECK(get_incoming_edges(g, {n[1]}) == std::unordered_set{}); + CHECK(get_outgoing_edges(g, {n[2], n[3]}) == + std::unordered_set{e[3]}); + std::unordered_map> expected_result = + std::unordered_map>{ + {n[1], {}}, + {n[2], {n[1]}}, + {n[3], {n[0], n[1], n[2]}}, + }; + CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); } - SUBCASE("get_sinks") { - auto expected = std::unordered_set{n[2], n[3]}; - CHECK(get_sinks(g) == expected); - } + TEST_CASE("DiGraph") { + DiGraph g = DiGraph::create(); - SUBCASE("get_bfs") { - std::unordered_set start_points = std::unordered_set{n[0]}; - auto expected = std::vector{n[0], n[2], n[1], n[3]}; - CHECK(get_bfs_ordering(g, start_points) == expected); - } + std::vector n = add_nodes(g, 4); + std::vector e = { + {n[0], n[3]}, + {n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[2]}, + }; + add_edges(g, e); - SUBCASE("get_predecessors") { - std::unordered_map> expected_result = { + CHECK(get_incoming_edges(g, {n[2], n[3]}) == + std::unordered_set{e[0], e[2], e[3]}); + CHECK(get_outgoing_edges(g, {n[2], n[3]}) == + std::unordered_set{}); + auto expected_result = std::unordered_map>{ {n[1], {n[0]}}, {n[2], {n[0], n[1]}}, + {n[3], {n[0]}}, }; - CHECK(get_predecessors(g, {n[1], n[2]}) == expected_result); - } -} + CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); -TEST_CASE("traversal") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 5); - std::vector edges = {{n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; - add_edges(g, edges); - - CHECK(get_sources(g) == std::unordered_set{n[0], n[4]}); - CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(get_bfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == true); - CHECK(get_bfs_ordering(g, {n[4]}) == std::vector{n[4]}); - CHECK(get_dfs_ordering(g, {n[4]}) == std::vector{n[4]}); - - SUBCASE("with root") { - g.add_edge({n[3], n[2]}); - - CHECK(get_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); + SUBCASE("get_imm_dominators") { + std::unordered_map> result = get_imm_dominators(g); + + std::unordered_map> expected_result = { + {n[2], n[0]}, + {n[1], n[0]}, + {n[3], n[0]}, + {n[0], nullopt}, + }; + CHECK(result == expected_result); + } + + SUBCASE("get_dominators") { + std::unordered_map> expected = { + {n[0], {n[0]}}, + {n[1], {n[0], n[1]}}, + {n[2], {n[0], n[2]}}, + {n[3], {n[0], n[3]}}, + }; + CHECK(get_dominators(g) == expected); + } + + SUBCASE("get_sinks") { + auto expected = std::unordered_set{n[2], n[3]}; + CHECK(get_sinks(g) == expected); + } + + SUBCASE("get_bfs") { + std::unordered_set start_points = std::unordered_set{n[0]}; + auto expected = std::vector{n[0], n[2], n[1], n[3]}; + CHECK(get_bfs_ordering(g, start_points) == expected); + } + + SUBCASE("get_predecessors") { + std::unordered_map> expected_result = { + {n[1], {n[0]}}, + {n[2], {n[0], n[1]}}, + }; + CHECK(get_predecessors(g, {n[1], n[2]}) == expected_result); + } } - SUBCASE("without root") { - g.add_edge({n[3], n[0]}); + TEST_CASE("traversal") { + DiGraph g = DiGraph::create(); + std::vector const n = add_nodes(g, 5); + std::vector edges = {{n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; + add_edges(g, edges); - CHECK(get_dfs_ordering(g, {n[0]}) == + CHECK(get_sources(g) == std::unordered_set{n[0], n[4]}); + CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); - } - SUBCASE("nonlinear") { - g.add_edge({n[1], n[3]}); - CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs + CHECK(get_bfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(is_acyclic(g) == true); + CHECK(get_bfs_ordering(g, {n[4]}) == std::vector{n[4]}); + CHECK(get_dfs_ordering(g, {n[4]}) == std::vector{n[4]}); + + SUBCASE("with root") { + g.add_edge({n[3], n[2]}); + + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(is_acyclic(g) == false); + } + + SUBCASE("without root") { + g.add_edge({n[3], n[0]}); + + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(is_acyclic(g) == false); + } + SUBCASE("nonlinear") { + g.add_edge({n[1], n[3]}); + CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs + } + + SUBCASE("not connected") { + g.remove_edge({n[2], n[3]}); + CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); + } } - SUBCASE("not connected") { - g.remove_edge({n[2], n[3]}); - CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); + TEST_CASE("bfs") { + DiGraph g = DiGraph::create(); + std::vector const n = add_nodes(g, 7); + + std::vector e = { + {n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[6]}, + {n[2], n[3]}, + {n[3], n[4]}, + {n[4], n[5]}, + {n[5], n[6]}, + {n[6], n[0]}, + }; + + add_edges(g, e); + + std::vector ordering = get_bfs_ordering(g, {n[0]}); + auto CHECK_BEFORE = [&](int l, int r) { + CHECK(index_of(ordering, n[l]).has_value()); + CHECK(index_of(ordering, n[r]).has_value()); + CHECK(index_of(ordering, n[l]).value() < index_of(ordering, n[r]).value()); + }; + + CHECK(ordering.size() == n.size()); + CHECK_BEFORE(0, 1); + CHECK_BEFORE(0, 2); + + CHECK_BEFORE(1, 3); + CHECK_BEFORE(1, 6); + CHECK_BEFORE(2, 3); + CHECK_BEFORE(2, 6); + + CHECK_BEFORE(3, 4); + CHECK_BEFORE(6, 4); + + CHECK_BEFORE(4, 5); } -} -TEST_CASE("bfs") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 7); - - std::vector e = { - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[6]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}, - {n[5], n[6]}, - {n[6], n[0]}, - }; - - add_edges(g, e); - - std::vector ordering = get_bfs_ordering(g, {n[0]}); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]).value() < index_of(ordering, n[r]).value()); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - - CHECK_BEFORE(1, 3); - CHECK_BEFORE(1, 6); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(2, 6); - - CHECK_BEFORE(3, 4); - CHECK_BEFORE(6, 4); - - CHECK_BEFORE(4, 5); -} + TEST_CASE("get_topological_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + std::vector edges = {{n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[5]}, + {n[2], n[3]}, + {n[3], n[4]}, + {n[4], n[5]}}; + add_edges(g, edges); + std::vector ordering = get_topological_ordering(g); + auto CHECK_BEFORE = [&](int l, int r) { + CHECK(index_of(ordering, n[l]).has_value()); + CHECK(index_of(ordering, n[r]).has_value()); + CHECK(index_of(ordering, n[l]) < index_of(ordering, n[r])); + }; -TEST_CASE("get_topological_ordering") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 6); - std::vector edges = {{n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[5]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}}; - add_edges(g, edges); - std::vector ordering = get_topological_ordering(g); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]) < index_of(ordering, n[r])); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - CHECK_BEFORE(1, 5); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(3, 4); - CHECK_BEFORE(4, 5); -} + CHECK(ordering.size() == n.size()); + CHECK_BEFORE(0, 1); + CHECK_BEFORE(0, 2); + CHECK_BEFORE(1, 5); + CHECK_BEFORE(2, 3); + CHECK_BEFORE(3, 4); + CHECK_BEFORE(4, 5); + } -TEST_CASE("get_connected_components") { - UndirectedGraph g = UndirectedGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; + TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 4); + std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; + add_edges(g, edges); + std::unordered_set> expected_components = { + {n[0], n[1], n[2]}, + {n[3]}, + }; - CHECK(get_connected_components(g) == expected_components); -} + CHECK(get_connected_components(g) == expected_components); + } -TEST_CASE("get_weakly_connected_components") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 4); + TEST_CASE("get_weakly_connected_components") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; + std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; + add_edges(g, edges); + std::unordered_set> expected_components = { + {n[0], n[1], n[2]}, + {n[3]}, + }; - CHECK(get_outgoing_edges(as_digraph(as_undirected(g)), n[0]).size() == 1); + CHECK(get_outgoing_edges(as_digraph(as_undirected(g)), n[0]).size() == 1); - CHECK(get_weakly_connected_components(g) == expected_components); + CHECK(get_weakly_connected_components(g) == expected_components); + } } diff --git a/lib/utils/test/src/test_bidict.cc b/lib/utils/test/src/test_bidict.cc index 6c288089b6..afc32b3658 100644 --- a/lib/utils/test/src/test_bidict.cc +++ b/lib/utils/test/src/test_bidict.cc @@ -3,61 +3,63 @@ using namespace FlexFlow; -TEST_CASE("bidict") { - bidict dict; - dict.equate(1, "one"); - dict.equate(2, "two"); - - // Test the equate() function - SUBCASE("Equate") { - CHECK(dict.at_l(1) == "one"); - CHECK(dict.at_r("one") == 1); - CHECK(dict.at_l(2) == "two"); - CHECK(dict.at_r("two") == 2); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("bidict") { + bidict dict; + dict.equate(1, "one"); + dict.equate(2, "two"); - // Test the erase_l() function - SUBCASE("EraseL") { - dict.erase_l(1); - CHECK(dict.size() == 1); - CHECK_THROWS_AS(dict.at_l(1), std::out_of_range); - CHECK(dict.at_r("two") == 2); - } + // Test the equate() function + SUBCASE("Equate") { + CHECK(dict.at_l(1) == "one"); + CHECK(dict.at_r("one") == 1); + CHECK(dict.at_l(2) == "two"); + CHECK(dict.at_r("two") == 2); + } - // Test the erase_r() function - SUBCASE("EraseR") { - dict.erase_r("one"); - CHECK(dict.size() == 1); - CHECK_THROWS_AS(dict.at_r("one"), std::out_of_range); - CHECK(dict.at_l(2) == "two"); - } + // Test the erase_l() function + SUBCASE("EraseL") { + dict.erase_l(1); + CHECK(dict.size() == 1); + CHECK_THROWS_AS(dict.at_l(1), std::out_of_range); + CHECK(dict.at_r("two") == 2); + } - // Test the reversed() function - SUBCASE("Reversed") { - bidict reversed_dict = dict.reversed(); - CHECK(reversed_dict.at_l("one") == 1); - CHECK(reversed_dict.at_r(2) == "two"); - } + // Test the erase_r() function + SUBCASE("EraseR") { + dict.erase_r("one"); + CHECK(dict.size() == 1); + CHECK_THROWS_AS(dict.at_r("one"), std::out_of_range); + CHECK(dict.at_l(2) == "two"); + } - // Test the size() function - SUBCASE("Size") { - CHECK(dict.size() == 2); - } + // Test the reversed() function + SUBCASE("Reversed") { + bidict reversed_dict = dict.reversed(); + CHECK(reversed_dict.at_l("one") == 1); + CHECK(reversed_dict.at_r(2) == "two"); + } - SUBCASE("implicitly convert to std::unordered_map") { - std::unordered_map res = dict; - std::unordered_map expected = {{1, "one"}, {2, "two"}}; - CHECK(res == expected); - } + // Test the size() function + SUBCASE("Size") { + CHECK(dict.size() == 2); + } - SUBCASE("begin") { - auto it = dict.begin(); - CHECK(it->first == 2); - CHECK(it->second == "two"); - } + SUBCASE("implicitly convert to std::unordered_map") { + std::unordered_map res = dict; + std::unordered_map expected = {{1, "one"}, {2, "two"}}; + CHECK(res == expected); + } + + SUBCASE("begin") { + auto it = dict.begin(); + CHECK(it->first == 2); + CHECK(it->second == "two"); + } - SUBCASE("end") { - auto it = dict.end(); - CHECK(it == dict.end()); + SUBCASE("end") { + auto it = dict.end(); + CHECK(it == dict.end()); + } } } diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc index 8c37abf877..f6ac6e2d42 100644 --- a/lib/utils/test/src/test_containers.cc +++ b/lib/utils/test/src/test_containers.cc @@ -5,384 +5,387 @@ #include using namespace FlexFlow; -TEST_CASE("join_strings") { - std::vector const v = {"Hello", "world", "!"}; - CHECK(join_strings(v.begin(), v.end(), " ") == "Hello world !"); -} -TEST_CASE("join_strings with container") { - std::vector const v = {"Hello", "world"}; - CHECK(join_strings(v, " ") == "Hello world"); -} +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("join_strings") { + std::vector const v = {"Hello", "world", "!"}; + CHECK(join_strings(v.begin(), v.end(), " ") == "Hello world !"); + } -TEST_CASE("find") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(find(v, 3) != v.cend()); - CHECK(find(v, 6) == v.cend()); -} + TEST_CASE("join_strings with container") { + std::vector const v = {"Hello", "world"}; + CHECK(join_strings(v, " ") == "Hello world"); + } -TEST_CASE("sum") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(sum(v) == 15); -} + TEST_CASE("find") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(find(v, 3) != v.cend()); + CHECK(find(v, 6) == v.cend()); + } -TEST_CASE("sum with condition") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { return x % 2 == 0; }; // Sum of even numbers only - CHECK(sum_where(v, condition) == 6); -} + TEST_CASE("sum") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(sum(v) == 15); + } -TEST_CASE("product") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(product(v) == 120); -} + TEST_CASE("sum with condition") { + std::vector v = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; // Sum of even numbers only + CHECK(sum_where(v, condition) == 6); + } -TEST_CASE("product_where") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { - return x % 2 == 0; - }; // Product of even numbers only - CHECK(product_where(v, condition) == 8); -} + TEST_CASE("product") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(product(v) == 120); + } -TEST_CASE("contains") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); -} + TEST_CASE("product_where") { + std::vector v = {1, 2, 3, 4, 5}; + auto condition = [](int x) { + return x % 2 == 0; + }; // Product of even numbers only + CHECK(product_where(v, condition) == 8); + } -TEST_CASE("contains_key") { - std::unordered_map m = { - {"one", 1}, {"two", 2}, {"three", 3}}; - CHECK(contains_key(m, "one")); - CHECK(!contains_key(m, "four")); -} + TEST_CASE("contains") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(contains(v, 3)); + CHECK(!contains(v, 6)); + } -TEST_CASE("map_keys") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](int x) { return x * x; }; // Mapping function - auto result = map_keys(m, f); - CHECK(result.size() == 2); - CHECK(result[1] == "one"); - CHECK(result[4] == "two"); -} + TEST_CASE("contains_key") { + std::unordered_map m = { + {"one", 1}, {"two", 2}, {"three", 3}}; + CHECK(contains_key(m, "one")); + CHECK(!contains_key(m, "four")); + } -TEST_CASE("filter_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = [](int x) { return x % 2 == 1; }; // Filtering function - std::unordered_map result = filter_keys(m, f); - std::unordered_map expected = {{1, "one"}, {3, "three"}}; - CHECK(result == expected); -} + TEST_CASE("map_keys") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = [](int x) { return x * x; }; // Mapping function + auto result = map_keys(m, f); + CHECK(result.size() == 2); + CHECK(result[1] == "one"); + CHECK(result[4] == "two"); + } -TEST_CASE("map_values") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](std::string const &s) { return s.size(); }; // Mapping function - std::unordered_map result = map_values(m, f); - std::unordered_map expected = {{1, 3}, {2, 3}}; - CHECK(result == expected); -} + TEST_CASE("filter_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + auto f = [](int x) { return x % 2 == 1; }; // Filtering function + std::unordered_map result = filter_keys(m, f); + std::unordered_map expected = {{1, "one"}, {3, "three"}}; + CHECK(result == expected); + } -TEST_CASE("keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set result = keys(m); - std::unordered_set expected = {3, 2, 1}; - CHECK(result == expected); -} + TEST_CASE("map_values") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = [](std::string const &s) { return s.size(); }; // Mapping function + std::unordered_map result = map_values(m, f); + std::unordered_map expected = {{1, 3}, {2, 3}}; + CHECK(result == expected); + } -TEST_CASE("values") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::vector result = values(m); - std::vector expected = {"three", "two", "one"}; - CHECK(result == expected); -} + TEST_CASE("keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set result = keys(m); + std::unordered_set expected = {3, 2, 1}; + CHECK(result == expected); + } -// TEST_CASE("items") { -// std::unordered_map m = {{1, std::string("one")}, {2, -// std::string("two")}, {3,std::string("three")}}; -// std::cout<<"result type:"< v = {1, 2, 3, 2, 1}; - std::unordered_set result = unique(v); - std::unordered_set expected = {1, 2, 3}; - CHECK(result == expected); -} + TEST_CASE("values") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::vector result = values(m); + std::vector expected = {"three", "two", "one"}; + CHECK(result == expected); + } -TEST_CASE("without_order") { - std::vector v = {1, 4, 6, 4, 6}; - std::unordered_set expected = {1, 4, 6}; - CHECK(without_order(v) == expected); -} + // TEST_CASE("items") { + // std::unordered_map m = {{1, std::string("one")}, {2, + // std::string("two")}, {3,std::string("three")}}; + // std::cout<<"result type:"< v = {1, 2, 3, 2, 1}; + std::unordered_set result = unique(v); + std::unordered_set expected = {1, 2, 3}; + CHECK(result == expected); + } -TEST_CASE("index_of") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(index_of(v, 3) == 2); - CHECK(!index_of(v, 6).has_value()); -} + TEST_CASE("without_order") { + std::vector v = {1, 4, 6, 4, 6}; + std::unordered_set expected = {1, 4, 6}; + CHECK(without_order(v) == expected); + } -TEST_CASE("intersection") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {2, 3, 4}; - std::unordered_set result = intersection(l, r); - std::unordered_set expected = {2, 3}; - CHECK(result == expected); -} + TEST_CASE("index_of") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(index_of(v, 3) == 2); + CHECK(!index_of(v, 6).has_value()); + } -TEST_CASE("are_disjoint") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {4, 5, 6}; - CHECK(are_disjoint(l, r)); - r.insert(3); - CHECK_FALSE(are_disjoint(l, r)); -} + TEST_CASE("intersection") { + std::unordered_set l = {1, 2, 3}; + std::unordered_set r = {2, 3, 4}; + std::unordered_set result = intersection(l, r); + std::unordered_set expected = {2, 3}; + CHECK(result == expected); + } -TEST_CASE("restrict_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set mask = {2, 3, 4}; - std::unordered_map result = restrict_keys(m, mask); - std::unordered_map expected = {{2, "two"}, {3, "three"}}; - CHECK(result == expected); -} + TEST_CASE("are_disjoint") { + std::unordered_set l = {1, 2, 3}; + std::unordered_set r = {4, 5, 6}; + CHECK(are_disjoint(l, r)); + r.insert(3); + CHECK_FALSE(are_disjoint(l, r)); + } -TEST_CASE("merge_maps(unordered_map)") { - std::unordered_map lhs = {{1, "one"}, {2, "two"}}; - std::unordered_map rhs = {{3, "three"}, {4, "four"}}; - std::unordered_map result = merge_maps(lhs, rhs); - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); -} + TEST_CASE("restrict_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set mask = {2, 3, 4}; + std::unordered_map result = restrict_keys(m, mask); + std::unordered_map expected = {{2, "two"}, {3, "three"}}; + CHECK(result == expected); + } -TEST_CASE("merge_maps(bidict)") { - std::unordered_map fwd_map1 = {{1, "one"}, {2, "two"}}; - std::unordered_map bwd_map1 = {{"one", 1}, {"two", 2}}; - std::unordered_map fwd_map2 = {{3, "three"}, {4, "four"}}; - std::unordered_map bwd_map2 = {{"three", 3}, {"four", 4}}; - bidict lhs{fwd_map1, bwd_map1}; - bidict rhs{fwd_map2, bwd_map2}; - - std::unordered_map result = - merge_maps(lhs, rhs); // impicit conversion - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); -} + TEST_CASE("merge_maps(unordered_map)") { + std::unordered_map lhs = {{1, "one"}, {2, "two"}}; + std::unordered_map rhs = {{3, "three"}, {4, "four"}}; + std::unordered_map result = merge_maps(lhs, rhs); + std::unordered_map expected = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + CHECK(result == expected); + } -TEST_CASE("lookup_in") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = lookup_in(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); - CHECK(f(3) == "three"); -} + TEST_CASE("merge_maps(bidict)") { + std::unordered_map fwd_map1 = {{1, "one"}, {2, "two"}}; + std::unordered_map bwd_map1 = {{"one", 1}, {"two", 2}}; + std::unordered_map fwd_map2 = {{3, "three"}, {4, "four"}}; + std::unordered_map bwd_map2 = {{"three", 3}, {"four", 4}}; + bidict lhs{fwd_map1, bwd_map1}; + bidict rhs{fwd_map2, bwd_map2}; + + std::unordered_map result = + merge_maps(lhs, rhs); // impicit conversion + std::unordered_map expected = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + CHECK(result == expected); + } -TEST_CASE("lookup_in_l") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_l(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); -} + TEST_CASE("lookup_in") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + auto f = lookup_in(m); + CHECK(f(1) == "one"); + CHECK(f(2) == "two"); + CHECK(f(3) == "three"); + } -TEST_CASE("lookup_in_r") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_r(m); - CHECK(f("one") == 1); - CHECK(f("two") == 2); -} + TEST_CASE("lookup_in_l") { + bidict m; + m.equate(1, "one"); + m.equate(2, "two"); + auto f = lookup_in_l(m); + CHECK(f(1) == "one"); + CHECK(f(2) == "two"); + } -TEST_CASE("set_union") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {2, 3, 4}; - std::unordered_set result = set_union(s1, s2); - std::unordered_set expected = {1, 2, 3, 4}; - CHECK(result == expected); -} + TEST_CASE("lookup_in_r") { + bidict m; + m.equate(1, "one"); + m.equate(2, "two"); + auto f = lookup_in_r(m); + CHECK(f("one") == 1); + CHECK(f("two") == 2); + } -TEST_CASE("is_subseteq_of") { - std::unordered_set s1 = {1, 2}; - std::unordered_set s2 = {1, 2, 3}; - CHECK(is_subseteq_of(s1, s2) == true); - CHECK(is_subseteq_of(s2, s1) == false); - CHECK(is_subseteq_of(s1, s1) == true); - CHECK(is_subseteq_of(s2, s2) == true); -} + TEST_CASE("set_union") { + std::unordered_set s1 = {1, 2, 3}; + std::unordered_set s2 = {2, 3, 4}; + std::unordered_set result = set_union(s1, s2); + std::unordered_set expected = {1, 2, 3, 4}; + CHECK(result == expected); + } -TEST_CASE("is_superseteq_of") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {1, 2}; - CHECK(is_supserseteq_of(s1, s2) == true); - CHECK(is_supserseteq_of(s2, s1) == false); -} + TEST_CASE("is_subseteq_of") { + std::unordered_set s1 = {1, 2}; + std::unordered_set s2 = {1, 2, 3}; + CHECK(is_subseteq_of(s1, s2) == true); + CHECK(is_subseteq_of(s2, s1) == false); + CHECK(is_subseteq_of(s1, s1) == true); + CHECK(is_subseteq_of(s2, s2) == true); + } -TEST_CASE("get_only") { - std::unordered_set s = {42}; - CHECK(get_only(s) == 42); -} + TEST_CASE("is_superseteq_of") { + std::unordered_set s1 = {1, 2, 3}; + std::unordered_set s2 = {1, 2}; + CHECK(is_supserseteq_of(s1, s2) == true); + CHECK(is_supserseteq_of(s2, s1) == false); + } -TEST_CASE("get_first") { - std::unordered_set s = {1, 2, 3}; - CHECK(s.count(get_first(s)) == 1); -} + TEST_CASE("get_only") { + std::unordered_set s = {42}; + CHECK(get_only(s) == 42); + } -TEST_CASE("extend") { - std::vector v = {1, 2, 3}; - std::unordered_set s = {4, 5, 6}; - extend(v, s); - CHECK(v.size() == 6); - std::vector expected = {1, 2, 3, 6, 5, 4}; - CHECK(v == expected); -} + TEST_CASE("get_first") { + std::unordered_set s = {1, 2, 3}; + CHECK(s.count(get_first(s)) == 1); + } -TEST_CASE("all_of") { - std::vector v = {2, 4, 6, 8}; - CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); - CHECK(all_of(v, [](int x) { return x % 2 == 1; }) == false); -} + TEST_CASE("extend") { + std::vector v = {1, 2, 3}; + std::unordered_set s = {4, 5, 6}; + extend(v, s); + CHECK(v.size() == 6); + std::vector expected = {1, 2, 3, 6, 5, 4}; + CHECK(v == expected); + } -TEST_CASE("count") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); - CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); -} + TEST_CASE("all_of") { + std::vector v = {2, 4, 6, 8}; + CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); + CHECK(all_of(v, [](int x) { return x % 2 == 1; }) == false); + } -TEST_CASE("are_all_same") { - std::vector v1 = {2, 2, 2, 2}; - std::vector v2 = {1, 2, 3, 4}; - CHECK(are_all_same(v1) == true); - CHECK(are_all_same(v2) == false); -} + TEST_CASE("count") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); + CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); + } -TEST_CASE("vector_transform") { - std::vector v = {1, 2, 3}; - auto result = vector_transform([](int x) { return x * 2; }, v); - CHECK(result == std::vector({2, 4, 6})); -} + TEST_CASE("are_all_same") { + std::vector v1 = {2, 2, 2, 2}; + std::vector v2 = {1, 2, 3, 4}; + CHECK(are_all_same(v1) == true); + CHECK(are_all_same(v2) == false); + } -TEST_CASE("as_vector") { - std::unordered_set s = {1, 2, 3}; - std::vector result = as_vector(s); - CHECK(result == std::vector({3, 2, 1})); -} + TEST_CASE("vector_transform") { + std::vector v = {1, 2, 3}; + auto result = vector_transform([](int x) { return x * 2; }, v); + CHECK(result == std::vector({2, 4, 6})); + } -TEST_CASE("transform_vector") { - std::vector v = {1, 2, 3}; - auto result = transform(v, [](int x) { return x * 2; }); - CHECK(result == std::vector({2, 4, 6})); -} + TEST_CASE("as_vector") { + std::unordered_set s = {1, 2, 3}; + std::vector result = as_vector(s); + CHECK(result == std::vector({3, 2, 1})); + } -TEST_CASE("transform_unordered_set") { - std::unordered_set s = {1, 2, 3}; - auto result = transform(s, [](int x) { return x * 2; }); - CHECK(result == std::unordered_set({2, 4, 6})); -} + TEST_CASE("transform_vector") { + std::vector v = {1, 2, 3}; + auto result = transform(v, [](int x) { return x * 2; }); + CHECK(result == std::vector({2, 4, 6})); + } -TEST_CASE("transform_string") { - std::string s = "abc"; - auto result = transform(s, ::toupper); - CHECK(result == "ABC"); -} + TEST_CASE("transform_unordered_set") { + std::unordered_set s = {1, 2, 3}; + auto result = transform(s, [](int x) { return x * 2; }); + CHECK(result == std::unordered_set({2, 4, 6})); + } -TEST_CASE("repeat") { - int ctr = 0; - std::vector result = repeat(5, [&] { return ctr++; }); + TEST_CASE("transform_string") { + std::string s = "abc"; + auto result = transform(s, ::toupper); + CHECK(result == "ABC"); + } - CHECK(result == std::vector{0, 1, 2, 3, 4}); -} + TEST_CASE("repeat") { + int ctr = 0; + std::vector result = repeat(5, [&] { return ctr++; }); -TEST_CASE("Testing the 'enumerate' function") { - std::unordered_set input_set = {1, 2, 3, 4, 5}; - std::unordered_map result = enumerate(input_set); - std::unordered_map expected = { - {1, 4}, {2, 3}, {3, 2}, {4, 1}, {0, 5}}; - CHECK(result == expected); -} + CHECK(result == std::vector{0, 1, 2, 3, 4}); + } -TEST_CASE("Testing the 'maximum' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - auto result = maximum(input_vec); + TEST_CASE("Testing the 'enumerate' function") { + std::unordered_set input_set = {1, 2, 3, 4, 5}; + std::unordered_map result = enumerate(input_set); + std::unordered_map expected = { + {1, 4}, {2, 3}, {3, 2}, {4, 1}, {0, 5}}; + CHECK(result == expected); + } - // Checking the maximum is as expected - REQUIRE(result == 5); -} + TEST_CASE("Testing the 'maximum' function") { + std::vector input_vec = {1, 2, 3, 4, 5}; + auto result = maximum(input_vec); -TEST_CASE("Testing the 'reversed' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - std::vector result = reversed(input_vec); - std::vector expected = {5, 4, 3, 2, 1}; + // Checking the maximum is as expected + REQUIRE(result == 5); + } - // Checking the reversed sequence is as expected - CHECK(result == expected); -} + TEST_CASE("Testing the 'reversed' function") { + std::vector input_vec = {1, 2, 3, 4, 5}; + std::vector result = reversed(input_vec); + std::vector expected = {5, 4, 3, 2, 1}; + + // Checking the reversed sequence is as expected + CHECK(result == expected); + } -TEST_CASE("Testing sorted_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - auto sorted_s = sorted_by(s, [](int a, int b) { return a < b; }); - CHECK(sorted_s == std::vector({1, 2, 3, 4, 5})); + TEST_CASE("Testing sorted_by function") { + std::unordered_set s = {5, 2, 3, 4, 1}; + auto sorted_s = sorted_by(s, [](int a, int b) { return a < b; }); + CHECK(sorted_s == std::vector({1, 2, 3, 4, 5})); - std::unordered_set s2 = {-5, -1, -3, -2, -4}; - auto sorted_s2 = sorted_by(s2, [](int a, int b) { return a > b; }); - CHECK(sorted_s2 == std::vector({-1, -2, -3, -4, -5})); -} + std::unordered_set s2 = {-5, -1, -3, -2, -4}; + auto sorted_s2 = sorted_by(s2, [](int a, int b) { return a > b; }); + CHECK(sorted_s2 == std::vector({-1, -2, -3, -4, -5})); + } -TEST_CASE("Testing compare_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - std::vector result = - sorted_by(s, compare_by([](int i) { return (-i); })); - CHECK(result == std::vector{5, 4, 3, 2, 1}); -} + TEST_CASE("Testing compare_by function") { + std::unordered_set s = {5, 2, 3, 4, 1}; + std::vector result = + sorted_by(s, compare_by([](int i) { return (-i); })); + CHECK(result == std::vector{5, 4, 3, 2, 1}); + } -TEST_CASE("Testing vector_split function") { - std::vector v = {1, 2, 3, 4, 5}; - auto result = vector_split(v, 2); - std::vector prefix = result.first; - std::vector postfix = result.second; - CHECK(prefix == std::vector({1, 2})); - CHECK(postfix == std::vector({3, 4, 5})); -} + TEST_CASE("Testing vector_split function") { + std::vector v = {1, 2, 3, 4, 5}; + auto result = vector_split(v, 2); + std::vector prefix = result.first; + std::vector postfix = result.second; + CHECK(prefix == std::vector({1, 2})); + CHECK(postfix == std::vector({3, 4, 5})); + } -TEST_CASE("Testing value_all function") { - std::vector> v = {1, 2, 3, 4, 5}; - auto value_all_v = value_all(v); - CHECK(value_all_v == std::vector({1, 2, 3, 4, 5})); -} + TEST_CASE("Testing value_all function") { + std::vector> v = {1, 2, 3, 4, 5}; + auto value_all_v = value_all(v); + CHECK(value_all_v == std::vector({1, 2, 3, 4, 5})); + } -TEST_CASE("Testing subvec function") { - std::vector v = {1, 2, 3, 4, 5}; - auto subvec_v = subvec(v, tl::optional(1), tl::optional(4)); + TEST_CASE("Testing subvec function") { + std::vector v = {1, 2, 3, 4, 5}; + auto subvec_v = subvec(v, tl::optional(1), tl::optional(4)); - CHECK(subvec_v == std::vector({2, 3, 4})); + CHECK(subvec_v == std::vector({2, 3, 4})); - auto subvec_v2 = subvec(v, tl::nullopt, tl::optional(3)); - CHECK(subvec_v2 == std::vector({1, 2, 3})); -} + auto subvec_v2 = subvec(v, tl::nullopt, tl::optional(3)); + CHECK(subvec_v2 == std::vector({1, 2, 3})); + } -auto get_factors = [](int x) -> std::vector { - // Returns a vector of factors of x - std::vector factors; - for (int i = 1; i <= x; i++) { - if (x % i == 0) { - factors.push_back(i); + auto get_factors = [](int x) -> std::vector { + // Returns a vector of factors of x + std::vector factors; + for (int i = 1; i <= x; i++) { + if (x % i == 0) { + factors.push_back(i); + } } + return factors; + }; + + // Example for vector + TEST_CASE("Test for flatmap function on vectors") { + std::vector v = {2, 3, 4, 5}; + auto result = flatmap(v, get_factors); + CHECK(result == std::vector({1, 2, 1, 3, 1, 2, 4, 1, 5})); } - return factors; -}; - -// Example for vector -TEST_CASE("Test for flatmap function on vectors") { - std::vector v = {2, 3, 4, 5}; - auto result = flatmap(v, get_factors); - CHECK(result == std::vector({1, 2, 1, 3, 1, 2, 4, 1, 5})); } diff --git a/lib/utils/test/src/test_cow_ptr.cc b/lib/utils/test/src/test_cow_ptr.cc index 62406bddec..de573d0c9b 100644 --- a/lib/utils/test/src/test_cow_ptr.cc +++ b/lib/utils/test/src/test_cow_ptr.cc @@ -22,39 +22,41 @@ struct TestObjectDerived : public TestObject { } }; -TEST_CASE("cow_ptr_t constructor") { - std::shared_ptr sp = std::make_shared(1); - cow_ptr_t p1(sp); - cow_ptr_t p2(std::make_shared(3)); - cow_ptr_t p3(TestObject(2)); - cow_ptr_t p4(p3); - cow_ptr_t p5 = p1; - CHECK(p1->x == 1); - CHECK(p2->x == 3); - CHECK(p3->x == 2); - CHECK(p4->x == p3->x); - CHECK(p5->x == p1->x); -} +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cow_ptr_t constructor") { + std::shared_ptr sp = std::make_shared(1); + cow_ptr_t p1(sp); + cow_ptr_t p2(std::make_shared(3)); + cow_ptr_t p3(TestObject(2)); + cow_ptr_t p4(p3); + cow_ptr_t p5 = p1; + CHECK(p1->x == 1); + CHECK(p2->x == 3); + CHECK(p3->x == 2); + CHECK(p4->x == p3->x); + CHECK(p5->x == p1->x); + } -TEST_CASE("cow_ptr_t copy") { - cow_ptr_t p1(std::make_shared(1)); - cow_ptr_t p2(std::make_shared(2)); - p1 = p2; - CHECK(p1->x == p2->x); -} + TEST_CASE("cow_ptr_t copy") { + cow_ptr_t p1(std::make_shared(1)); + cow_ptr_t p2(std::make_shared(2)); + p1 = p2; + CHECK(p1->x == p2->x); + } -TEST_CASE("cow_ptr_t cast") { - cow_ptr_t p1(std::make_shared(1, 2)); - cow_ptr_t p2(p1); - CHECK(p2->x == 1); -} + TEST_CASE("cow_ptr_t cast") { + cow_ptr_t p1(std::make_shared(1, 2)); + cow_ptr_t p2(p1); + CHECK(p2->x == 1); + } -TEST_CASE("cow_ptr_t get_mutable") { - cow_ptr_t p1(std::make_shared(1)); - cow_ptr_t p2(p1); - p1.get_mutable()->x = 3; - CHECK(p1->x == 3); - CHECK(p2->x == 1); - p2.get_mutable()->x = 2; - CHECK(p1->x == 3); + TEST_CASE("cow_ptr_t get_mutable") { + cow_ptr_t p1(std::make_shared(1)); + cow_ptr_t p2(p1); + p1.get_mutable()->x = 3; + CHECK(p1->x == 3); + CHECK(p2->x == 1); + p2.get_mutable()->x = 2; + CHECK(p1->x == 3); + } } diff --git a/lib/utils/test/src/test_deduplicated_priority_queue.cc b/lib/utils/test/src/test_deduplicated_priority_queue.cc index a5c97fa0f8..66cfd395bc 100644 --- a/lib/utils/test/src/test_deduplicated_priority_queue.cc +++ b/lib/utils/test/src/test_deduplicated_priority_queue.cc @@ -1,34 +1,36 @@ #include "test/utils/doctest.h" #include "utils/deduplicated_priority_queue.h" -TEST_CASE("DeduplicatedPriorityQueue push and pop") { - DeduplicatedPriorityQueue queue; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DeduplicatedPriorityQueue push and pop") { + DeduplicatedPriorityQueue queue; - SUBCASE("Push elements") { - queue.push(5); - queue.push(2); - queue.push(7); - queue.push(2); + SUBCASE("Push elements") { + queue.push(5); + queue.push(2); + queue.push(7); + queue.push(2); - CHECK(queue.size() == 3); - CHECK(queue.top() == 7); - CHECK_FALSE(queue.empty()); - } + CHECK(queue.size() == 3); + CHECK(queue.top() == 7); + CHECK_FALSE(queue.empty()); + } - SUBCASE("Pop elements") { - queue.push(5); - queue.push(2); - queue.push(7); + SUBCASE("Pop elements") { + queue.push(5); + queue.push(2); + queue.push(7); - queue.pop(); - CHECK(queue.size() == 2); - CHECK(queue.top() == 5); + queue.pop(); + CHECK(queue.size() == 2); + CHECK(queue.top() == 5); - queue.pop(); - CHECK(queue.size() == 1); - CHECK(queue.top() == 2); + queue.pop(); + CHECK(queue.size() == 1); + CHECK(queue.top() == 2); - queue.pop(); - CHECK(queue.empty()); + queue.pop(); + CHECK(queue.empty()); + } } } diff --git a/lib/utils/test/src/test_disjoint_set.cc b/lib/utils/test/src/test_disjoint_set.cc index fe2c4bae33..8bcf2e533f 100644 --- a/lib/utils/test/src/test_disjoint_set.cc +++ b/lib/utils/test/src/test_disjoint_set.cc @@ -16,53 +16,55 @@ std::string generate_element(int seed) { return "Element" + std::to_string(seed); } -TEST_CASE_TEMPLATE("DisjointSetUnionAndFind", T, int, std::string) { - disjoint_set> ds; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("DisjointSetUnionAndFind", T, int, std::string) { + disjoint_set> ds; - SUBCASE("SingleElementSets") { - optional element = generate_element(1); - CHECK(ds.find(element) == element); + SUBCASE("SingleElementSets") { + optional element = generate_element(1); + CHECK(ds.find(element) == element); - element = generate_element(2); - CHECK(ds.find(element) == element); - } + element = generate_element(2); + CHECK(ds.find(element) == element); + } - SUBCASE("UnionAndFind") { - optional element1 = generate_element(1); - optional element2 = generate_element(2); - optional element3 = generate_element(3); - optional element4 = generate_element(4); + SUBCASE("UnionAndFind") { + optional element1 = generate_element(1); + optional element2 = generate_element(2); + optional element3 = generate_element(3); + optional element4 = generate_element(4); - ds.m_union(element1, element2); - CHECK(ds.find(element1) == ds.find(element2)); + ds.m_union(element1, element2); + CHECK(ds.find(element1) == ds.find(element2)); - ds.m_union(element3, element4); - CHECK(ds.find(element3) == ds.find(element4)); + ds.m_union(element3, element4); + CHECK(ds.find(element3) == ds.find(element4)); - ds.m_union(element1, element3); - CHECK(ds.find(element1) == ds.find(element3)); - CHECK(ds.find(element2) == ds.find(element4)); - CHECK(ds.find(element1) == ds.find(element2)); - CHECK(ds.find(element1) == ds.find(element4)); + ds.m_union(element1, element3); + CHECK(ds.find(element1) == ds.find(element3)); + CHECK(ds.find(element2) == ds.find(element4)); + CHECK(ds.find(element1) == ds.find(element2)); + CHECK(ds.find(element1) == ds.find(element4)); + } } -} -TEST_CASE_TEMPLATE("DisjointSetMapping", T, int, std::string) { - disjoint_set ds; - ds.m_union(1, 2); - ds.m_union(3, 4); - ds.m_union(1, 4); - ds.m_union(5, 6); + TEST_CASE_TEMPLATE("DisjointSetMapping", T, int, std::string) { + disjoint_set ds; + ds.m_union(1, 2); + ds.m_union(3, 4); + ds.m_union(1, 4); + ds.m_union(5, 6); - std::map, optional, OptionalComparator> - expectedMapping = {{1, 4}, {2, 4}, {3, 4}, {4, 4}, {5, 6}, {6, 6}}; + std::map, optional, OptionalComparator> + expectedMapping = {{1, 4}, {2, 4}, {3, 4}, {4, 4}, {5, 6}, {6, 6}}; - std::map, optional, OptionalComparator> mapping = - ds.get_mapping(); + std::map, optional, OptionalComparator> mapping = + ds.get_mapping(); - for (auto const &kv : mapping) { - CHECK( - *kv.second == - *expectedMapping[kv.first]); // Compare the values inside the optionals + for (auto const &kv : mapping) { + CHECK( + *kv.second == + *expectedMapping[kv.first]); // Compare the values inside the optionals + } } } diff --git a/lib/utils/test/src/test_dot_file.cc b/lib/utils/test/src/test_dot_file.cc index a65265afbd..ed4c32bb1c 100644 --- a/lib/utils/test/src/test_dot_file.cc +++ b/lib/utils/test/src/test_dot_file.cc @@ -2,67 +2,68 @@ #include "utils/dot_file.h" #include -TEST_CASE("DotFile") { - std::ostringstream oss; - DotFile dotFile(oss); - SUBCASE("add_node") { - dotFile.add_node("A", {{"shape", "circle"}, {"label", "Node A"}}); - dotFile.add_node("B", {{"shape", "rectangle"}, {"label", "Node B"}}); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DotFile") { + std::ostringstream oss; + DotFile dotFile(oss); + SUBCASE("add_node") { + dotFile.add_node("A", {{"shape", "circle"}, {"label", "Node A"}}); + dotFile.add_node("B", {{"shape", "rectangle"}, {"label", "Node B"}}); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { node0 [label=Node A,shape=circle]; node1 [label=Node B,shape=rectangle]; })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); - } + CHECK(oss.str() == expectedOutput); + } - SUBCASE("add_edge") { - dotFile.add_edge("A", "B"); - dotFile.add_edge("B", "C"); + SUBCASE("add_edge") { + dotFile.add_edge("A", "B"); + dotFile.add_edge("B", "C"); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { node0 -> node1; node1 -> node2; })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); - } + CHECK(oss.str() == expectedOutput); + } - SUBCASE("add_record_node") { - RecordFormatter rf; + SUBCASE("add_record_node") { + RecordFormatter rf; - rf << "Field1"; - rf << 42; - rf << "Field2"; - rf << float(3.14); + rf << "Field1"; + rf << 42; + rf << "Field2"; + rf << float(3.14); - dotFile.add_record_node("A", rf); + dotFile.add_record_node("A", rf); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = - R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = + R"EXPECTED_OUTPUT(digraph taskgraph { node0 [label="{ Field1 | 42 | Field2 | 3.140000e+00 }",shape=record]; })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); - } + CHECK(oss.str() == expectedOutput); + } - SUBCASE("add_node_to_subgraph") { - size_t subgraph1 = dotFile.add_subgraph(); - size_t subgraph2 = dotFile.add_subgraph(subgraph1); + SUBCASE("add_node_to_subgraph") { + size_t subgraph1 = dotFile.add_subgraph(); + size_t subgraph2 = dotFile.add_subgraph(subgraph1); - dotFile.add_node_to_subgraph("A", subgraph1); - dotFile.add_node_to_subgraph("B", subgraph2); + dotFile.add_node_to_subgraph("A", subgraph1); + dotFile.add_node_to_subgraph("B", subgraph2); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { subgraph cluster_0 { node1; node0; @@ -72,6 +73,7 @@ node1; } })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); + CHECK(oss.str() == expectedOutput); + } } } diff --git a/lib/utils/test/src/test_format.cc b/lib/utils/test/src/test_format.cc index 2f653c85af..eeed2eae81 100644 --- a/lib/utils/test/src/test_format.cc +++ b/lib/utils/test/src/test_format.cc @@ -7,32 +7,34 @@ std::string formatRecord(RecordFormatter const &formatter) { return oss.str(); } -TEST_CASE("RecordFormatter") { - RecordFormatter formatter; - SUBCASE("Appending string") { - formatter << "Hello"; - formatter << "World"; - CHECK(formatRecord(formatter) == "{ Hello | World }"); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RecordFormatter") { + RecordFormatter formatter; + SUBCASE("Appending string") { + formatter << "Hello"; + formatter << "World"; + CHECK(formatRecord(formatter) == "{ Hello | World }"); + } - SUBCASE("Appending integer and float") { - formatter << 42; - formatter << 3.14f; - CHECK(formatRecord(formatter) == "{ 42 | 3.140000e+00 }"); - } + SUBCASE("Appending integer and float") { + formatter << 42; + formatter << 3.14f; + CHECK(formatRecord(formatter) == "{ 42 | 3.140000e+00 }"); + } - SUBCASE("Appending another RecordFormatter") { - RecordFormatter subFormatter; - subFormatter << "Sub"; - subFormatter << "Formatter"; + SUBCASE("Appending another RecordFormatter") { + RecordFormatter subFormatter; + subFormatter << "Sub"; + subFormatter << "Formatter"; - RecordFormatter formatter; - formatter << "Hello"; - formatter << subFormatter; + RecordFormatter formatter; + formatter << "Hello"; + formatter << subFormatter; - std::ostringstream oss; - oss << formatter; + std::ostringstream oss; + oss << formatter; - CHECK(formatRecord(formatter) == "{ Hello | { Sub | Formatter } }"); + CHECK(formatRecord(formatter) == "{ Hello | { Sub | Formatter } }"); + } } } diff --git a/lib/utils/test/src/test_hash.cc b/lib/utils/test/src/test_hash.cc index f0d907b741..b38c43fe30 100644 --- a/lib/utils/test/src/test_hash.cc +++ b/lib/utils/test/src/test_hash.cc @@ -3,16 +3,18 @@ using namespace FlexFlow; -TEST_CASE("hash:unordered_map") { - std::unordered_map map1{{1, 2}}; - std::unordered_map map2{{1, 2}, {3, 4}}; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("hash:unordered_map") { + std::unordered_map map1{{1, 2}}; + std::unordered_map map2{{1, 2}, {3, 4}}; - size_t hash1 = get_std_hash(map1); - size_t hash2 = get_std_hash(map2); + size_t hash1 = get_std_hash(map1); + size_t hash2 = get_std_hash(map2); - CHECK(hash1 != hash2); + CHECK(hash1 != hash2); - map1.insert({1, 2}); - hash1 = get_std_hash(map1); - CHECK(hash1 == hash2); + map1.insert({1, 2}); + hash1 = get_std_hash(map1); + CHECK(hash1 == hash2); + } } diff --git a/lib/utils/test/src/test_multidigraph.cc b/lib/utils/test/src/test_multidigraph.cc index 944ff0b7ca..91631f0391 100644 --- a/lib/utils/test/src/test_multidigraph.cc +++ b/lib/utils/test/src/test_multidigraph.cc @@ -5,86 +5,88 @@ using namespace FlexFlow; -TEST_CASE_TEMPLATE("MultiDiGraph implementations", T, AdjacencyMultiDiGraph) { - MultiDiGraph g = MultiDiGraph::create(); - - std::vector n = repeat(3, [&] { return g.add_node(); }); - std::vector p = repeat(3, [&] { return g.add_node_port(); }); - - std::vector e = {{n[1], p[1], n[0], p[0]}, - {n[2], p[2], n[0], p[0]}, - {n[0], p[0], n[2], p[2]}, - {n[1], p[1], n[2], p[2]}}; - for (MultiDiEdge const &edge : e) { - g.add_edge(edge); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("MultiDiGraph implementations", T, AdjacencyMultiDiGraph) { + MultiDiGraph g = MultiDiGraph::create(); + + std::vector n = repeat(3, [&] { return g.add_node(); }); + std::vector p = repeat(3, [&] { return g.add_node_port(); }); - CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[0], n[1], n[2]}); - - CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == - std::unordered_set{n[0], n[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[0], e[1], e[2], e[3]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(query_set( - {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( - {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs(query_set( - {p[1], p[2]}))) == std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs(query_set( - {p[0], p[2]}))) == std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all() - .with_src_nodes({n[1]}) - .with_dst_nodes({n[2]}) - .with_src_idxs({p[1]}) - .with_dst_idxs({p[2]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[2]})) == - std::unordered_set{e[1]}); - - SUBCASE("remove node") { - g.remove_node_unsafe(n[0]); + std::vector e = {{n[1], p[1], n[0], p[0]}, + {n[2], p[2], n[0], p[0]}, + {n[0], p[0], n[2], p[2]}, + {n[1], p[1], n[2], p[2]}}; + for (MultiDiEdge const &edge : e) { + g.add_edge(edge); + } CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[1], n[2]}); + std::unordered_set{n[0], n[1], n[2]}); + + CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == + std::unordered_set{n[0], n[2]}); CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[2], e[3]}); + std::unordered_set{e[0], e[1], e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[0]})) == + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[1]})) == std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[1]})) == + std::unordered_set{e[0], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[1]})) == + std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[1]})) == + std::unordered_set{e[0], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(query_set( + {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( + {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs(query_set( + {p[1], p[2]}))) == std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs(query_set( + {p[0], p[2]}))) == std::unordered_set{e[1], e[2]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all() + .with_src_nodes({n[1]}) + .with_dst_nodes({n[2]}) + .with_src_idxs({p[1]}) + .with_dst_idxs({p[2]})) == + std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[2]})) == + std::unordered_set{e[1]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[0]})) == - std::unordered_set{e[2]}); + SUBCASE("remove node") { + g.remove_node_unsafe(n[0]); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[0]})) == - std::unordered_set{e[2]}); - } + CHECK(g.query_nodes(NodeQuery::all()) == + std::unordered_set{n[1], n[2]}); - SUBCASE("remove_edge") { - g.remove_edge(e[0]); + CHECK(g.query_edges(MultiDiEdgeQuery::all()) == + std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges( - MultiDiEdgeQuery::all().with_src_nodes({n[0]}).with_dst_nodes( - {n[1]})) == std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[0]})) == + std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[2]})) == - std::unordered_set{e[1]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[0]})) == + std::unordered_set{e[2]}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == + std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[0]})) == + std::unordered_set{e[2]}); + } + + SUBCASE("remove_edge") { + g.remove_edge(e[0]); + + CHECK(g.query_edges( + MultiDiEdgeQuery::all().with_src_nodes({n[0]}).with_dst_nodes( + {n[1]})) == std::unordered_set{}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[2]})) == + std::unordered_set{e[1]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == + std::unordered_set{e[2], e[3]}); + } } } diff --git a/lib/utils/test/src/test_random_utils.cc b/lib/utils/test/src/test_random_utils.cc index dd7c320d85..88a566a198 100644 --- a/lib/utils/test/src/test_random_utils.cc +++ b/lib/utils/test/src/test_random_utils.cc @@ -14,52 +14,54 @@ void checkProbabilities(std::vector const &counts, } } -TEST_CASE("select_random") { - std::vector values = {1, 2, 3, 4, 5}; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("select_random") { + std::vector values = {1, 2, 3, 4, 5}; - SUBCASE("Select random value") { - int result = select_random(values); + SUBCASE("Select random value") { + int result = select_random(values); - CHECK(std::find(values.begin(), values.end(), result) != values.end()); - } + CHECK(std::find(values.begin(), values.end(), result) != values.end()); + } - SUBCASE("Invalid arguments") { - std::vector weights = {0.1f, 0.3f, 0.2f}; - CHECK(select_random(values, weights) == 2); + SUBCASE("Invalid arguments") { + std::vector weights = {0.1f, 0.3f, 0.2f}; + CHECK(select_random(values, weights) == 2); + } } -} -TEST_CASE("select_random - Weighted Random Selection") { - SUBCASE("Test with equal weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + TEST_CASE("select_random - Weighted Random Selection") { + SUBCASE("Test with equal weights") { + std::vector values = {1, 2, 3, 4, 5}; + std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; + std::vector counts(values.size(), 0); + int const numIterations = 10000; + for (int i = 0; i < numIterations; i++) { + int selected = select_random(values, weights); + counts[selected - 1]++; + } + + checkProbabilities(counts, numIterations, weights, values.size()); } - checkProbabilities(counts, numIterations, weights, values.size()); - } + SUBCASE("Test with different weights") { + std::vector values = {1, 2, 3, 4, 5}; + std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; - SUBCASE("Test with different weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; + std::vector counts(values.size(), 0); + int const numIterations = 10000; + for (int i = 0; i < numIterations; i++) { + int selected = select_random(values, weights); + counts[selected - 1]++; + } - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; - } + float totalWeight = 0.0f; + for (float weight : weights) { + totalWeight += weight; + } - float totalWeight = 0.0f; - for (float weight : weights) { - totalWeight += weight; + checkProbabilities(counts, numIterations, weights, totalWeight); } - - checkProbabilities(counts, numIterations, weights, totalWeight); } } diff --git a/lib/utils/test/src/test_sequence.cc b/lib/utils/test/src/test_sequence.cc index 576271a858..ee72febe05 100644 --- a/lib/utils/test/src/test_sequence.cc +++ b/lib/utils/test/src/test_sequence.cc @@ -3,169 +3,171 @@ using namespace FlexFlow; -TEST_CASE("seq_head") { - SUBCASE("seq_head with non-empty sequence") { - using Seq = seq<1, 2, 3, 4>; - constexpr int result = seq_head::value; - CHECK(result == 1); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("seq_head") { + SUBCASE("seq_head with non-empty sequence") { + using Seq = seq<1, 2, 3, 4>; + constexpr int result = seq_head::value; + CHECK(result == 1); + } + + SUBCASE("seq_head with empty sequence") { + using Seq = seq<>; + constexpr int result = seq_head::value; + CHECK(result == -1); + } } - SUBCASE("seq_head with empty sequence") { - using Seq = seq<>; - constexpr int result = seq_head::value; - CHECK(result == -1); + TEST_CASE("seq_tail") { + SUBCASE("seq_tail with non-empty sequence") { + using Seq = seq<1, 2, 3, 4>; + using ResultType = typename seq_tail::type; + using ExpectedType = seq<2, 3, 4>; + CHECK(std::is_same::value); + } + + SUBCASE("seq_tail with empty sequence") { + using Seq = seq<>; + using ResultType = typename seq_tail::type; + using ExpectedType = seq<>; + CHECK(std::is_same::value); + } } -} -TEST_CASE("seq_tail") { - SUBCASE("seq_tail with non-empty sequence") { - using Seq = seq<1, 2, 3, 4>; - using ResultType = typename seq_tail::type; - using ExpectedType = seq<2, 3, 4>; + TEST_CASE("seq_prepend") { + using ResultType = typename FlexFlow::seq_prepend<1, 2, 3>::type; + using ExpectedType = FlexFlow::seq<1, 2, 3>; CHECK(std::is_same::value); } - SUBCASE("seq_tail with empty sequence") { - using Seq = seq<>; - using ResultType = typename seq_tail::type; - using ExpectedType = seq<>; + TEST_CASE("seq_append") { + using Seq = seq<1, 2, 3>; + using ResultType = typename seq_append::type; + using ExpectedType = seq<1, 2, 3, 4>; CHECK(std::is_same::value); } -} -TEST_CASE("seq_prepend") { - using ResultType = typename FlexFlow::seq_prepend<1, 2, 3>::type; - using ExpectedType = FlexFlow::seq<1, 2, 3>; - CHECK(std::is_same::value); -} - -TEST_CASE("seq_append") { - using Seq = seq<1, 2, 3>; - using ResultType = typename seq_append::type; - using ExpectedType = seq<1, 2, 3, 4>; - CHECK(std::is_same::value); -} + TEST_CASE("seq_count") { + using ResultType = seq_count_t<5>; + using ExpectedType = seq<1, 2, 3, 4, 5>; + CHECK(!std::is_same::value); + } -TEST_CASE("seq_count") { - using ResultType = seq_count_t<5>; - using ExpectedType = seq<1, 2, 3, 4, 5>; - CHECK(!std::is_same::value); -} + TEST_CASE("seq_enumerate_args") { + using Args = std::tuple; + using ResultType = seq_enumerate_args_t; + using ExpectedType = seq<0, 1, 2>; + CHECK(std::is_same::value); + } -TEST_CASE("seq_enumerate_args") { - using Args = std::tuple; - using ResultType = seq_enumerate_args_t; - using ExpectedType = seq<0, 1, 2>; - CHECK(std::is_same::value); + // template + // int square(std::integral_constant) { + // return X * X; + // } + + // TEST_CASE("seq_select") { + // SUBCASE("Valid index") { + // using Seq = seq<1, 2, 3>; + // int result = seq_select(square, 1, seq<1, 2, 3>); + // CHECK(result == 4); + // } + + // SUBCASE("Invalid index") { + // using Seq = seq<1, 2, 3>; + // CHECK_THROWS_AS(seq_select(square, 3, Seq{}), std::runtime_error); + // } + // } + + // TEST_CASE("seq_get") { + // SUBCASE("Valid index") { + // using Seq = seq<1, 2, 3>; + // int result = seq_get(square, 2, Seq{}); + // CHECK(result == 9); + // } + + // SUBCASE("Invalid index") { + // using Seq = seq<1, 2, 3>; + // CHECK_THROWS_AS(seq_get(square, 3, Seq{}), std::runtime_error); + // } + // } + + // TEST_CASE("seq_get") { + // struct F { + // template + // int operator()(std::integral_constant) const { + // return X * X; + // } + // }; + + // SUBCASE("Valid index") { + // using Seq = seq<1, 2, 3>; + // int result = seq_get(F{}, 2, Seq{}); + // CHECK(result == 9); + // } + + // SUBCASE("Invalid index") { + // using Seq = seq<1, 2, 3>; + // CHECK_THROWS_AS(seq_get(F{}, 3, Seq{}), std::runtime_error); + // } + // } + + // struct F { + // template + // struct type { + // using result = std::integral_constant; + // }; + // }; + + // TEST_CASE("seq_transform_type") { + // using Seq = seq<1, 2, 3>; + // using ResultType = seq_transform_type_t; + // using ExpectedType = std::tuple, + // std::integral_constant, + // std::integral_constant>; + // CHECK(std::is_same::value); + // } + + // TEST_CASE("seq_transform") { + // struct F { + // template + // int operator()(std::integral_constant) { + // return X * X; + // } + // }; + + // using Seq = seq<1, 2, 3>; + // auto result = seq_transform(F{}, Seq{}); + // std::tuple expected{1, 4, 9}; + // CHECK(result == expected); + // } + + // TEST_CASE("seq_select") { + // struct F { + // template + // tl::optional operator()(std::integral_constant) { + // if (X % 2 == 0) { + // return X; + // } else { + // return tl::nullopt; + // } + // } + // }; + + // using Seq = seq<1, 2, 3, 4, 5>; + // int result = seq_select(F{}, Seq{}); + // CHECK(result == 2); + // } + + // TEST_CASE("seq_get") { + // struct F { + // template + // int operator()(std::integral_constant) { + // return X * X; + // } + // }; + + // using Seq = seq<1, 2, 3, 4, 5>; + // int result = seq_get(F{}, 3, Seq{}); + // CHECK(result == 16); + // } } - -// template -// int square(std::integral_constant) { -// return X * X; -// } - -// TEST_CASE("seq_select") { -// SUBCASE("Valid index") { -// using Seq = seq<1, 2, 3>; -// int result = seq_select(square, 1, seq<1, 2, 3>); -// CHECK(result == 4); -// } - -// SUBCASE("Invalid index") { -// using Seq = seq<1, 2, 3>; -// CHECK_THROWS_AS(seq_select(square, 3, Seq{}), std::runtime_error); -// } -// } - -// TEST_CASE("seq_get") { -// SUBCASE("Valid index") { -// using Seq = seq<1, 2, 3>; -// int result = seq_get(square, 2, Seq{}); -// CHECK(result == 9); -// } - -// SUBCASE("Invalid index") { -// using Seq = seq<1, 2, 3>; -// CHECK_THROWS_AS(seq_get(square, 3, Seq{}), std::runtime_error); -// } -// } - -// TEST_CASE("seq_get") { -// struct F { -// template -// int operator()(std::integral_constant) const { -// return X * X; -// } -// }; - -// SUBCASE("Valid index") { -// using Seq = seq<1, 2, 3>; -// int result = seq_get(F{}, 2, Seq{}); -// CHECK(result == 9); -// } - -// SUBCASE("Invalid index") { -// using Seq = seq<1, 2, 3>; -// CHECK_THROWS_AS(seq_get(F{}, 3, Seq{}), std::runtime_error); -// } -// } - -// struct F { -// template -// struct type { -// using result = std::integral_constant; -// }; -// }; - -// TEST_CASE("seq_transform_type") { -// using Seq = seq<1, 2, 3>; -// using ResultType = seq_transform_type_t; -// using ExpectedType = std::tuple, -// std::integral_constant, -// std::integral_constant>; -// CHECK(std::is_same::value); -// } - -// TEST_CASE("seq_transform") { -// struct F { -// template -// int operator()(std::integral_constant) { -// return X * X; -// } -// }; - -// using Seq = seq<1, 2, 3>; -// auto result = seq_transform(F{}, Seq{}); -// std::tuple expected{1, 4, 9}; -// CHECK(result == expected); -// } - -// TEST_CASE("seq_select") { -// struct F { -// template -// tl::optional operator()(std::integral_constant) { -// if (X % 2 == 0) { -// return X; -// } else { -// return tl::nullopt; -// } -// } -// }; - -// using Seq = seq<1, 2, 3, 4, 5>; -// int result = seq_select(F{}, Seq{}); -// CHECK(result == 2); -// } - -// TEST_CASE("seq_get") { -// struct F { -// template -// int operator()(std::integral_constant) { -// return X * X; -// } -// }; - -// using Seq = seq<1, 2, 3, 4, 5>; -// int result = seq_get(F{}, 3, Seq{}); -// CHECK(result == 16); -// } diff --git a/lib/utils/test/src/test_stack_map.cc b/lib/utils/test/src/test_stack_map.cc index 11d332afa4..21c1b07d1b 100644 --- a/lib/utils/test/src/test_stack_map.cc +++ b/lib/utils/test/src/test_stack_map.cc @@ -3,48 +3,50 @@ using namespace FlexFlow; -TEST_CASE("stack_map") { - stack_map map; - // Test the [] operator to insert and access elements - SUBCASE("BracketOperator") { - map[1] = 10; - map[2] = 20; - - CHECK(map[1] == 10); - CHECK(map[2] == 20); - } - - // Test the insert() function - SUBCASE("Insert") { - map.insert(1, 10); - map.insert(2, 20); - - CHECK(map[1] == 10); - CHECK(map[2] == 20); - } - - // Test the at() function to access elements - SUBCASE("At") { - map[1] = 10; - map[2] = 20; - - CHECK(map.at(1) == 10); - CHECK(map.at(2) == 20); - CHECK(map.at(1) != 20); - // Test const version of at() function - stack_map const &const_map = map; - CHECK(const_map.at(1) == 10); - CHECK(const_map.at(2) == 20); - } - - // Test the begin() and end() functions for iterator - SUBCASE("Iterator") { - map[1] = 10; - map[2] = 20; - map[3] = 30; - - std::vector> expected = {{1, 10}, {2, 20}, {3, 30}}; - std::vector> actual = map; - CHECK(actual == expected); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("stack_map") { + stack_map map; + // Test the [] operator to insert and access elements + SUBCASE("BracketOperator") { + map[1] = 10; + map[2] = 20; + + CHECK(map[1] == 10); + CHECK(map[2] == 20); + } + + // Test the insert() function + SUBCASE("Insert") { + map.insert(1, 10); + map.insert(2, 20); + + CHECK(map[1] == 10); + CHECK(map[2] == 20); + } + + // Test the at() function to access elements + SUBCASE("At") { + map[1] = 10; + map[2] = 20; + + CHECK(map.at(1) == 10); + CHECK(map.at(2) == 20); + CHECK(map.at(1) != 20); + // Test const version of at() function + stack_map const &const_map = map; + CHECK(const_map.at(1) == 10); + CHECK(const_map.at(2) == 20); + } + + // Test the begin() and end() functions for iterator + SUBCASE("Iterator") { + map[1] = 10; + map[2] = 20; + map[3] = 30; + + std::vector> expected = {{1, 10}, {2, 20}, {3, 30}}; + std::vector> actual = map; + CHECK(actual == expected); + } } } diff --git a/lib/utils/test/src/test_stack_string.cc b/lib/utils/test/src/test_stack_string.cc index 700b7d6a0f..1836e0824a 100644 --- a/lib/utils/test/src/test_stack_string.cc +++ b/lib/utils/test/src/test_stack_string.cc @@ -3,79 +3,81 @@ using namespace FlexFlow; -TEST_CASE_TEMPLATE("StackStringConstruction", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("StackStringConstruction", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - SUBCASE("DefaultConstruction") { - StackString str; - CHECK(str.size() == 0); - CHECK(str.length() == 0); - CHECK(static_cast(str) == ""); - } + SUBCASE("DefaultConstruction") { + StackString str; + CHECK(str.size() == 0); + CHECK(str.length() == 0); + CHECK(static_cast(str) == ""); + } - SUBCASE("CStringConstruction") { - char const *cstr = "Hello"; - StackString str(cstr); - CHECK(str.size() == 5); - CHECK(str.length() == 5); - CHECK(static_cast(str) == "Hello"); - } + SUBCASE("CStringConstruction") { + char const *cstr = "Hello"; + StackString str(cstr); + CHECK(str.size() == 5); + CHECK(str.length() == 5); + CHECK(static_cast(str) == "Hello"); + } - SUBCASE("ShortCStringConstruction") { - char const *cstr = "CMU"; - StackString str(cstr); - CHECK(str.size() == 3); - CHECK(str.length() == 3); - CHECK(static_cast(str) == "CMU"); - } + SUBCASE("ShortCStringConstruction") { + char const *cstr = "CMU"; + StackString str(cstr); + CHECK(str.size() == 3); + CHECK(str.length() == 3); + CHECK(static_cast(str) == "CMU"); + } - SUBCASE("StdStringConstruction") { - std::basic_string stdStr = "World"; - StackString str(stdStr); - CHECK(str.size() == 5); - CHECK(str.length() == 5); - CHECK(static_cast(str) == "World"); + SUBCASE("StdStringConstruction") { + std::basic_string stdStr = "World"; + StackString str(stdStr); + CHECK(str.size() == 5); + CHECK(str.length() == 5); + CHECK(static_cast(str) == "World"); + } } -} -TEST_CASE_TEMPLATE("StackStringComparison", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; + TEST_CASE_TEMPLATE("StackStringComparison", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - StackString str1{"abc"}; - StackString str2{"def"}; - StackString str3{"abc"}; + StackString str1{"abc"}; + StackString str2{"def"}; + StackString str3{"abc"}; - CHECK(str1 == str1); - CHECK(str1 == str3); - CHECK(str1 != str2); - CHECK(str2 != str3); - CHECK(str1 < str2); -} + CHECK(str1 == str1); + CHECK(str1 == str3); + CHECK(str1 != str2); + CHECK(str2 != str3); + CHECK(str1 < str2); + } -TEST_CASE_TEMPLATE("StackStringSize", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; + TEST_CASE_TEMPLATE("StackStringSize", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - SUBCASE("EmptyString") { - StackString str; - CHECK(str.size() == 0); - CHECK(str.length() == 0); - } + SUBCASE("EmptyString") { + StackString str; + CHECK(str.size() == 0); + CHECK(str.length() == 0); + } - SUBCASE("NonEmptyString") { - StackString str{"Hello"}; - CHECK(str.size() == 5); - CHECK(str.length() == 5); + SUBCASE("NonEmptyString") { + StackString str{"Hello"}; + CHECK(str.size() == 5); + CHECK(str.length() == 5); + } } -} -TEST_CASE_TEMPLATE("StackStringConversion", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; + TEST_CASE_TEMPLATE("StackStringConversion", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - StackString str{"Hello"}; - std::string stdStr = static_cast(str); - CHECK(stdStr == "Hello"); + StackString str{"Hello"}; + std::string stdStr = static_cast(str); + CHECK(stdStr == "Hello"); + } } diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/test_stack_vector.cc index 08101527f9..6c0ecf36f3 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/test_stack_vector.cc @@ -4,74 +4,76 @@ using namespace FlexFlow; -TEST_CASE_TEMPLATE("PushBack", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - vector.push_back(10); - std::vector res = vector; - std::vector expected = {10}; - CHECK(res == expected); - - vector.push_back(20); - expected = {10, 20}; - res = vector; - CHECK(res == expected); -} - -TEST_CASE_TEMPLATE("OperatorIndex", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - vector.push_back(10); - vector.push_back(20); - vector.push_back(30); - - CHECK(vector[0] == 10); - CHECK(vector[1] == 20); - CHECK(vector[2] == 30); -} - -TEST_CASE_TEMPLATE("Size", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - CHECK(vector.size() == 0); - - vector.push_back(10); - CHECK(vector.size() == 1); - - vector.push_back(20); - CHECK(vector.size() == 2); -} - -TEST_CASE_TEMPLATE("==", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector1, vector2; - - vector1.push_back(10); - vector1.push_back(15); - vector1.push_back(20); - - vector2.push_back(10); - vector2.push_back(15); - vector2.push_back(20); - - CHECK(vector1 == vector2); -} - -TEST_CASE_TEMPLATE("EmplaceBack", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - vector.push_back(10); - CHECK(vector.back() == 10); - - vector.push_back(20); - CHECK(vector.back() == 20); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("PushBack", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + vector.push_back(10); + std::vector res = vector; + std::vector expected = {10}; + CHECK(res == expected); + + vector.push_back(20); + expected = {10, 20}; + res = vector; + CHECK(res == expected); + } + + TEST_CASE_TEMPLATE("OperatorIndex", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + vector.push_back(10); + vector.push_back(20); + vector.push_back(30); + + CHECK(vector[0] == 10); + CHECK(vector[1] == 20); + CHECK(vector[2] == 30); + } + + TEST_CASE_TEMPLATE("Size", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + CHECK(vector.size() == 0); + + vector.push_back(10); + CHECK(vector.size() == 1); + + vector.push_back(20); + CHECK(vector.size() == 2); + } + + TEST_CASE_TEMPLATE("==", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector1, vector2; + + vector1.push_back(10); + vector1.push_back(15); + vector1.push_back(20); + + vector2.push_back(10); + vector2.push_back(15); + vector2.push_back(20); + + CHECK(vector1 == vector2); + } + + TEST_CASE_TEMPLATE("EmplaceBack", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + vector.push_back(10); + CHECK(vector.back() == 10); + + vector.push_back(20); + CHECK(vector.back() == 20); + } } diff --git a/lib/utils/test/src/test_tuple.cc b/lib/utils/test/src/test_tuple.cc index 344a2cd0fb..31308dec2c 100644 --- a/lib/utils/test/src/test_tuple.cc +++ b/lib/utils/test/src/test_tuple.cc @@ -6,74 +6,76 @@ using namespace FlexFlow; -TEST_CASE("get function") { - std::tuple t(42, 3.14f, 2.71828); - - SUBCASE("get mutable reference") { - int &result = get(t); - CHECK(result == 42); - - result = 100; - CHECK(std::get<0>(t) == 100); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get function") { + std::tuple t(42, 3.14f, 2.71828); + + SUBCASE("get mutable reference") { + int &result = get(t); + CHECK(result == 42); + + result = 100; + CHECK(std::get<0>(t) == 100); + } + + SUBCASE("get rvalue reference") { + int &&result = get(std::move(t)); + CHECK(result == 42); + + // t is in a valid but unspecified state after move + CHECK(std::get<0>(t) == 42); // Uncomment this line to check the behavior + } + + SUBCASE("get const reference") { + int const &result = get(t); + CHECK(result == 42); + } + + SUBCASE("get const rvalue reference") { + int const &&result = get(std::move(t)); + CHECK(result == 42); + } } - SUBCASE("get rvalue reference") { - int &&result = get(std::move(t)); - CHECK(result == 42); + TEST_CASE("tuple_prepend function") { + std::tuple t1(3.14f, 2.71828); + int value = 42; - // t is in a valid but unspecified state after move - CHECK(std::get<0>(t) == 42); // Uncomment this line to check the behavior + auto result = tuple_prepend(value, t1); + std::tuple expected(42, 3.14f, 2.71828); + CHECK(result == expected); } - SUBCASE("get const reference") { - int const &result = get(t); - CHECK(result == 42); + TEST_CASE("Testing tuple_head_t") { + CHECK(std::is_same>, + std::tuple>::value); + CHECK(std::is_same>, + std::tuple<>>::value); } - SUBCASE("get const rvalue reference") { - int const &&result = get(std::move(t)); - CHECK(result == 42); + TEST_CASE("Testing tuple_slice_t") { + CHECK(std::is_same>, + std::tuple>::value); + CHECK(std::is_same>, + std::tuple>::value); + CHECK(std::is_same>, + std::tuple>::value); } -} - -TEST_CASE("tuple_prepend function") { - std::tuple t1(3.14f, 2.71828); - int value = 42; - auto result = tuple_prepend(value, t1); - std::tuple expected(42, 3.14f, 2.71828); - CHECK(result == expected); -} - -TEST_CASE("Testing tuple_head_t") { - CHECK(std::is_same>, - std::tuple>::value); - CHECK(std::is_same>, - std::tuple<>>::value); -} + TEST_CASE("Testing tuple_compare function") { + std::tuple tup1{1, 3.14, 'a'}; + std::tuple tup2{1, 3.14, 'a'}; + std::tuple tup3{2, 3.14, 'b'}; -TEST_CASE("Testing tuple_slice_t") { - CHECK(std::is_same>, - std::tuple>::value); - CHECK(std::is_same>, - std::tuple>::value); - CHECK(std::is_same>, - std::tuple>::value); -} - -TEST_CASE("Testing tuple_compare function") { - std::tuple tup1{1, 3.14, 'a'}; - std::tuple tup2{1, 3.14, 'a'}; - std::tuple tup3{2, 3.14, 'b'}; - - CHECK(tuple_compare(tup1, tup2)); - CHECK(!tuple_compare(tup1, tup3)); -} + CHECK(tuple_compare(tup1, tup2)); + CHECK(!tuple_compare(tup1, tup3)); + } -TEST_CASE("Testing get function with valid index") { - std::tuple tup{1, 3.14, 'a'}; + TEST_CASE("Testing get function with valid index") { + std::tuple tup{1, 3.14, 'a'}; - CHECK(get(tup) == 1); - CHECK(get(tup) == 3.14); - CHECK(get(tup) == 'a'); + CHECK(get(tup) == 1); + CHECK(get(tup) == 3.14); + CHECK(get(tup) == 'a'); + } } diff --git a/lib/utils/test/src/test_type_index.cc b/lib/utils/test/src/test_type_index.cc index 1b9a811846..b2d8aea848 100644 --- a/lib/utils/test/src/test_type_index.cc +++ b/lib/utils/test/src/test_type_index.cc @@ -4,30 +4,32 @@ using namespace FlexFlow; -TEST_CASE("type_index function") { - SUBCASE("int type") { - std::type_index idx = type_index(); - std::type_index expected_idx = typeid(int); - CHECK(idx == expected_idx); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("type_index function") { + SUBCASE("int type") { + std::type_index idx = type_index(); + std::type_index expected_idx = typeid(int); + CHECK(idx == expected_idx); + } - SUBCASE("string type") { - std::type_index idx = type_index(); - std::type_index expected_idx = typeid(std::string); - CHECK(idx == expected_idx); + SUBCASE("string type") { + std::type_index idx = type_index(); + std::type_index expected_idx = typeid(std::string); + CHECK(idx == expected_idx); + } } -} -TEST_CASE("matches function") { - std::type_index idx = typeid(float); + TEST_CASE("matches function") { + std::type_index idx = typeid(float); - SUBCASE("matching type") { - bool result = matches(idx); - CHECK(result == true); - } + SUBCASE("matching type") { + bool result = matches(idx); + CHECK(result == true); + } - SUBCASE("non-matching type") { - bool result = matches(idx); - CHECK(result == false); + SUBCASE("non-matching type") { + bool result = matches(idx); + CHECK(result == false); + } } } diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/test_undirected_graph.cc index c6f2003ee4..a60a330ad3 100644 --- a/lib/utils/test/src/test_undirected_graph.cc +++ b/lib/utils/test/src/test_undirected_graph.cc @@ -31,30 +31,32 @@ using namespace rc; /* static_assert(is_streamable::value, ""); */ /* static_assert(is_fmtable::value, ""); */ -TEST_CASE_TEMPLATE("UndirectedGraph implementations", - T, - HashmapUndirectedGraph) { - - rc::dc_check("Full", [&]() { - UndirectedGraph g = UndirectedGraph::create(); - int num_nodes = *gen::inRange(1, 10); - std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); - int num_edges = *gen::inRange(0, num_nodes); - std::vector e; - if (num_nodes > 0) { - e = *gen::unique>( - num_edges, - gen::construct(gen::elementOf(n), gen::elementOf(n))); - } - for (UndirectedEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(NodeQuery::all()) == without_order(n)); - - auto subset = *rc::subset_of(n); - CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); - - CHECK(g.query_edges(UndirectedEdgeQuery::all()) == without_order(e)); - }); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("UndirectedGraph implementations", + T, + HashmapUndirectedGraph) { + + rc::dc_check("Full", [&]() { + UndirectedGraph g = UndirectedGraph::create(); + int num_nodes = *gen::inRange(1, 10); + std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); + int num_edges = *gen::inRange(0, num_nodes); + std::vector e; + if (num_nodes > 0) { + e = *gen::unique>( + num_edges, + gen::construct(gen::elementOf(n), gen::elementOf(n))); + } + for (UndirectedEdge const &edge : e) { + g.add_edge(edge); + } + + CHECK(g.query_nodes(NodeQuery::all()) == without_order(n)); + + auto subset = *rc::subset_of(n); + CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); + + CHECK(g.query_edges(UndirectedEdgeQuery::all()) == without_order(e)); + }); + } } diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 541ff40920..f7d08889de 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -1,69 +1,71 @@ #include "test/utils/doctest.h" #include "utils/variant.h" -TEST_CASE("widen and narrow functions") { - SUBCASE("widen function") { - std::variant v1 = 42; - std::variant result = - widen>(v1); - std::variant expected = 42; - CHECK(result == expected); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("widen and narrow functions") { + SUBCASE("widen function") { + std::variant v1 = 42; + std::variant result = + widen>(v1); + std::variant expected = 42; + CHECK(result == expected); + } - SUBCASE("narrow function fail") { - std::variant v2 = - 3.14; // this is a doule, because 3.14 default to double - std::optional> result = - narrow>(v2); - std::optional> expected = float(3.14); - CHECK(!result.has_value()); // result should be empty due to narrowing - } + SUBCASE("narrow function fail") { + std::variant v2 = + 3.14; // this is a doule, because 3.14 default to double + std::optional> result = + narrow>(v2); + std::optional> expected = float(3.14); + CHECK(!result.has_value()); // result should be empty due to narrowing + } - SUBCASE("narrow function success") { - std::variant v2 = - 3.14; // this is a doule, because 3.14 default to double - std::optional> result = - narrow>(v2); - std::optional> expected = 3.14; - CHECK(result == expected); // - } + SUBCASE("narrow function success") { + std::variant v2 = + 3.14; // this is a doule, because 3.14 default to double + std::optional> result = + narrow>(v2); + std::optional> expected = 3.14; + CHECK(result == expected); // + } - SUBCASE("cast function") { - std::variant v3 = 42; - std::optional> result = - cast>(v3); - std::optional> expected = 42; - CHECK(result == expected); + SUBCASE("cast function") { + std::variant v3 = 42; + std::optional> result = + cast>(v3); + std::optional> expected = 42; + CHECK(result == expected); + } } -} -TEST_CASE("Narrow and cast variants") { - std::variant original_variant = 42; + TEST_CASE("Narrow and cast variants") { + std::variant original_variant = 42; - // narrow - std::optional> narrow_result = - narrow>(original_variant); - CHECK(narrow_result.has_value()); // assert narrow has value + // narrow + std::optional> narrow_result = + narrow>(original_variant); + CHECK(narrow_result.has_value()); // assert narrow has value - // cast - std::optional> cast_result = - cast>(narrow_result.value()); - CHECK(cast_result.has_value()); // assert cast has value - CHECK(get(cast_result.value()) == 42); -} + // cast + std::optional> cast_result = + cast>(narrow_result.value()); + CHECK(cast_result.has_value()); // assert cast has value + CHECK(get(cast_result.value()) == 42); + } -TEST_CASE("casting and widening a variant") { - std::variant smaller_variant = 42; - std::variant wider_variant; + TEST_CASE("casting and widening a variant") { + std::variant smaller_variant = 42; + std::variant wider_variant; - // Perform the cast operation - std::optional> cast_result = - cast>(smaller_variant); - REQUIRE(cast_result); // Ensure the cast was successful + // Perform the cast operation + std::optional> cast_result = + cast>(smaller_variant); + REQUIRE(cast_result); // Ensure the cast was successful - // Perform the widening operation - wider_variant = widen>(cast_result.value()); + // Perform the widening operation + wider_variant = widen>(cast_result.value()); - // Check the result - CHECK(get(wider_variant) == 42); + // Check the result + CHECK(get(wider_variant) == 42); + } } diff --git a/lib/utils/test/src/test_vector.cc b/lib/utils/test/src/test_vector.cc index 5eba16c312..4bdc724dd8 100644 --- a/lib/utils/test/src/test_vector.cc +++ b/lib/utils/test/src/test_vector.cc @@ -1,29 +1,31 @@ #include "test/utils/doctest.h" #include "utils/vector.h" -TEST_CASE("concat function") { - SUBCASE("concatenates two vectors") { - std::vector v1 = {1, 2, 3}; - std::vector v2 = {4, 5, 6}; - std::vector result = concat(v1, v2); - std::vector expected = {1, 2, 3, 4, 5, 6}; - CHECK(result == expected); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("concat function") { + SUBCASE("concatenates two vectors") { + std::vector v1 = {1, 2, 3}; + std::vector v2 = {4, 5, 6}; + std::vector result = concat(v1, v2); + std::vector expected = {1, 2, 3, 4, 5, 6}; + CHECK(result == expected); + } - SUBCASE("concatenates two string vectors") { - std::vector v1 = {"1", "2", "3"}; - std::vector v2 = {"4", "5", "6"}; - std::vector result = concat(v1, v2); - std::vector expected = {"1", "2", "3", "4", "5", "6"}; - CHECK(result == expected); - } + SUBCASE("concatenates two string vectors") { + std::vector v1 = {"1", "2", "3"}; + std::vector v2 = {"4", "5", "6"}; + std::vector result = concat(v1, v2); + std::vector expected = {"1", "2", "3", "4", "5", "6"}; + CHECK(result == expected); + } - SUBCASE("concatenates multiple vectors") { - std::vector v1 = {1, 2, 3}; - std::vector v2 = {4, 5, 6}; - std::vector v3 = {7, 8, 9}; - std::vector result = concat(v1, v2, v3); - std::vector expected = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - CHECK(result == expected); + SUBCASE("concatenates multiple vectors") { + std::vector v1 = {1, 2, 3}; + std::vector v2 = {4, 5, 6}; + std::vector v3 = {7, 8, 9}; + std::vector result = concat(v1, v2, v3); + std::vector expected = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + CHECK(result == expected); + } } } From da7481790b5cfdb9a2ed4b30f0840fb7bdbef97f Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 26 Mar 2024 16:39:53 -0700 Subject: [PATCH 30/32] Remove unnecessary nix files, add utils test to ci --- .flake/pkgs/fmt.nix | 73 ----------------------------- .flake/pkgs/rapidcheck.nix | 48 ------------------- .flake/pkgs/tokenizers-cpp.nix | 43 ----------------- .github/workflows/per-lib-check.yml | 4 ++ 4 files changed, 4 insertions(+), 164 deletions(-) delete mode 100644 .flake/pkgs/fmt.nix delete mode 100644 .flake/pkgs/rapidcheck.nix delete mode 100644 .flake/pkgs/tokenizers-cpp.nix diff --git a/.flake/pkgs/fmt.nix b/.flake/pkgs/fmt.nix deleted file mode 100644 index e2677bdea2..0000000000 --- a/.flake/pkgs/fmt.nix +++ /dev/null @@ -1,73 +0,0 @@ -{ lib -, stdenv -, fetchFromGitHub, fetchpatch -, cmake -, enableShared ? !stdenv.hostPlatform.isStatic - -# tests -, mpd -, openimageio -, fcitx5 -, spdlog -}: - -let - generic = { version, sha256, patches ? [ ] }: - stdenv.mkDerivation { - pname = "fmt"; - inherit version; - - outputs = [ "out" "dev" ]; - - src = fetchFromGitHub { - owner = "fmtlib"; - repo = "fmt"; - rev = version; - inherit sha256; - }; - - inherit patches; - - nativeBuildInputs = [ cmake ]; - - cmakeFlags = [ - "-DBUILD_SHARED_LIBS=${if enableShared then "ON" else "OFF"}" - ]; - - doCheck = true; - - passthru.tests = { - inherit mpd openimageio fcitx5 spdlog; - }; - - meta = with lib; { - description = "Small, safe and fast formatting library"; - longDescription = '' - fmt (formerly cppformat) is an open-source formatting library. It can be - used as a fast and safe alternative to printf and IOStreams. - ''; - homepage = "https://fmt.dev/"; - changelog = "https://github.com/fmtlib/fmt/blob/${version}/ChangeLog.rst"; - downloadPage = "https://github.com/fmtlib/fmt/"; - maintainers = [ maintainers.jdehaas ]; - license = licenses.mit; - platforms = platforms.all; - }; - }; -in -{ - fmt_8 = generic { - version = "8.1.1"; - sha256 = "sha256-leb2800CwdZMJRWF5b1Y9ocK0jXpOX/nwo95icDf308="; - }; - - fmt_9 = generic { - version = "9.1.0"; - sha256 = "sha256-rP6ymyRc7LnKxUXwPpzhHOQvpJkpnRFOt2ctvUNlYI0="; - }; - - fmt_10 = generic { - version = "10.1.1"; - sha256 = "sha256-H9+1lEaHM12nzXSmo9m8S6527t+97e6necayyjCPm1A="; - }; -} diff --git a/.flake/pkgs/rapidcheck.nix b/.flake/pkgs/rapidcheck.nix deleted file mode 100644 index 3ff63207b2..0000000000 --- a/.flake/pkgs/rapidcheck.nix +++ /dev/null @@ -1,48 +0,0 @@ -{ lib -, stdenv -, fetchFromGitHub -, cmake -, unstableGitUpdater -, testers -}: - -stdenv.mkDerivation (finalAttrs: { - pname = "rapidcheck"; - version = "unstable-2023-12-14"; - - src = fetchFromGitHub { - owner = "emil-e"; - repo = "rapidcheck"; - rev = "ff6af6fc683159deb51c543b065eba14dfcf329b"; - hash = "sha256-Ixz5RpY0n8Un/Pv4XoTfbs40+70iyMbkQUjDqoLaWOg="; - }; - - nativeBuildInputs = [ cmake ]; - - cmakeFlags = [ - (lib.cmakeBool "BUILD_SHARED_LIBS" (!stdenv.hostPlatform.isStatic)) - (lib.cmakeBool "RC_INSTALL_ALL_EXTRAS" true) - ]; - - passthru = { - updateScript = unstableGitUpdater { }; - tests.pkg-config = testers.testMetaPkgConfig finalAttrs.finalPackage; - }; - - meta = with lib; { - description = "A C++ framework for property based testing inspired by QuickCheck"; - inherit (finalAttrs.src.meta) homepage; - maintainers = with maintainers; [ ]; - license = licenses.bsd2; - pkgConfigModules = [ - "rapidcheck" - # Extras - "rapidcheck_boost" - "rapidcheck_boost_test" - "rapidcheck_catch" - "rapidcheck_doctest" - "rapidcheck_gtest" - ]; - platforms = platforms.all; - }; -}) diff --git a/.flake/pkgs/tokenizers-cpp.nix b/.flake/pkgs/tokenizers-cpp.nix deleted file mode 100644 index a705667ae6..0000000000 --- a/.flake/pkgs/tokenizers-cpp.nix +++ /dev/null @@ -1,43 +0,0 @@ -{ lib -, stdenv -, fetchFromGitHub -, cmake -, rustc -, cargo -}: - -stdenv.mkDerivation rec { - pname = "tokenizers-cpp"; - version = "2024-03-13"; - - src = fetchFromGitHub { - owner = "mlc-ai"; - repo = "tokenizers-cpp"; - rev = "4f42c9fa74946d70af86671a3804b6f2433e5dac"; - sha256 = "sha256-p7OYx9RVnKUAuMexy3WjW2zyfMJ/Q9ss4xFLsbQK7wA="; - fetchSubmodules = true; - }; - - nativeBuildInputs = [ - cmake - rustc - ]; - - # cmakeFlags = [ - # "-DLegion_USE_Python=1" - # "-DLegion_BUILD_BINDINGS=1" - # "-DLegion_USE_CUDA=1" - # "-DLegion_CUDA_ARCH=${lib.concatStringsSep "," cudaCapabilities}" - # ]; - - buildInputs = [ ]; - # python3 - # cudatoolkit - # ]; - - meta = with lib; { - description = "Universal cross-platform tokenizers binding to HF and sentencepiece"; - homepage = "https://github.com/mlc-ai/tokenizers-cpp"; - license = licenses.asl20; - }; -} diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index f1d069f252..874a298587 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -80,6 +80,10 @@ jobs: run: | build_libs.sh compiler + - name: Test utils + run: | + test_libs.sh utils + - name: Test substitutions run: | test_libs.sh substitutions From 0db60db6e5c6a53460a209ea72f9c70bd63caccb Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 26 Mar 2024 16:46:21 -0700 Subject: [PATCH 31/32] Fix utils tests name, format --- .../test/src/test_labelled_open_graph.cc | 6 +++-- lib/compiler/test/src/test_optimal_cost.cc | 3 ++- lib/compiler/test/src/test_unity_algorithm.cc | 2 +- .../test/src/test_pattern_matches.cc | 6 +++-- .../test/src/test_substitution.cc | 25 +++++++++++-------- lib/utils/test/CMakeLists.txt | 2 +- lib/utils/test/src/test_algorithms.cc | 8 +++--- lib/utils/test/src/test_containers.cc | 4 ++- lib/utils/test/src/test_disjoint_set.cc | 5 ++-- lib/utils/test/src/test_multidigraph.cc | 10 +++++--- lib/utils/test/src/test_undirected_graph.cc | 8 +++--- lib/utils/test/src/test_variant.cc | 3 ++- 12 files changed, 48 insertions(+), 34 deletions(-) diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index e3498a769a..ccad7b19ff 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -43,7 +43,8 @@ TEST_SUITE(FF_TEST_SUITE) { auto subgraph0 = get_subgraph(g, node_set0); auto subgraph1 = get_subgraph(g, node_set0); - auto subgraph2 = get_subgraph(g, node_set0); + auto subgraph2 = + get_subgraph(g, node_set0); auto subgraph3 = get_subgraph(g, node_set0); CHECK(bool(get_nodes(subgraph0) == node_set0)); @@ -73,7 +74,8 @@ TEST_SUITE(FF_TEST_SUITE) { split_edge(e2).second, split_edge(e3).second, e4})); CHECK(bool(get_edges(subgraph2) == std::unordered_set{e4, e5})); - CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); + CHECK( + bool(get_edges(subgraph3) == std::unordered_set{e4})); CHECK(bool(get_closed_sources(subgraph2) == std::unordered_set{n3})); } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index da303e3ccc..91c7a11888 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -8,7 +8,8 @@ TEST_SUITE(FF_TEST_SUITE) { // Rapidcheck infrastructures for graphs does not work for now /* Tests whether optimal_cost can give a valid result given random PCG, trivial - allowed machine views, trivial cost estimator and random machine specification. + allowed machine views, trivial cost estimator and random machine + specification. */ // TEST_CASE("optimal_cost") { // auto test_allowed_machine_views = [](Operator const &, diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc index b8fde91c51..614e9bb182 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/test_unity_algorithm.cc @@ -1,7 +1,7 @@ #include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" #include "test_cost_estimator.h" #include "test_generator.h" -#include "doctest/doctest.h" TEST_SUITE(FF_TEST_SUITE) { // Rapidcheck does not work for now diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index f1abd5c17e..5d72bbff7e 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -102,8 +102,10 @@ TEST_SUITE(FF_TEST_SUITE) { RC_ASSERT(matches.size() == 3); for (MultiDiGraphPatternMatch const &match : matches) { - RC_ASSERT(pattern_matches( - as_openmultidigraph(sg0), as_openmultidigraph(g), match, always_true)); + RC_ASSERT(pattern_matches(as_openmultidigraph(sg0), + as_openmultidigraph(g), + match, + always_true)); } } } diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index 86ee087a29..df22d8a620 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -12,18 +12,20 @@ TEST_SUITE(FF_TEST_SUITE) { ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; ParallelTensorPattern tensor_pattern_e0{ - std::vector{TensorAttributeConstraint{ - ConstraintType::EQUAL, - ListIndexAccess{TensorAttributeKey::DIM_SIZES, 0}, - 2}}}; + std::vector{ + TensorAttributeConstraint{ConstraintType::EQUAL, + ListIndexAccess{ + TensorAttributeKey::DIM_SIZES, 0}, + 2}}}; ParallelTensorPattern tensor_pattern_empty{ std::vector{}}; - auto ig = OutputLabelledOpenMultiDiGraph:: - create>(); + auto ig = + OutputLabelledOpenMultiDiGraph:: + create>(); Node n0 = ig.add_node(operator_pattern_n0); NodePort p0 = ig.add_node_port(); InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; @@ -86,7 +88,8 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph pcg = OutputLabelledOpenMultiDiGraph::create< - UnorderedOutputLabelledOpenMultiDiGraph>(); + UnorderedOutputLabelledOpenMultiDiGraph>(); Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); Node n5 = pcg.add_node(Operator{ @@ -109,8 +112,8 @@ TEST_SUITE(FF_TEST_SUITE) { }, [&](OpenMultiDiEdge const &pattern_edge, OpenMultiDiEdge const &graph_edge) { - return parallel_tensor_satisfies(pcg.at(graph_edge), - input_graph.value().at(pattern_edge)); + return parallel_tensor_satisfies( + pcg.at(graph_edge), input_graph.value().at(pattern_edge)); }}; RC_ASSERT(criterion.node_criterion(n0, n5)); diff --git a/lib/utils/test/CMakeLists.txt b/lib/utils/test/CMakeLists.txt index 97253b4ab7..40ff07285e 100644 --- a/lib/utils/test/CMakeLists.txt +++ b/lib/utils/test/CMakeLists.txt @@ -1,6 +1,6 @@ ff_add_test_executable( NAME - utils-test + utils-tests SRC_PATTERNS src/test_cow_ptr.cc PRIVATE_INCLUDE diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index d3236a7b1c..0fb258bf15 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -109,7 +109,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("traversal") { DiGraph g = DiGraph::create(); std::vector const n = add_nodes(g, 5); - std::vector edges = {{n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; + std::vector edges = { + {n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; add_edges(g, edges); CHECK(get_sources(g) == std::unordered_set{n[0], n[4]}); @@ -138,7 +139,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nonlinear") { g.add_edge({n[1], n[3]}); - CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs + CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs } SUBCASE("not connected") { @@ -168,7 +169,8 @@ TEST_SUITE(FF_TEST_SUITE) { auto CHECK_BEFORE = [&](int l, int r) { CHECK(index_of(ordering, n[l]).has_value()); CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]).value() < index_of(ordering, n[r]).value()); + CHECK(index_of(ordering, n[l]).value() < + index_of(ordering, n[r]).value()); }; CHECK(ordering.size() == n.size()); diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc index f6ac6e2d42..a6776d492e 100644 --- a/lib/utils/test/src/test_containers.cc +++ b/lib/utils/test/src/test_containers.cc @@ -30,7 +30,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("sum with condition") { std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { return x % 2 == 0; }; // Sum of even numbers only + auto condition = [](int x) { + return x % 2 == 0; + }; // Sum of even numbers only CHECK(sum_where(v, condition) == 6); } diff --git a/lib/utils/test/src/test_disjoint_set.cc b/lib/utils/test/src/test_disjoint_set.cc index 8bcf2e533f..80fcf87d6b 100644 --- a/lib/utils/test/src/test_disjoint_set.cc +++ b/lib/utils/test/src/test_disjoint_set.cc @@ -62,9 +62,8 @@ TEST_SUITE(FF_TEST_SUITE) { ds.get_mapping(); for (auto const &kv : mapping) { - CHECK( - *kv.second == - *expectedMapping[kv.first]); // Compare the values inside the optionals + CHECK(*kv.second == *expectedMapping[kv.first]); // Compare the values + // inside the optionals } } } diff --git a/lib/utils/test/src/test_multidigraph.cc b/lib/utils/test/src/test_multidigraph.cc index 91631f0391..90e1bb2187 100644 --- a/lib/utils/test/src/test_multidigraph.cc +++ b/lib/utils/test/src/test_multidigraph.cc @@ -41,10 +41,12 @@ TEST_SUITE(FF_TEST_SUITE) { {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs(query_set( - {p[1], p[2]}))) == std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs(query_set( - {p[0], p[2]}))) == std::unordered_set{e[1], e[2]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs( + query_set({p[1], p[2]}))) == + std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs( + query_set({p[0], p[2]}))) == + std::unordered_set{e[1], e[2]}); CHECK(g.query_edges(MultiDiEdgeQuery::all() .with_src_nodes({n[1]}) .with_dst_nodes({n[2]}) diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/test_undirected_graph.cc index a60a330ad3..3616ee59aa 100644 --- a/lib/utils/test/src/test_undirected_graph.cc +++ b/lib/utils/test/src/test_undirected_graph.cc @@ -32,9 +32,8 @@ using namespace rc; /* static_assert(is_fmtable::value, ""); */ TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("UndirectedGraph implementations", - T, - HashmapUndirectedGraph) { + TEST_CASE_TEMPLATE( + "UndirectedGraph implementations", T, HashmapUndirectedGraph) { rc::dc_check("Full", [&]() { UndirectedGraph g = UndirectedGraph::create(); @@ -45,7 +44,8 @@ TEST_SUITE(FF_TEST_SUITE) { if (num_nodes > 0) { e = *gen::unique>( num_edges, - gen::construct(gen::elementOf(n), gen::elementOf(n))); + gen::construct(gen::elementOf(n), + gen::elementOf(n))); } for (UndirectedEdge const &edge : e) { g.add_edge(edge); diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index f7d08889de..0fef782c0e 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -63,7 +63,8 @@ TEST_SUITE(FF_TEST_SUITE) { REQUIRE(cast_result); // Ensure the cast was successful // Perform the widening operation - wider_variant = widen>(cast_result.value()); + wider_variant = + widen>(cast_result.value()); // Check the result CHECK(get(wider_variant) == 42); From c21d66eb7591d2072c9006d12d40560eef2310ab Mon Sep 17 00:00:00 2001 From: Bob Chen <70640928+Bob-Chen222@users.noreply.github.com> Date: Sun, 31 Mar 2024 22:52:32 -0400 Subject: [PATCH 32/32] add tutorial --- lib/substitutions/TUTORIAL.md | 206 ++++++++++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 lib/substitutions/TUTORIAL.md diff --git a/lib/substitutions/TUTORIAL.md b/lib/substitutions/TUTORIAL.md new file mode 100644 index 0000000000..bcf39da603 --- /dev/null +++ b/lib/substitutions/TUTORIAL.md @@ -0,0 +1,206 @@ +## Tutorial of substitution lib with simple example + +#### Create a pattern + +```c++ +//we should specify both the node pattern and edge pattern when defining a GraphPattern + +//first define an operator pattern for example, specify the node to have a linear +//operator +OperatorPattern operator_pattern_n0{ + std::vector{OperatorAttributeConstraint{ + ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; + +//then define a tensor_pattern that restrict the pattern of edge in pcg. for example, +//specify that the first dimension (indexed by 0) of a tensor should be 2 +ParallelTensorPattern tensor_pattern_e0{ + std::vector{ + TensorAttributeConstraint{ConstraintType::EQUAL, + ListIndexAccess{ + TensorAttributeKey::DIM_SIZES, 0}, + 2}}}; +/* +remeber that both operator_pattern and tensor_pattern are std::vector, meaning that you +can define more than one constraint depending on the context +*/ +``` + + +#### Pack into GraphPattern +```c++ +//create a graph with node label of OperatorPattern and edge label of ParallelTensorPattern +auto ig = + OutputLabelledOpenMultiDiGraph:: + create>(); +//add constraints defined above as argument to create a node +Node n0 = ig.add_node(operator_pattern_n0); +//add port number to distinguish different edges going to the same node +NodePort p0 = ig.add_node_port(); +//create edge +InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; +ig.add_edge(e0); +//add edge constraints above to the edge e0 +ig.add_label(e0, tensor_pattern_e0); + +//a pattern graph with one input edge pointing to a node +/* + n0 (Linear) + ↑ +*/ +RC_ASSERT(get_nodes(ig).size() == 1); +RC_ASSERT(get_edges(ig).size() == 1); +``` + +#### Define OutputGraph +```cpp + +//define a 3-node PCG that can be applied from the input graph ig + +//Partition node that can partite the input into two parts +OperatorAttrAssignment op_ass_n1{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REPARTITION}}, + {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, + {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; + +//Linear node +OperatorAttrAssignment op_ass_n2{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::LINEAR}}, + {OperatorAttributeKey::OUT_CHANNELS, + OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, + {OperatorAttributeKey::USE_BIAS, + OperatorAttrAccess{n0, OperatorAttributeKey::USE_BIAS}}, + {OperatorAttributeKey::DATA_TYPE, + OperatorAttrAccess{n0, OperatorAttributeKey::DATA_TYPE}}, + {OperatorAttributeKey::ACTIVATION, + OperatorAttrAccess{n0, OperatorAttributeKey::ACTIVATION}}, + {OperatorAttributeKey::REGULARIZER, + OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; + +//Reduce node that will combine the result of two partitions +OperatorAttrAssignment op_ass_n3{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REDUCTION}}, + {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, + {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; + +//notice that these assignments will be evaluated +//into new operators in the apply_substitution function +//and be inserted into the new pcg + +//create outputgraph with 3 nodes and 3 edges +auto og = NodeLabelledOpenMultiDiGraph::create< + UnorderedNodeLabelledOpenMultiDiGraph>(); +Node n1 = og.add_node(op_ass_n1); +Node n2 = og.add_node(op_ass_n2); +Node n3 = og.add_node(op_ass_n3); +NodePort p1 = og.add_node_port(); +NodePort p2 = og.add_node_port(); +NodePort p3 = og.add_node_port(); + +InputMultiDiEdge e1{n1, p1, {p1.value(), p1.value()}}; +MultiDiEdge e2{n2, p2, n1, p1}; +MultiDiEdge e3{n3, p3, n2, p2}; +og.add_edge(e1); +og.add_edge(e2); +og.add_edge(e3); +OutputGraphExpr output_graph_expr{og}; + +/* +The output graph looks like this + n3 (Reduce) + ↑ + n2 (Linear) + ↑ + n1 (Partition) + ↑ +*/ +RC_ASSERT(get_nodes(og).size() == 3); +RC_ASSERT(get_edges(og).size() == 3); +``` + +#### Define substitution +```cpp +//define two dict that specify how the input and output edges are mapped in the substitution +bidict input_mapping; +input_mapping.equate(e0, e1); +bidict output_mapping; + +Substitution substitution{ + input_graph, output_graph_expr, input_mapping, output_mapping}; +``` + +#### Apply substitution +```cpp + +//create the target pcg that we want to apply for substitution +SubParallelComputationGraph pcg = + OutputLabelledOpenMultiDiGraph::create< + UnorderedOutputLabelledOpenMultiDiGraph>(); + +Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); +Node n5 = pcg.add_node(Operator{ + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, + "linear"}); +NodePort p4 = pcg.add_node_port(); +NodePort p5 = pcg.add_node_port(); + +MultiDiEdge e4{n5, p5, n4, p4}; +pcg.add_edge(e4); +pcg.add_label(e4, + ParallelTensor(ParallelTensorDims({2, 1}), + DataType::FLOAT, + CreateGrad::YES)); + +/* Our target pcg looks like this + n5 (Linear) + ↑ + n4 (input) +*/ + +//create criterion function that will test every predefined edge and node constraints +MatchAdditionalCriterion criterion{ + [&](Node const &pattern_node, Node const &graph_node) { + return operator_satisfies(pcg.at(graph_node), + input_graph.value().at(pattern_node)); + }, + [&](OpenMultiDiEdge const &pattern_edge, + OpenMultiDiEdge const &graph_edge) { + return parallel_tensor_satisfies( + pcg.at(graph_edge), input_graph.value().at(pattern_edge)); + }}; + +RC_ASSERT(criterion.node_criterion(n0, n5)); + + +//find the match point that we can apply the substitution in the target pcg +std::vector matches = + find_pattern_matches(input_graph, pcg, criterion); + +//there is only one match point in the pcg that we defined +RC_ASSERT(matches.size() == 1); + +//apply substitution +//the number of new pcg generated is bounded by O(2^(sn))where s is the number of +//different substitutions and n is the number of nodes +SubParallelComputationGraph new_pcg = + apply_substitution(pcg, substitution, matches[0]); + +//now the new pcg becomes as follow +/* + n3 (Reduce) + ↑ + n2 (Linear) + ↑ + n1 (Partition) + ↑ + n4 (Input) +*/ +RC_ASSERT(get_nodes(new_pcg).size() == 4); +RC_ASSERT(get_edges(new_pcg).size() == 3); +``` + + + +