diff --git a/bin/sp_ization_benchmarking/nasnet_bench_graph_generator.h b/bin/sp_ization_benchmarking/nasnet_bench_graph_generator.h index 71c942976f..0946794b39 100644 --- a/bin/sp_ization_benchmarking/nasnet_bench_graph_generator.h +++ b/bin/sp_ization_benchmarking/nasnet_bench_graph_generator.h @@ -12,7 +12,7 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/digraph_generation.h" +#include "utils/graph/series_parallel/digraph_generation.h" #include #include @@ -118,7 +118,7 @@ DiGraph generate_nasnet_bench_cell() { } DiGraph generate_nasnet_bench_network() { - DiGraph g = serial_composition( + DiGraph g = series_composition( transform(repeat(NUM_CELLS, generate_nasnet_bench_cell), [](auto const cell) -> DiGraphView { return cell; })); return g; diff --git a/bin/sp_ization_benchmarking/sample_graphs.h b/bin/sp_ization_benchmarking/sample_graphs.h index a3286f3337..709258e502 100644 --- a/bin/sp_ization_benchmarking/sample_graphs.h +++ b/bin/sp_ization_benchmarking/sample_graphs.h @@ -10,7 +10,7 @@ #include "utils/graph/digraph/algorithms/is_acyclic.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/serial_parallel/digraph_generation.h" +#include "utils/graph/series_parallel/digraph_generation.h" #include namespace FlexFlow { diff --git a/bin/sp_ization_benchmarking/sp_ization_benchmarking.cc b/bin/sp_ization_benchmarking/sp_ization_benchmarking.cc index 5febc22474..7fd9064688 100644 --- a/bin/sp_ization_benchmarking/sp_ization_benchmarking.cc +++ b/bin/sp_ization_benchmarking/sp_ization_benchmarking.cc @@ -29,10 +29,10 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/digraph/digraph_view.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_metrics.h" -#include "utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.h" -#include "utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_metrics.h" +#include "utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.h" +#include "utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.h" #include #include #include @@ -56,10 +56,10 @@ CombinedResult perform_benchmark_given_graph(DiGraphView const &g, for (int i = 0; i < repeat; i++) { auto cost_map = make_cost_map(get_nodes(g), Dist); - SerialParallelDecomposition sp1 = + SeriesParallelDecomposition sp1 = critical_path_preserving_sp_ization_with_coalescing(g); - SerialParallelDecomposition sp2 = stratum_sync_sp_ization(g); - SerialParallelDecomposition sp3 = + SeriesParallelDecomposition sp2 = stratum_sync_sp_ization(g); + SeriesParallelDecomposition sp3 = cost_aware_stratum_sync_sp_ization(g, cost_map); auto noisy_cost_map = add_noise_to_cost_map(cost_map, Noise); @@ -108,10 +108,10 @@ CombinedResult DiGraphView g = graph_generator(); auto cost_map = make_cost_map(get_nodes(g), Dist); - SerialParallelDecomposition sp1 = + SeriesParallelDecomposition sp1 = critical_path_preserving_sp_ization_with_coalescing(g); - SerialParallelDecomposition sp2 = stratum_sync_sp_ization(g); - SerialParallelDecomposition sp3 = + SeriesParallelDecomposition sp2 = stratum_sync_sp_ization(g); + SeriesParallelDecomposition sp3 = cost_aware_stratum_sync_sp_ization(g, cost_map); auto noisy_cost_map = add_noise_to_cost_map(cost_map, Noise); diff --git a/lib/utils/include/utils/graph/digraph/algorithms/transitive_reduction.h b/lib/utils/include/utils/graph/digraph/algorithms/transitive_reduction.h index b4cdc62f83..ad11c6388c 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/transitive_reduction.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/transitive_reduction.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_REDUCTION_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_REDUCTION_H +#include "utils/graph/digraph/digraph.h" #include "utils/graph/digraph/digraph_view.h" namespace FlexFlow { @@ -21,7 +22,7 @@ struct DirectedEdgeMaskView final : public IDiGraphView { std::unordered_set edge_mask; }; -DiGraphView transitive_reduction(DiGraphView const &); +DiGraph transitive_reduction(DiGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h deleted file mode 100644 index b67783c888..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GET_SERIAL_PARALLEL_DECOMPOSITION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GET_SERIAL_PARALLEL_DECOMPOSITION_H - -#include "utils/graph/digraph/digraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/optional.h" -#include -#include - -namespace FlexFlow { - -std::optional - get_serial_parallel_decomposition(DiGraphView const &); -std::optional - get_serial_parallel_decomposition_with_dummy_nodes( - DiGraphView const &, std::unordered_set const &); - -} // namespace FlexFlow - -#endif - diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h deleted file mode 100644 index 96a552c9e0..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_H - -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include - -namespace FlexFlow { - -std::variant internal_to_final_ast( - std::variant const &ast); -SerialParallelDecomposition - to_final_ast(std::variant const &); - -std::unordered_set get_nodes(SerialParallelDecomposition const &sp); -std::unordered_set get_nodes(SerialSplit const &); -std::unordered_set get_nodes(ParallelSplit const &); -std::unordered_set get_nodes(Node const &); - -bool is_empty(Node const &node); -bool is_empty(SerialSplit const &serial); -bool is_empty(ParallelSplit const ¶llel); -bool is_empty(SerialParallelDecomposition const &sp); - -bool has_no_duplicate_nodes(SerialParallelDecomposition const &sp); - -SerialParallelDecomposition delete_node(SerialParallelDecomposition sp, - Node const &node); - -// duplicate nodes within `sp` are counted multiple times -size_t num_nodes(SerialParallelDecomposition const &sp); - -SerialParallelDecomposition serial_composition( - std::vector const &sp_compositions); -SerialParallelDecomposition parallel_composition( - std::unordered_set const &sp_compositions); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h deleted file mode 100644 index 5434d8fb7a..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h +++ /dev/null @@ -1,83 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_SPLITS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_SPLITS_H - -#include "utils/graph/node/node.dtg.h" -#include -#include - -namespace FlexFlow { - -struct SerialSplit; -struct ParallelSplit; - -struct SerialSplit { -public: - SerialSplit(); - explicit SerialSplit(std::vector> const &); - explicit SerialSplit( - std::initializer_list> const &); - explicit SerialSplit(std::vector const &nodes); - - bool operator==(SerialSplit const &) const; - bool operator!=(SerialSplit const &) const; - -public: - std::vector> children; - -private: - using Tie = std::tuple; - Tie tie() const; -}; - -std::string format_as(SerialSplit const &); -std::ostream &operator<<(std::ostream &, SerialSplit const &); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::SerialSplit> { - size_t operator()(::FlexFlow::SerialSplit const &) const; -}; - -} // namespace std - -namespace FlexFlow { - -struct ParallelSplit { -public: - ParallelSplit(); - explicit ParallelSplit( - std::unordered_set> const &); - explicit ParallelSplit( - std::initializer_list> const &); - explicit ParallelSplit(std::unordered_set const &nodes); - - bool operator==(ParallelSplit const &) const; - bool operator!=(ParallelSplit const &) const; - -public: - std::unordered_set> children; - -private: - using Tie = std::tuple; - Tie tie() const; -}; - -std::string format_as(ParallelSplit const &); -std::ostream &operator<<(std::ostream &, ParallelSplit const &); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::ParallelSplit> { - size_t operator()(::FlexFlow::ParallelSplit const &) const; -}; - -} // namespace std - -#endif - diff --git a/lib/utils/include/utils/graph/series_parallel/digraph_generation.h b/lib/utils/include/utils/graph/series_parallel/digraph_generation.h index 40fddc4c59..aa724d9567 100644 --- a/lib/utils/include/utils/graph/series_parallel/digraph_generation.h +++ b/lib/utils/include/utils/graph/series_parallel/digraph_generation.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_DIGRAPH_GENERATION_H #include "utils/graph/digraph/digraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { @@ -10,12 +10,12 @@ std::unordered_map parallel_extend(DiGraph &g, DiGraphView const &ext); std::unordered_map serial_extend(DiGraph &g, DiGraphView const &ext); -DiGraph serial_composition(DiGraphView const &g1, DiGraphView const &g2); +DiGraph series_composition(DiGraphView const &g1, DiGraphView const &g2); DiGraph parallel_composition(DiGraphView const &g1, DiGraphView const &g2); -DiGraph serial_composition(std::vector const &graphs); +DiGraph series_composition(std::vector const &graphs); DiGraph parallel_composition(std::vector const &graphs); -DiGraph digraph_from_sp_decomposition(SerialParallelDecomposition const &sp); +DiGraph digraph_from_sp_decomposition(SeriesParallelDecomposition const &sp); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/get_ancestors.h b/lib/utils/include/utils/graph/series_parallel/get_ancestors.h index aa580fe189..b7ae79bf49 100644 --- a/lib/utils/include/utils/graph/series_parallel/get_ancestors.h +++ b/lib/utils/include/utils/graph/series_parallel/get_ancestors.h @@ -1,11 +1,11 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLLEL_GET_ANCESTORS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLLEL_GET_ANCESTORS_H -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { -std::unordered_set get_ancestors(SerialParallelDecomposition const &sp, +std::unordered_set get_ancestors(SeriesParallelDecomposition const &sp, Node const &starting_node); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h index f2a006d899..ebaa2eb967 100644 --- a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h @@ -12,6 +12,10 @@ namespace FlexFlow { std::optional get_series_parallel_decomposition(DiGraphView const &); +std::optional + get_series_parallel_decomposition_with_sync_nodes( + DiGraphView const &, std::unordered_set const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/normalize_sp_decomposition.h b/lib/utils/include/utils/graph/series_parallel/normalize_sp_decomposition.h index 00a85a7514..46b60cd636 100644 --- a/lib/utils/include/utils/graph/series_parallel/normalize_sp_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/normalize_sp_decomposition.h @@ -1,23 +1,23 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_normalize_sp_decomposition_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_normalize_sp_decomposition_H -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { /** - * @brief Recursively normalizes a SerialParallelDecomposition. + * @brief Recursively normalizes a SeriesParallelDecomposition. * * @details This function performs the following semantic substitutions: - * - Deletes every empty SerialSplit and ParallelSplit item, e.g., + * - Deletes every empty SeriesSplit and ParallelSplit item, e.g., * S(P(S()), Node(1), Node(2)) -> S(Node(1), Node(2)) * - * - Replaces SerialSplit and ParallelSplit of size 1 with their content, e.g., + * - Replaces SeriesSplit and ParallelSplit of size 1 with their content, e.g., * S(S(Node(1)), P(Node(2))) -> S(Node(1), Node(2))) * */ -SerialParallelDecomposition - normalize_sp_decomposition(SerialParallelDecomposition const &sp); +SeriesParallelDecomposition + normalize_sp_decomposition(SeriesParallelDecomposition const &sp); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h index 52d2cb7236..b3fc201ca5 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -17,6 +17,25 @@ std::unordered_multiset get_nodes(SeriesSplit const &); std::unordered_multiset get_nodes(ParallelSplit const &); std::unordered_multiset get_nodes(Node const &); +bool is_empty(Node const &node); +bool is_empty(SeriesSplit const &serial); +bool is_empty(ParallelSplit const ¶llel); +bool is_empty(SeriesParallelDecomposition const &sp); + +bool has_no_duplicate_nodes(SeriesParallelDecomposition const &sp); + +SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp, + Node const &node); + +// duplicate nodes within `sp` are counted multiple times +size_t num_nodes(SeriesParallelDecomposition const &sp); + +SeriesParallelDecomposition series_composition( + std::vector const &sp_compositions); +SeriesParallelDecomposition parallel_composition( + std::unordered_multiset const + &sp_compositions); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/serial_parallel_metrics.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_metrics.h similarity index 60% rename from lib/utils/include/utils/graph/series_parallel/serial_parallel_metrics.h rename to lib/utils/include/utils/graph/series_parallel/series_parallel_metrics.h index 7d2c546117..2ca8fb0825 100644 --- a/lib/utils/include/utils/graph/series_parallel/serial_parallel_metrics.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_metrics.h @@ -1,48 +1,48 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_METRICS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_METRICS_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_series_parallel_metrics_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_series_parallel_metrics_H #include "utils/graph/digraph/digraph_view.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include namespace FlexFlow { std::unordered_map get_node_frequency_map(Node const &node); std::unordered_map - get_node_frequency_map(SerialSplit const &serial); + get_node_frequency_map(SeriesSplit const &serial); std::unordered_map get_node_frequency_map(ParallelSplit const ¶llel); std::unordered_map - get_node_frequency_map(SerialParallelDecomposition const &sp); + get_node_frequency_map(SeriesParallelDecomposition const &sp); -float work_cost(SerialParallelDecomposition const &sp, +float work_cost(SeriesParallelDecomposition const &sp, std::unordered_map cost_map); float work_cost(DiGraphView const &g, std::unordered_map const &cost_map); -int num_dependencies(SerialParallelDecomposition const &sp); +int num_dependencies(SeriesParallelDecomposition const &sp); int num_dependencies(DiGraphView const &g); -float critical_path_cost(SerialParallelDecomposition const &sp, +float critical_path_cost(SeriesParallelDecomposition const &sp, std::unordered_map const &cost_map); float critical_path_cost(DiGraphView const &g, std::unordered_map const &cost_map); float relative_work_increase(DiGraphView const &g, - SerialParallelDecomposition const &sp, + SeriesParallelDecomposition const &sp, std::unordered_map const &cost_map); float relative_critical_path_cost_increase( DiGraphView const &g, - SerialParallelDecomposition const &sp, + SeriesParallelDecomposition const &sp, std::unordered_map const &cost_map); float relative_num_dependencies_increase(DiGraphView const &g, - SerialParallelDecomposition const &sp); + SeriesParallelDecomposition const &sp); } // namespace FlexFlow -#endif // FLEXFLOW_SERIAL_PARALLEL_METRICS_H +#endif // FLEXFLOW_series_parallel_metrics_H diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.h b/lib/utils/include/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.h index 7cd756b68e..680e10493a 100644 --- a/lib/utils/include/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.h +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_UTILS_GRAPH_SERIAL_PARALLEL_CRITICAL_PATH_PRESERVING_SP_IZATION_H #include "utils/graph/digraph/digraph_view.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include namespace FlexFlow { @@ -13,10 +13,10 @@ namespace FlexFlow { * through node (work) duplication. * * @details - * The resulting graph, encoded as a SerialParallelDecomposition, is a tree + * The resulting graph, encoded as a SeriesParallelDecomposition, is a tree * whose critical path is the same as that of the original graph. The tree is * constructed as follows: - * - Denote SP(n) as the SerialParallelDecomposition of the subgraph of g whose + * - Denote SP(n) as the SeriesParallelDecomposition of the subgraph of g whose * nodes are all the ancestors of n. * - Denote the predecessors of n as M. * - Then: @@ -63,7 +63,7 @@ namespace FlexFlow { * @note g must be a 2 terminal (i.e. single source and single sink) directed * acyclic graph. */ -SerialParallelDecomposition +SeriesParallelDecomposition critical_path_preserving_sp_ization(DiGraphView const &g); /** @@ -111,7 +111,7 @@ SerialParallelDecomposition * } * */ -SerialParallelDecomposition +SeriesParallelDecomposition critical_path_preserving_sp_ization_with_coalescing(DiGraphView const &g); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/is_valid_sp_ization.h b/lib/utils/include/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h similarity index 58% rename from lib/utils/include/utils/graph/series_parallel/sp_ization/is_valid_sp_ization.h rename to lib/utils/include/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h index 883feaf87b..fed5ad727d 100644 --- a/lib/utils/include/utils/graph/series_parallel/sp_ization/is_valid_sp_ization.h +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h @@ -2,13 +2,13 @@ #define _FLEXFLOW_UTILS_GRAPH_SERIAL_PARALLEL_IS_VALID_SP_IZATION_H #include "utils/graph/digraph/digraph_view.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include namespace FlexFlow { -bool is_valid_sp_ization(DiGraphView const &g, - SerialParallelDecomposition const &sp); +bool dependencies_are_maintained(DiGraphView const &g, + SeriesParallelDecomposition const &sp); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/node_role.enum.toml b/lib/utils/include/utils/graph/series_parallel/sp_ization/node_role.enum.toml new file mode 100644 index 0000000000..bdc5940383 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/node_role.enum.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "NodeRole" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "PURE" + +[[values]] +name = "SYNC" + +[[values]] +name = "DUMMY" diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/spanish_algo.h b/lib/utils/include/utils/graph/series_parallel/sp_ization/spanish_algo.h index bc31eb733b..c4a2081a55 100644 --- a/lib/utils/include/utils/graph/series_parallel/sp_ization/spanish_algo.h +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/spanish_algo.h @@ -2,12 +2,25 @@ #define _FLEXFLOW_UTILS_GRAPH_SERIAL_PARALLEL_SP_IZATION_SPANISH_ALGO_H #include "utils/graph/digraph/digraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/sp_ization/node_role.dtg.h" #include - namespace FlexFlow { -SerialParallelDecomposition one_node_at_a_time_spanish_sp_ization(DiGraph g); +DiGraph add_dummy_nodes(DiGraph g, + std::unordered_map &node_roles); + +DiGraph + delete_dummy_nodes(DiGraph g, + std::unordered_map const &node_roles); + +std::unordered_set + get_component(DiGraph const &g, + Node const &node, + std::unordered_map const &depth_map, + std::unordered_map const &node_roles); + +SeriesParallelDecomposition spanish_strata_sync(DiGraph g); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.h b/lib/utils/include/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.h index ffbf9bfcfe..061c5ae5f0 100644 --- a/lib/utils/include/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.h +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_UTILS_GRAPH_SERIAL_PARALLEL_WORK_PRESERVING_SP_IZATION_H #include "utils/graph/digraph/digraph_view.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include namespace FlexFlow { @@ -44,7 +44,7 @@ namespace FlexFlow { * * @note g must be a directed acyclic graph. **/ -SerialParallelDecomposition stratum_sync_sp_ization(DiGraphView const &g); +SeriesParallelDecomposition stratum_sync_sp_ization(DiGraphView const &g); /** * @brief @@ -61,7 +61,7 @@ SerialParallelDecomposition stratum_sync_sp_ization(DiGraphView const &g); *similar critical path cost, thus minimizing the overall critical path cost of *the SP-ized graph. **/ -SerialParallelDecomposition cost_aware_stratum_sync_sp_ization( +SeriesParallelDecomposition cost_aware_stratum_sync_sp_ization( DiGraphView const &g, std::unordered_map const &cost_map); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_descendants.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_descendants.cc index 207fa8d192..0ccaa1bb0a 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_descendants.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_descendants.cc @@ -22,7 +22,6 @@ std::unordered_set get_descendants(DiGraphView const &g, to_visit.pop(); descendants.insert(current); - // add all unvisited successors of `current` to `to_visit` for (auto const &s : filter(get_successors(g, current), [&](Node const &n) { return !contains(descendants, n); })) { diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc index 1cfad742d1..6cc3b73805 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc @@ -1,5 +1,5 @@ #include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" -#include "utils/containers.h" +#include "utils/containers/maximum.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc index d5ea96693b..5053924035 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc @@ -1,6 +1,6 @@ -#include "utils/containers.h" #include "utils/containers/intersection.h" #include "utils/containers/is_subseteq_of.h" +#include "utils/containers/maximum.h" #include "utils/containers/transform.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/get_ancestors.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc index a7ff74b2ca..18f5ba34ca 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -1,5 +1,6 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/contains.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/vector_of.h" #include "utils/graph/digraph/algorithms.h" @@ -7,6 +8,7 @@ #include "utils/graph/digraph/algorithms/is_acyclic.h" #include "utils/graph/digraph/algorithms/materialize_digraph_view.h" #include "utils/graph/digraph/algorithms/transitive_closure.h" +#include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/node/algorithms.h" @@ -30,7 +32,7 @@ DirectedEdgeMaskView *DirectedEdgeMaskView::clone() const { return new DirectedEdgeMaskView(this->g, this->edge_mask); } -DiGraphView transitive_reduction(DiGraphView const &g) { +DiGraph transitive_reduction(DiGraphView const &g) { // Logic dropped down to raw adjacency matrix for performance. // The version going through the full graph abstraction was // incredibly slow (> minutes) for even moderately sized graphs @@ -87,9 +89,4 @@ DiGraphView transitive_reduction(DiGraphView const &g) { return result; } -DiGraphView transitive_reduction(DiGraphView const &g) { - assert(is_acyclic(g)); - return unchecked_transitive_reduction(g); -} - } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc deleted file mode 100644 index 5bb5c1229c..0000000000 --- a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc +++ /dev/null @@ -1,101 +0,0 @@ -#include "utils/graph/serial_parallel/serial_parallel_splits.h" -#include "utils/containers/transform.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/variant.h" -#include "utils/fmt/vector.h" -#include "utils/hash-utils.h" -#include "utils/hash/unordered_set.h" -#include "utils/hash/vector.h" - -namespace FlexFlow { - -SerialSplit::SerialSplit() : children{} {} - -SerialSplit::SerialSplit( - std::vector> const &children) - : children(children) {} - -SerialSplit::SerialSplit( - std::initializer_list> const &children) - : children(children) {} - -SerialSplit::SerialSplit(std::vector const &nodes) - : children(transform(nodes, [](Node const &node) { - return std::variant(node); - })) {} - -bool SerialSplit::operator==(SerialSplit const &other) const { - return this->tie() == other.tie(); -} - -bool SerialSplit::operator!=(SerialSplit const &other) const { - return this->tie() != other.tie(); -} - -SerialSplit::Tie SerialSplit::tie() const { - return std::tie(this->children); -} - -std::string format_as(SerialSplit const &split) { - return fmt::format("", split.children); -} - -std::ostream &operator<<(std::ostream &s, SerialSplit const &split) { - return s << fmt::to_string(split); -} - -ParallelSplit::ParallelSplit() : children{} {} - -ParallelSplit::ParallelSplit( - std::unordered_set> const &children) - : children(children) {} - -ParallelSplit::ParallelSplit( - std::initializer_list> const &children) - : children(children) {} - -ParallelSplit::ParallelSplit(std::unordered_set const &nodes) - : children(transform(nodes, [](Node const &node) { - return std::variant(node); - })) {} - -bool ParallelSplit::operator==(ParallelSplit const &other) const { - return this->tie() == other.tie(); -} - -bool ParallelSplit::operator!=(ParallelSplit const &other) const { - return this->tie() != other.tie(); -} - -ParallelSplit::Tie ParallelSplit::tie() const { - return std::tie(this->children); -} - -std::string format_as(ParallelSplit const &split) { - return fmt::format("", split.children); -} - -std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { - return s << fmt::to_string(split); -} - -} // namespace FlexFlow - -namespace std { - -size_t hash<::FlexFlow::SerialSplit>::operator()( - ::FlexFlow::SerialSplit const &s) const { - size_t result = 0; - ::FlexFlow::hash_combine(result, s.children); - return result; -} - -size_t hash<::FlexFlow::ParallelSplit>::operator()( - ::FlexFlow::ParallelSplit const &s) const { - size_t result = 0; - ::FlexFlow::hash_combine(result, s.children); - return result; -} - -} // namespace std - diff --git a/lib/utils/src/utils/graph/series_parallel/digraph_generation.cc b/lib/utils/src/utils/graph/series_parallel/digraph_generation.cc index b1b572d676..d34b599583 100644 --- a/lib/utils/src/utils/graph/series_parallel/digraph_generation.cc +++ b/lib/utils/src/utils/graph/series_parallel/digraph_generation.cc @@ -1,12 +1,13 @@ -#include "utils/graph/serial_parallel/digraph_generation.h" -#include "utils/containers/as_vector.h" +#include "utils/graph/series_parallel/digraph_generation.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/materialize_digraph_view.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/serial_parallel_splits.h" +#include "utils/graph/series_parallel/series_parallel_splits.h" +#include "utils/variant.h" namespace FlexFlow { @@ -34,7 +35,7 @@ std::unordered_map serial_extend(DiGraph &g, return node_map; } -DiGraph serial_composition(DiGraphView const &g1, DiGraphView const &g2) { +DiGraph series_composition(DiGraphView const &g1, DiGraphView const &g2) { DiGraph g = materialize_digraph_view(g1); serial_extend(g, g2); return g; @@ -46,10 +47,10 @@ DiGraph parallel_composition(DiGraphView const &g1, DiGraphView const &g2) { return g; } -DiGraph serial_composition(std::vector const &graphs) { +DiGraph series_composition(std::vector const &graphs) { DiGraph g = DiGraph::create(); for (DiGraphView const &gs : graphs) { - g = materialize_digraph_view(serial_composition(g, gs)); + g = materialize_digraph_view(series_composition(g, gs)); } return g; } @@ -70,21 +71,21 @@ DiGraph digraph_from_sp_decomposition(Node const &node) { return g; } -DiGraph digraph_from_sp_decomposition(SerialSplit const &serial) { - std::vector children = +DiGraph digraph_from_sp_decomposition(SeriesSplit const &serial) { + std::vector children = transform(serial.children, [](auto const &child) { - return widen(child); + return widen(child); }); - return serial_composition( + return series_composition( transform(children, [](auto const child) -> DiGraphView { return digraph_from_sp_decomposition(child); })); } DiGraph digraph_from_sp_decomposition(ParallelSplit const ¶llel) { - std::vector children = - transform(as_vector(parallel.children), [](auto const &child) { - return widen(child); + std::vector children = + transform(vector_of(parallel.get_children()), [](auto const &child) { + return widen(child); }); return parallel_composition( transform(children, [](auto const child) -> DiGraphView { @@ -92,7 +93,7 @@ DiGraph digraph_from_sp_decomposition(ParallelSplit const ¶llel) { })); } -DiGraph digraph_from_sp_decomposition(SerialParallelDecomposition const &sp) { +DiGraph digraph_from_sp_decomposition(SeriesParallelDecomposition const &sp) { return sp.visit( [](auto const &x) { return digraph_from_sp_decomposition(x); }); } diff --git a/lib/utils/src/utils/graph/series_parallel/get_ancestors.cc b/lib/utils/src/utils/graph/series_parallel/get_ancestors.cc index dafb783384..d4661e14ff 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_ancestors.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_ancestors.cc @@ -1,15 +1,15 @@ -#include "utils/graph/serial_parallel/get_ancestors.h" +#include "utils/graph/series_parallel/get_ancestors.h" #include "utils/containers/contains.h" #include "utils/containers/filter.h" #include "utils/containers/get_only.h" #include "utils/containers/transform.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/variant.h" #include namespace FlexFlow { -static bool perform_traversal(SerialParallelDecomposition const &sp, +static bool perform_traversal(SeriesParallelDecomposition const &sp, Node const &starting_node, std::unordered_set &ancestors) { return sp.visit([&](auto const &sp) { @@ -17,14 +17,14 @@ static bool perform_traversal(SerialParallelDecomposition const &sp, }); } -static bool perform_traversal(SerialSplit const &serial, +static bool perform_traversal(SeriesSplit const &serial, Node const &starting_node, std::unordered_set &ancestors) { - std::vector children = + std::vector children = transform(serial.children, [](auto const &child) { - return widen(child); + return widen(child); }); - for (SerialParallelDecomposition const &child : children) { + for (SeriesParallelDecomposition const &child : children) { bool found_starting_node = perform_traversal(child, starting_node, ancestors); if (found_starting_node) { @@ -37,22 +37,21 @@ static bool perform_traversal(SerialSplit const &serial, static bool perform_traversal(ParallelSplit const ¶llel, Node const &starting_node, std::unordered_set &ancestors) { - std::unordered_set children = - transform(parallel.children, [](auto const &child) { - return widen(child); + std::unordered_multiset children = + transform(parallel.get_children(), [](auto const &child) { + return widen(child); }); - // starting_node is in this ParallelSplit if (contains(get_nodes(parallel), starting_node)) { - SerialParallelDecomposition branch_with_starting_node = get_only( - filter(children, [&](SerialParallelDecomposition const &child) { + SeriesParallelDecomposition branch_with_starting_node = get_only( + filter(children, [&](SeriesParallelDecomposition const &child) { return contains(get_nodes(child), starting_node); })); perform_traversal(branch_with_starting_node, starting_node, ancestors); return true; } - for (SerialParallelDecomposition const &child : children) { + for (SeriesParallelDecomposition const &child : children) { perform_traversal(child, starting_node, ancestors); } return false; @@ -68,7 +67,7 @@ static bool perform_traversal(Node const &node, return true; } -std::unordered_set get_ancestors(SerialParallelDecomposition const &sp, +std::unordered_set get_ancestors(SeriesParallelDecomposition const &sp, Node const &starting_node) { assert(contains(get_nodes(sp), starting_node)); std::unordered_set ancestors; diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index cd29af59a0..5a9d91ee8b 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -9,6 +9,7 @@ #include "utils/graph/node/algorithms.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/normalize_sp_decomposition.h" #include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/graph/series_parallel/series_reduction.h" @@ -92,4 +93,19 @@ std::optional } } +std::optional + get_series_parallel_decomposition_with_sync_nodes( + DiGraphView const &g, std::unordered_set const &dummy_nodes) { + std::optional maybe_sp = + get_series_parallel_decomposition(g); + if (!maybe_sp) { + return std::nullopt; + } + SeriesParallelDecomposition sp = maybe_sp.value(); + for (Node const &dummy : dummy_nodes) { + sp = delete_node(sp, dummy); + } + return normalize_sp_decomposition(sp); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc index 47ec908fe4..19b02e5121 100644 --- a/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc @@ -1,44 +1,45 @@ -#include "utils/graph/serial_parallel/normalize_sp_decomposition.h" +#include "utils/graph/series_parallel/normalize_sp_decomposition.h" #include "utils/containers/filter.h" #include "utils/containers/get_only.h" #include "utils/containers/transform.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/variant.h" +#include namespace FlexFlow { template static auto filter_empty(T const &container) { return filter(container, [](auto const &child) { - return !is_empty(widen(child)); + return !is_empty(widen(child)); }); } -SerialParallelDecomposition normalize_sp_decomposition(Node const &node) { - return SerialParallelDecomposition(node); +SeriesParallelDecomposition normalize_sp_decomposition(Node const &node) { + return SeriesParallelDecomposition(node); } -SerialParallelDecomposition - normalize_sp_decomposition(SerialSplit const &serial) { - std::vector normalized_children = +SeriesParallelDecomposition + normalize_sp_decomposition(SeriesSplit const &serial) { + std::vector normalized_children = transform(filter_empty(serial.children), [](auto const &child) { return normalize_sp_decomposition( - widen(child)); + widen(child)); }); if (normalized_children.size() == 1) { return get_only(normalized_children); } - return serial_composition(normalized_children); + return series_composition(normalized_children); } -SerialParallelDecomposition +SeriesParallelDecomposition normalize_sp_decomposition(ParallelSplit const ¶llel) { - std::unordered_set normalized_children = - transform(filter_empty(parallel.children), [](auto const &child) { + std::unordered_multiset normalized_children = + transform(filter_empty(parallel.get_children()), [](auto const &child) { return normalize_sp_decomposition( - widen(child)); + widen(child)); }); if (normalized_children.size() == 1) { @@ -47,9 +48,9 @@ SerialParallelDecomposition return parallel_composition(normalized_children); } -SerialParallelDecomposition - normalize_sp_decomposition(SerialParallelDecomposition const &sp) { - return sp.visit( +SeriesParallelDecomposition + normalize_sp_decomposition(SeriesParallelDecomposition const &sp) { + return sp.visit( [](auto const &x) { return normalize_sp_decomposition(x); }); } diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index 51ecdee862..2d8d88c490 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,27 +1,18 @@ -<<<<<<< HEAD:lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/containers.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/containers/all_of.h" #include "utils/containers/extend.h" -#include "utils/containers/get_only.h" -#include "utils/containers/set_union.h" -#include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/containers/values.h" -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" -#include "utils/graph/serial_parallel/normalize_sp_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_metrics.h" -======= -#include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/containers/multiset_union.h" #include "utils/containers/set_union.h" +#include "utils/containers/sum.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/values.h" #include "utils/containers/vector_of.h" #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" ->>>>>>> origin/repo-refactor:lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +#include "utils/graph/series_parallel/series_parallel_metrics.h" #include "utils/hash/unordered_set.h" #include "utils/variant.h" +#include namespace FlexFlow { @@ -93,33 +84,33 @@ bool is_empty(Node const &node) { return false; } -bool is_empty(SerialSplit const &serial) { +bool is_empty(SeriesSplit const &serial) { return all_of(serial.children, [](auto const &child) { - return is_empty(widen(child)); + return is_empty(widen(child)); }); } bool is_empty(ParallelSplit const ¶llel) { - return all_of(parallel.children, [](auto const &child) { - return is_empty(widen(child)); + return all_of(parallel.get_children(), [](auto const &child) { + return is_empty(widen(child)); }); } -bool is_empty(SerialParallelDecomposition const &sp) { +bool is_empty(SeriesParallelDecomposition const &sp) { return sp.visit([](auto const &t) { return is_empty(t); }); } -SerialParallelDecomposition delete_node(SerialParallelDecomposition sp, +SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp, Node const &n) { - return sp.visit( + return sp.visit( [&n](auto const &t) { return delete_node(t, n); }); } -SerialParallelDecomposition delete_node(ParallelSplit const ¶llel, +SeriesParallelDecomposition delete_node(ParallelSplit const ¶llel, Node const &n) { - std::unordered_set children; - for (auto const &child : parallel.children) { - auto widened = widen(child); + std::unordered_multiset children; + for (auto const &child : parallel.get_children()) { + auto widened = widen(child); if (!(widened.has() && widened.get() == n)) { children.insert(delete_node(widened, n)); } @@ -127,67 +118,70 @@ SerialParallelDecomposition delete_node(ParallelSplit const ¶llel, return parallel_composition(children); } -SerialParallelDecomposition delete_node(SerialSplit const &serial, +SeriesParallelDecomposition delete_node(SeriesSplit const &serial, Node const &n) { - std::vector children; + std::vector children; for (auto const &child : serial.children) { - auto widened = widen(child); + auto widened = widen(child); if (!(widened.has() && widened.get() == n)) { children.push_back(delete_node(widened, n)); } } - return serial_composition(children); + return series_composition(children); } -SerialParallelDecomposition delete_node(Node const &node, Node const &n) { +SeriesParallelDecomposition delete_node(Node const &node, Node const &n) { if (node == n) { throw mk_runtime_error( - "Cannot delete Node from Node, only from ParallelSplit or SerialSplit"); + "Cannot delete Node from Node, only from ParallelSplit or SeriesSplit"); } - return SerialParallelDecomposition{node}; + return SeriesParallelDecomposition{node}; } -size_t num_nodes(SerialParallelDecomposition const &sp) { +size_t num_nodes(SeriesParallelDecomposition const &sp) { return sum(values(get_node_frequency_map(sp))); } -bool has_no_duplicate_nodes(SerialParallelDecomposition const &sp) { +bool has_no_duplicate_nodes(SeriesParallelDecomposition const &sp) { return all_of(values(get_node_frequency_map(sp)), [](int count) { return count == 1; }); } -SerialParallelDecomposition serial_composition( - std::vector const &sp_compositions) { - SerialSplit composition{}; - for (SerialParallelDecomposition const &sp_comp : sp_compositions) { - if (sp_comp.has()) { - extend(composition.children, sp_comp.get().children); +SeriesParallelDecomposition series_composition( + std::vector const &sp_compositions) { + std::vector> composition{}; + for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + extend(composition, sp_comp.get().children); } else if (sp_comp.has()) { - composition.children.push_back(sp_comp.get()); + composition.push_back(sp_comp.get()); } else { assert(sp_comp.has()); - composition.children.push_back(sp_comp.get()); + composition.push_back(sp_comp.get()); } } - return SerialParallelDecomposition(composition); + return SeriesParallelDecomposition{SeriesSplit{composition}}; } -SerialParallelDecomposition parallel_composition( - std::unordered_set const &sp_compositions) { - ParallelSplit composition{}; - for (SerialParallelDecomposition const &sp_comp : sp_compositions) { +SeriesParallelDecomposition parallel_composition( + std::unordered_multiset const + &sp_compositions) { + std::unordered_multiset< + std::variant<::FlexFlow::SeriesSplit, ::FlexFlow::Node>> + composition{}; + for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { if (sp_comp.has()) { - composition.children = set_union(composition.children, - sp_comp.get().children); - } else if (sp_comp.has()) { - composition.children.insert(sp_comp.get()); + composition = multiset_union(composition, + sp_comp.get().get_children()); + } else if (sp_comp.has()) { + composition.insert(sp_comp.get()); } else { assert(sp_comp.has()); - composition.children.insert(sp_comp.get()); + composition.insert(sp_comp.get()); } } - return SerialParallelDecomposition(composition); + return SeriesParallelDecomposition(ParallelSplit{composition}); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/serial_parallel_metrics.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc similarity index 67% rename from lib/utils/src/utils/graph/series_parallel/serial_parallel_metrics.cc rename to lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc index e25a8dc80a..3b95357fcb 100644 --- a/lib/utils/src/utils/graph/series_parallel/serial_parallel_metrics.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc @@ -1,15 +1,18 @@ -#include "utils/graph/serial_parallel/serial_parallel_metrics.h" -#include "utils/containers.h" -#include "utils/containers/as_vector.h" +#include "utils/graph/series_parallel/series_parallel_metrics.h" +#include "utils/containers/maximum.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" #include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/fmt/unordered_multiset.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" #include "utils/graph/digraph/digraph_view.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/digraph_generation.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/digraph_generation.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/variant.h" #include - namespace FlexFlow { std::unordered_map get_node_frequency_map(Node const &node) { @@ -19,9 +22,9 @@ std::unordered_map get_node_frequency_map(Node const &node) { std::unordered_map get_node_frequency_map(ParallelSplit const ¶llel) { std::unordered_map counter; - for (std::variant const &child : parallel.children) { + for (std::variant const &child : parallel.get_children()) { for (auto const &[node, count] : - get_node_frequency_map(widen(child))) { + get_node_frequency_map(widen(child))) { counter[node] += count; } } @@ -29,11 +32,11 @@ std::unordered_map } std::unordered_map - get_node_frequency_map(SerialSplit const &serial) { + get_node_frequency_map(SeriesSplit const &serial) { std::unordered_map counter; for (std::variant const &child : serial.children) { for (auto const &[node, count] : - get_node_frequency_map(widen(child))) { + get_node_frequency_map(widen(child))) { counter[node] += count; } } @@ -41,12 +44,12 @@ std::unordered_map } std::unordered_map - get_node_frequency_map(SerialParallelDecomposition const &sp) { + get_node_frequency_map(SeriesParallelDecomposition const &sp) { return sp.visit>( [](auto const &t) { return get_node_frequency_map(t); }); } -float work_cost(SerialParallelDecomposition const &sp, +float work_cost(SeriesParallelDecomposition const &sp, std::unordered_map cost_map) { auto cost_per_node_group = [&](std::pair const &pair) { return pair.second * cost_map.at(pair.first); @@ -58,7 +61,7 @@ float work_cost(SerialParallelDecomposition const &sp, float work_cost(DiGraphView const &g, std::unordered_map const &cost_map) { - return sum(transform(as_vector(get_nodes(g)), + return sum(transform(vector_of(get_nodes(g)), [&](Node const &node) { return cost_map.at(node); })); } @@ -67,25 +70,26 @@ float critical_path_cost(Node const &node, return cost_map.at(node); } -float critical_path_cost(SerialSplit const &serial, +float critical_path_cost(SeriesSplit const &serial, std::unordered_map const &cost_map) { return sum(transform( serial.children, [&](std::variant const &child) { - return critical_path_cost(widen(child), + return critical_path_cost(widen(child), cost_map); })); } float critical_path_cost(ParallelSplit const ¶llel, std::unordered_map const &cost_map) { - return maximum(transform( - parallel.children, [&](std::variant const &child) { - return critical_path_cost(widen(child), - cost_map); - })); + return maximum(transform(parallel.get_children(), + [&](std::variant const &child) { + return critical_path_cost( + widen(child), + cost_map); + })); } -float critical_path_cost(SerialParallelDecomposition const &sp, +float critical_path_cost(SeriesParallelDecomposition const &sp, std::unordered_map const &cost_map) { return sp.visit( [&](auto const &t) { return critical_path_cost(t, cost_map); }); @@ -97,7 +101,7 @@ float critical_path_cost(DiGraphView const &g, values(get_weighted_longest_path_lengths_from_root(g, cost_map))); } -int num_dependencies(SerialParallelDecomposition const &sp) { +int num_dependencies(SeriesParallelDecomposition const &sp) { return num_dependencies(digraph_from_sp_decomposition(sp)); } @@ -106,20 +110,20 @@ int num_dependencies(DiGraphView const &g) { } float relative_work_increase(DiGraphView const &g, - SerialParallelDecomposition const &sp, + SeriesParallelDecomposition const &sp, std::unordered_map const &cost_map) { return work_cost(sp, cost_map) / work_cost(g, cost_map); } float relative_critical_path_cost_increase( DiGraphView const &g, - SerialParallelDecomposition const &sp, + SeriesParallelDecomposition const &sp, std::unordered_map const &cost_map) { return critical_path_cost(sp, cost_map) / critical_path_cost(g, cost_map); } float relative_num_dependencies_increase( - DiGraphView const &g, SerialParallelDecomposition const &sp) { + DiGraphView const &g, SeriesParallelDecomposition const &sp) { return static_cast(num_dependencies(sp)) / num_dependencies(g); } diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.cc index 4f7c8bc87e..33fd66e2eb 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.cc @@ -1,22 +1,24 @@ -#include "utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.h" -#include "utils/containers/as_vector.h" +#include "utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.h" #include "utils/containers/get_only.h" #include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/vector_of.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/digraph/algorithms/is_2_terminal_dag.h" #include "utils/graph/digraph/digraph_view.h" -#include "utils/graph/serial_parallel/normalize_sp_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/normalize_sp_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/variant.h" +#include namespace FlexFlow { -static SerialSplit cut_off_head(SerialSplit const &s) { +static SeriesSplit cut_off_head(SeriesSplit const &s) { assert(s.children.size() > 0); - return SerialSplit{std::vector>( + return SeriesSplit{std::vector>( s.children.begin() + 1, s.children.end())}; } @@ -26,17 +28,17 @@ static SerialSplit cut_off_head(SerialSplit const &s) { * without coalescing: P(S(1, 2, 5), S(1, 3, 4)) * with coalescing: S(1, P( S(2,5), S(3,4) )) */ -static SerialParallelDecomposition parallel_composition_with_coalescing( - std::unordered_set const &strands) { +static SeriesParallelDecomposition parallel_composition_with_coalescing( + std::unordered_set const &strands) { if (strands.size() == 1) { - return SerialParallelDecomposition(get_only(strands)); + return SeriesParallelDecomposition(get_only(strands)); } // group strands by their first ("head") node std::unordered_map, - std::unordered_set> + std::unordered_set> grouped_strands; - for (SerialSplit predecessor : filter(strands, [](SerialSplit const &serial) { + for (SeriesSplit predecessor : filter(strands, [](SeriesSplit const &serial) { return !is_empty(serial); })) { grouped_strands[predecessor.children.at(0)].insert( @@ -44,64 +46,65 @@ static SerialParallelDecomposition parallel_composition_with_coalescing( } // recursively coalesce the strands - std::unordered_set coalesced_strands; + std::unordered_multiset coalesced_strands; for (auto const &[head, tails] : grouped_strands) { - SerialParallelDecomposition parallel_comp = + SeriesParallelDecomposition parallel_comp = parallel_composition_with_coalescing(tails); - coalesced_strands.insert(serial_composition( - {widen(head), parallel_comp})); + coalesced_strands.insert(series_composition( + {widen(head), parallel_comp})); } return normalize_sp_decomposition(parallel_composition(coalesced_strands)); } -static SerialParallelDecomposition +static SeriesParallelDecomposition critical_path_preserving_sp_ization_unchecked_with_coalescing( DiGraphView const &g) { - std::unordered_map node_to_sp; + std::unordered_map node_to_sp; Node source = get_only(get_sources(g)); - node_to_sp[source] = SerialSplit{{source}}; + node_to_sp.emplace(source, SeriesSplit{{source}}); for (Node const &node : get_topological_ordering(g)) { if (node == source) { continue; } - std::unordered_set predecessors_as_sp = + std::unordered_set predecessors_as_sp = transform(get_predecessors(g, node), [&](Node const &p) { return node_to_sp.at(p); }); - SerialParallelDecomposition parallel_composed_predecessors = + SeriesParallelDecomposition parallel_composed_predecessors = parallel_composition_with_coalescing(predecessors_as_sp); - SerialParallelDecomposition sp_decomp = serial_composition( - {parallel_composed_predecessors, SerialParallelDecomposition(node)}); - node_to_sp[node] = sp_decomp.get(); + SeriesParallelDecomposition sp_decomp = series_composition( + {parallel_composed_predecessors, SeriesParallelDecomposition(node)}); + node_to_sp.emplace(node, sp_decomp.get()); } Node sink = get_only(get_sinks(g)); return normalize_sp_decomposition( - SerialParallelDecomposition(node_to_sp.at(sink))); + SeriesParallelDecomposition(node_to_sp.at(sink))); } -SerialParallelDecomposition +SeriesParallelDecomposition critical_path_preserving_sp_ization_with_coalescing(DiGraphView const &g) { assert(is_2_terminal_dag(g)); return critical_path_preserving_sp_ization_unchecked_with_coalescing(g); } -static SerialParallelDecomposition +static SeriesParallelDecomposition critical_path_preserving_sp_ization_unchecked(DiGraphView const &g) { - std::unordered_map node_to_sp; + std::unordered_map node_to_sp; for (Node const &node : get_topological_ordering(g)) { - std::unordered_set predecessors_as_sp = - transform(get_predecessors(g, node), - [&](Node const &p) { return node_to_sp.at(p); }); + std::unordered_multiset predecessors_as_sp = + unordered_multiset_of( + transform(get_predecessors(g, node), + [&](Node const &p) { return node_to_sp.at(p); })); - SerialParallelDecomposition sp_decomp = serial_composition( + SeriesParallelDecomposition sp_decomp = series_composition( {normalize_sp_decomposition(parallel_composition(predecessors_as_sp)), - SerialParallelDecomposition(node)}); + SeriesParallelDecomposition(node)}); node_to_sp.emplace(node, sp_decomp); } @@ -110,7 +113,7 @@ static SerialParallelDecomposition return node_to_sp.at(sink); } -SerialParallelDecomposition +SeriesParallelDecomposition critical_path_preserving_sp_ization(DiGraphView const &g) { assert(is_2_terminal_dag(g)); return critical_path_preserving_sp_ization_unchecked(g); diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/is_valid_sp_ization.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc similarity index 61% rename from lib/utils/src/utils/graph/series_parallel/sp_ization/is_valid_sp_ization.cc rename to lib/utils/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc index af17837248..240e5ae063 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/is_valid_sp_ization.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc @@ -1,19 +1,20 @@ -#include "utils/graph/serial_parallel/sp_ization/dependencies_are_maintained.h" +#include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" #include "utils/containers/is_subseteq_of.h" +#include "utils/containers/unordered_set_of.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_ancestors.h" #include "utils/graph/digraph/algorithms/get_descendants.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/digraph_generation.h" -#include "utils/graph/serial_parallel/get_ancestors.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/digraph_generation.h" +#include "utils/graph/series_parallel/get_ancestors.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" namespace FlexFlow { bool dependencies_are_maintained(DiGraphView const &g, - SerialParallelDecomposition const &sp) { + SeriesParallelDecomposition const &sp) { assert(has_no_duplicate_nodes(sp)); - if (get_nodes(sp) != get_nodes(g)) { + if (unordered_set_of(get_nodes(sp)) != get_nodes(g)) { return false; } diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/spanish_algo.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/spanish_algo.cc index 0854aa3969..9ca4d22d53 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/spanish_algo.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/spanish_algo.cc @@ -1,127 +1,264 @@ -#include "utils/containers/contains.h" -#include "utils/graph/serial_parallel/sp_ization/spanish_algo.h" -#include "utils/containers/filter.h" +#include "utils/graph/series_parallel/sp_ization/spanish_algo.h" +#include "utils/containers/filter_keys.h" +#include "utils/containers/filtrans.h" +#include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" -#include "utils/containers/sorted_by.h" -#include "utils/containers/set_difference.h" -#include "utils/containers/transform.h" +#include "utils/containers/group_by.h" +#include "utils/containers/intersection.h" +#include "utils/containers/map_values.h" +#include "utils/containers/maximum.h" +#include "utils/containers/range.h" +#include "utils/containers/set_union.h" +#include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/fmt/unordered_multiset.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/get_descendants.h" +#include "utils/graph/digraph/algorithms/get_incoming_edges.h" #include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" #include "utils/graph/digraph/algorithms/get_lowest_common_ancestors.h" +#include "utils/graph/digraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" #include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" #include "utils/graph/digraph/algorithms/is_2_terminal_dag.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/digraph/directed_edge.dtg.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/sp_ization/dependencies_are_maintained.h" -#include "utils/graph/digraph/algorithms/transitive_reduction.h" -#include "utils/graph/digraph/algorithms/transitive_reduction.h" -#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/sp_ization/node_role.dtg.h" + +#include +#include +#include namespace FlexFlow { -static std::unordered_set find_down_class_which_contains_node( - DiGraphView const &g, - Node const &n, - std::unordered_map const &depth_map) { - int max_depth = depth_map.at(n); - std::unordered_set last_2_layers = - filter(get_nodes(g), [&](Node const &node) { - return (depth_map.at(node) == max_depth) || - (depth_map.at(node) == max_depth - 1); - }); - DiGraphView last_2_layers_subgraph = get_subgraph(g, last_2_layers); - std::unordered_set component_containing_n = - get_only(filter(get_weakly_connected_components(last_2_layers_subgraph), - [&](std::unordered_set const &component) { - return contains(component, n); - })); - // TODO(@pietro): check that it has height 2 - return set_union( - transform(get_nodes(last_2_layers_subgraph), - [&](Node const &node) { return get_successors(g, node); })); +std::unordered_map get_initial_node_role_map(DiGraph g) { + return generate_map(get_nodes(g), + [&](Node const &n) { return NodeRole::PURE; }); } -static Node get_handle(DiGraphView const &g, std::unordered_set const &down_class) { - return get_only(get_lowest_common_ancestors(g, down_class).value()); +std::unordered_set + filter_sync_nodes(std::unordered_set const &nodes, + std::unordered_map const &node_roles) { + return filter( + nodes, [&](Node const &n) { return node_roles.at(n) != NodeRole::SYNC; }); } +int get_max_depth(DiGraph const &sp, + std::unordered_map const &depth_map) { + return maximum(values(filter_keys( + depth_map, [&](Node const &n) { return contains(get_nodes(sp), n); }))); +} -SerialParallelDecomposition - one_node_at_a_time_spanish_sp_ization_unchecked(DiGraph g) { - // TODO(@pietro): apply transitive reduction - g = materialize_digraph_view(transitive_reduction(g)); - std::unordered_set original_nodes = get_nodes(g); +DiGraph add_dummy_nodes(DiGraph g, + std::unordered_map &node_roles) { std::unordered_map depth_map = get_longest_path_lengths_from_root(g); - std::vector nodes = - sorted_by(get_nodes(g), [&](Node const &n1, Node const &n2) { - return depth_map.at(n1) < depth_map.at(n2); - }); - Node root = nodes.at(0); - std::unordered_set to_consider = {root}; - for (auto const &n : nodes) { - std::cout << n << get_edges(g) << std::endl; - if (n == root) { - continue; + for (DirectedEdge const &e : get_edges(g)) { + Node src = e.src; + Node dst = e.dst; + int depth_diff = depth_map.at(dst) - depth_map.at(src); + if (depth_diff > 1) { + g.remove_edge(e); + Node prev_node = src; + Node intermediate_node = Node{0}; + for (int i : range(1, depth_diff)) { + intermediate_node = g.add_node(); + node_roles[intermediate_node] = NodeRole::DUMMY; + g.add_edge(DirectedEdge{prev_node, intermediate_node}); + prev_node = intermediate_node; + } + g.add_edge(DirectedEdge{prev_node, dst}); + } + } + return g; +} + +DiGraph + delete_dummy_nodes(DiGraph g, + std::unordered_map const &node_roles) { + for (Node const &n : get_nodes(g)) { + if (node_roles.at(n) == NodeRole::DUMMY) { + for (Node const &pred : get_predecessors(g, n)) { + for (Node const &succ : get_successors(g, n)) { + g.add_edge(DirectedEdge{pred, succ}); + } + } + remove_node(g, n); } + } + return g; +} - to_consider.insert(n); - DiGraphView subgraph = get_subgraph(g, to_consider); - std::unordered_set component = - find_down_class_which_contains_node(subgraph, n, depth_map); - - Node handle = get_handle(subgraph, component); - - std::unordered_set forest = filter(get_descendants(subgraph, handle), - [&](Node const &n) { return contains(original_nodes, n); }); - if (forest.empty()) {continue;} - std::unordered_set last_layer = filter(forest, [&](auto const &node) { - return depth_map.at(node) == depth_map.at(n); - }); - std::unordered_set penultimate_layer = - filter(forest, [&](auto const &node) { - return depth_map.at(node) == depth_map.at(n) - 1; - }); - - std::cout << handle << forest << std::endl << std::endl; - - Node sync_node = g.add_node(); - to_consider.insert(sync_node); - for (DirectedEdge const &e : get_edges(g)) { - if (contains(last_layer, e.dst)) { - g.remove_edge(e); - g.add_edge(DirectedEdge{e.src, sync_node}); - g.add_edge(DirectedEdge{sync_node, e.dst}); +DiGraph + delete_sync_nodes(DiGraph g, + std::unordered_map const &node_roles) { + for (Node const &n : get_nodes(g)) { + if (node_roles.at(n) == NodeRole::SYNC) { + for (Node const &pred : get_predecessors(g, n)) { + for (Node const &succ : get_successors(g, n)) { + g.add_edge(DirectedEdge{pred, succ}); + } } + remove_node(g, n); } - - for (DirectedEdge const &e : get_edges(g)) { - if (contains(forest, e.src) && (depth_map.at(e.dst) > depth_map.at(n))) { - g.remove_edge(e); - g.add_edge(DirectedEdge{sync_node, e.dst}); + } + return g; +} + +std::unordered_set + get_component(DiGraph const &g, + Node const &node, + std::unordered_map const &depth_map, + std::unordered_map const &node_roles) { + + int max_depth = get_max_depth(g, depth_map); + auto is_in_last_2_layers = [&](Node const &n) { + if (node_roles.at(n) == NodeRole::SYNC) { + if (get_successors(g, n).empty()) { + return true; } + int successors_depth = + get_only(transform(get_successors(g, n), + [&](Node const &n) { return depth_map.at(n); })); + return successors_depth == max_depth; + } else { + return (depth_map.at(n) == max_depth) || + (depth_map.at(n) == max_depth - 1); } - - g = materialize_digraph_view(transitive_reduction(g)); + }; + std::unordered_set last_two_layers_nodes = + filter(get_nodes(g), is_in_last_2_layers); + + DiGraph subgraph = materialize_digraph_view( + get_subgraph(g, last_two_layers_nodes)); + std::unordered_set component = + get_only(filter(get_weakly_connected_components(subgraph), + [&](std::unordered_set const &component) { + return contains(component, node); + })); + std::unordered_set component_without_sync_nodes = + filter_sync_nodes(component, node_roles); + return component_without_sync_nodes; +} + +std::unordered_set + get_forest(DiGraph const &g, + Node const &handle, + std::unordered_set const &component, + std::unordered_map const &node_roles) { + std::unordered_set> subtrees = + transform(get_successors(g, handle), [&](Node const &n) { + return set_union(get_descendants(g, n), {n}); + }); + auto subtrees_overlapping_with_component = + filter(subtrees, [&](std::unordered_set subtree) { + return intersection(subtree, component).size() > 0; + }); + std::unordered_set forest = + set_union(subtrees_overlapping_with_component); + forest.insert(handle); + return filter_sync_nodes(forest, node_roles); +} + +std::pair, std::unordered_set> + get_up_and_down(DiGraph const &g, + std::unordered_set const &forest, + std::unordered_map const &depth_map) { + int max_depth = get_max_depth(g, depth_map); + std::unordered_map> grouped_by_depth = + group_by(forest, [&](Node const &n) { return depth_map.at(n); }); + return {grouped_by_depth.at(max_depth - 1), grouped_by_depth.at(max_depth)}; +} + +std::unordered_set + edges_to_remove(DiGraph const &g, + std::unordered_set const &up, + std::unordered_set const &down) { + std::unordered_set to_remove; + + for (Node const &u : up) { + to_remove = set_union(to_remove, get_outgoing_edges(g, u)); + } + for (Node const &d : down) { + to_remove = set_union(to_remove, get_incoming_edges(g, d)); } - std::cout << get_edges(g) << std::endl; - std::unordered_set dummy_nodes = set_difference(get_nodes(g), original_nodes); - return get_serial_parallel_decomposition_with_dummy_nodes(g, dummy_nodes).value(); + return to_remove; } -SerialParallelDecomposition one_node_at_a_time_spanish_sp_ization(DiGraph g) { - assert(is_2_terminal_dag(g)); - SerialParallelDecomposition sp = - one_node_at_a_time_spanish_sp_ization_unchecked(g); - assert(dependencies_are_maintained(g, sp)); - return sp; +std::unordered_set + edges_to_add(std::unordered_set const &up, + std::unordered_set const &down, + Node const &sync_node) { + std::unordered_set to_add; + + for (Node const &u : up) { + to_add.insert(DirectedEdge{u, sync_node}); + } + + for (Node const &d : down) { + to_add.insert(DirectedEdge{sync_node, d}); + } + + return to_add; } +SeriesParallelDecomposition spanish_strata_sync(DiGraph g) { + assert(is_2_terminal_dag(g)); + assert(is_acyclic(g)); + + std::unordered_map node_roles = get_initial_node_role_map(g); + + g = add_dummy_nodes(g, node_roles); + std::unordered_map depth_map = + get_longest_path_lengths_from_root(g); + + DiGraph sp = DiGraph::create(); + Node root = get_only(get_sources(g)); + sp.add_node_unsafe(root); + size_t sync_node_counter = maximum( + transform(get_nodes(g), [&](Node const &n) { return n.raw_uid; })); + for (Node const &node : get_bfs_ordering(g, {root})) { + if (node == root) { + continue; + } + sp.add_node_unsafe(node); + add_edges(sp, vector_of(get_incoming_edges(g, node))); + std::unordered_set component = + get_component(sp, node, depth_map, node_roles); + Node handle = get_only(get_lowest_common_ancestors(sp, component).value()); + std::unordered_set forest = + get_forest(sp, handle, component, node_roles); + auto [up, down] = get_up_and_down(sp, forest, depth_map); + + for (DirectedEdge const &e : edges_to_remove(sp, up, down)) { + sp.remove_edge(e); + } + + Node sync_node = Node{++sync_node_counter}; + node_roles[sync_node] = NodeRole::SYNC; + sp.add_node_unsafe(sync_node); + for (DirectedEdge const &e : edges_to_add(up, down, sync_node)) { + sp.add_edge(e); + } + } + sp = delete_dummy_nodes(sp, node_roles); + sp = transitive_reduction(sp); + std::unordered_set sync_nodes = + filter(get_nodes(sp), [&](Node const &node) { + return node_roles.at(node) == NodeRole::SYNC; + }); + return get_series_parallel_decomposition_with_sync_nodes(sp, sync_nodes) + .value(); +} } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.cc index 64a3bfccf0..8ebfa7c50c 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.cc @@ -1,13 +1,14 @@ -#include "utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.h" -#include "utils/containers.h" +#include "utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.h" #include "utils/containers/all_of.h" -#include "utils/containers/as_vector.h" #include "utils/containers/get_only.h" #include "utils/containers/invert_map.h" #include "utils/containers/keys.h" +#include "utils/containers/maximum.h" #include "utils/containers/sorted.h" +#include "utils/containers/unordered_multiset_of.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" +#include "utils/fmt/unordered_multiset.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" @@ -18,11 +19,11 @@ #include "utils/graph/digraph/algorithms/is_acyclic.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/normalize_sp_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_metrics.h" -#include "utils/graph/serial_parallel/sp_ization/dependencies_are_maintained.h" +#include "utils/graph/series_parallel/normalize_sp_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_metrics.h" +#include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" #include "utils/hash/unordered_set.h" #include "utils/hash/vector.h" #include @@ -30,11 +31,11 @@ namespace FlexFlow { -std::vector> +std::vector> stratum_split_assuming_unit_cost(DiGraphView const &g) { std::unordered_map node_to_stratum = get_longest_path_lengths_from_root(g); - std::vector> result( + std::vector> result( maximum(values(node_to_stratum))); for (auto const &[node, depth] : node_to_stratum) { result[depth - 1].insert(node); @@ -42,26 +43,28 @@ std::vector> return result; } -static SerialParallelDecomposition - naive_stratum_merge(std::vector> stratum_split) { - std::vector strata = transform( - stratum_split, [](std::unordered_set const &stratum_nodes) { - return SerialParallelDecomposition(ParallelSplit{stratum_nodes}); +static SeriesParallelDecomposition naive_stratum_merge( + std::vector> stratum_split) { + std::vector strata = transform( + stratum_split, [](std::unordered_multiset const &stratum_nodes) { + return parallel_composition(transform(stratum_nodes, [](Node const &n) { + return SeriesParallelDecomposition{n}; + })); }); - return normalize_sp_decomposition(serial_composition(strata)); + return normalize_sp_decomposition(series_composition(strata)); } -SerialParallelDecomposition +SeriesParallelDecomposition stratum_sync_sp_ization_unchecked(DiGraphView const &g) { - std::vector> stratum_split = + std::vector> stratum_split = stratum_split_assuming_unit_cost(g); return naive_stratum_merge(stratum_split); } -SerialParallelDecomposition stratum_sync_sp_ization(DiGraphView const &g) { +SeriesParallelDecomposition stratum_sync_sp_ization(DiGraphView const &g) { assert(is_acyclic(g)); - SerialParallelDecomposition sp = stratum_sync_sp_ization_unchecked(g); + SeriesParallelDecomposition sp = stratum_sync_sp_ization_unchecked(g); assert(dependencies_are_maintained(g, sp)); return sp; } @@ -163,32 +166,33 @@ static std::vector>> return strata; } -SerialParallelDecomposition cost_aware_stratum_sync_sp_ization_unchecked( +SeriesParallelDecomposition cost_aware_stratum_sync_sp_ization_unchecked( DiGraphView const &g, std::unordered_map const &cost_map) { if (get_nodes(g).size() == 1) { - return SerialParallelDecomposition(get_only(get_nodes(g))); + return SeriesParallelDecomposition(get_only(get_nodes(g))); } - std::vector> sp_ized_strata; + std::vector> + sp_ized_strata; for (auto const &stratum : cost_aware_stratum_split(g, cost_map)) { - auto sp_ized_stratum = + auto sp_ized_stratum = unordered_multiset_of( transform(stratum, [&](std::unordered_set const &nodes) { return cost_aware_stratum_sync_sp_ization_unchecked( get_subgraph(g, nodes), cost_map); - }); + })); sp_ized_strata.push_back(sp_ized_stratum); } return normalize_sp_decomposition( - serial_composition(transform(sp_ized_strata, parallel_composition))); + series_composition(transform(sp_ized_strata, parallel_composition))); } -SerialParallelDecomposition cost_aware_stratum_sync_sp_ization( +SeriesParallelDecomposition cost_aware_stratum_sync_sp_ization( DiGraphView const &g, std::unordered_map const &cost_map) { assert(is_acyclic(g)); - SerialParallelDecomposition sp = + SeriesParallelDecomposition sp = cost_aware_stratum_sync_sp_ization_unchecked(g, cost_map); assert(dependencies_are_maintained(g, sp)); return sp; diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index 5ef40ac2c5..2148b13a0c 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -245,4 +245,3 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(get_weakly_connected_components(g) == expected_components); } } - diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc index a635658755..6c670ae93a 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc @@ -4,6 +4,7 @@ #include "utils/containers/transform.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" diff --git a/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc deleted file mode 100644 index 39596ee6f1..0000000000 --- a/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc +++ /dev/null @@ -1,201 +0,0 @@ -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/instances/adjacency_digraph.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_serial_parallel_decomposition (base case)") { - DiGraph g = DiGraph::create(); - Node n = g.add_node(); - - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{n}; - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (parallel)") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 2); - - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{ParallelSplit{ - n.at(0), - n.at(1), - }}; - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (serial)") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 2); - g.add_edge(DirectedEdge{n.at(0), n.at(1)}); - - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{SerialSplit{ - n.at(0), - n.at(1), - }}; - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (composite)") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 3); - add_edges(g, - { - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(2)}, - }); - - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{ - SerialSplit{ - n.at(0), - ParallelSplit{ - n.at(1), - n.at(2), - }, - }, - }; - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (diamond graph)") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 6); - - add_edges(g, - { - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(3)}, - DirectedEdge{n.at(2), n.at(4)}, - DirectedEdge{n.at(3), n.at(5)}, - DirectedEdge{n.at(4), n.at(5)}, - }); - - std::optional correct = - SerialParallelDecomposition{SerialSplit{ - n.at(0), - ParallelSplit{ - SerialSplit{ - n.at(1), - n.at(3), - }, - SerialSplit{ - n.at(2), - n.at(4), - }, - }, - n.at(5), - }}; - - std::optional result = - get_serial_parallel_decomposition(g); - - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (all-to-all connection)") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - - add_edges(g, - { - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(0), n.at(3)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(1), n.at(3)}, - }); - - std::optional correct = - SerialParallelDecomposition{ - SerialSplit{ - ParallelSplit{ - n.at(0), - n.at(1), - }, - ParallelSplit{ - n.at(2), - n.at(3), - }, - }, - }; - - std::optional result = - get_serial_parallel_decomposition(g); - - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (non-sp graph)") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - - // N-graph - add_edges(g, - { - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(1), n.at(3)}, - }); - - std::optional correct = std::nullopt; - std::optional result = - get_serial_parallel_decomposition(g); - - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition") { - DiGraph g = DiGraph::create(); - SUBCASE("base case") { - Node n = g.add_node(); - std::optional result = - get_serial_parallel_decomposition_with_dummy_nodes(g, {}); - std::optional correct = - SerialParallelDecomposition{n}; - CHECK(result == correct); - } - SUBCASE("SerialSplit") { - std::vector n = add_nodes(g, 3); - add_edges( - g, {DirectedEdge{n.at(0), n.at(1)}, DirectedEdge{n.at(1), n.at(2)}}); - std::optional result = - get_serial_parallel_decomposition_with_dummy_nodes(g, {n[1]}); - std::optional correct = - SerialParallelDecomposition{SerialSplit{n[0], n[2]}}; - CHECK(result == correct); - } - - SUBCASE("ParallelSplit") { - std::vector n = add_nodes(g, 5); - add_edges(g, - {DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(2), n.at(3)}, - DirectedEdge{n.at(2), n.at(4)}}); - std::optional result = - get_serial_parallel_decomposition_with_dummy_nodes(g, {n[2]}); - std::optional correct = - SerialParallelDecomposition{SerialSplit{ParallelSplit{n[0], n[1]}, - ParallelSplit{n[3], n[4]}}}; - CHECK(result == correct); - } - - // TODO(@pietro) additional testing - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index fee971e5e0..5ee2e7c224 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -62,7 +62,7 @@ TEST_SUITE(FF_TEST_SUITE) { BinarySPDecompositionTree result = left_associative_binary_sp_tree_from_nary(input); - // we use multiple checks here because SerialParallelDecomposition's + // we use multiple checks here because SeriesParallelDecomposition's // ParallelSplit is unordered, so there are multiple possible // left-associative binary SP trees CHECK(is_binary_sp_tree_left_associative(result)); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index 532ff86c90..7b43f52b8f 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -60,7 +60,7 @@ TEST_SUITE(FF_TEST_SUITE) { BinarySPDecompositionTree result = right_associative_binary_sp_tree_from_nary(input); - // we use multiple checks here because SerialParallelDecomposition's + // we use multiple checks here because SeriesParallelDecomposition's // ParallelSplit is unordered, so there are multiple possible // right-associative binary SP trees CHECK(is_binary_sp_tree_right_associative(result)); diff --git a/lib/utils/test/src/utils/graph/series_parallel/digraph_generation.cc b/lib/utils/test/src/utils/graph/series_parallel/digraph_generation.cc index 07f080e180..6199aa6ae3 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/digraph_generation.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/digraph_generation.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/digraph_generation.h" +#include "utils/graph/series_parallel/digraph_generation.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/node/algorithms.h" @@ -9,30 +9,31 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("digraph_from_sp_decomposition") { SUBCASE("Empty") { - SerialParallelDecomposition input = - SerialParallelDecomposition(ParallelSplit{}); + SeriesParallelDecomposition input = + SeriesParallelDecomposition(ParallelSplit{{}}); DiGraph result = digraph_from_sp_decomposition(input); CHECK(num_nodes(result) == 0); CHECK(num_edges(result) == 0); } + SUBCASE("Complex Empty") { - SerialParallelDecomposition input = SerialParallelDecomposition( - ParallelSplit{SerialSplit{}, SerialSplit{ParallelSplit{}}}); + SeriesParallelDecomposition input = SeriesParallelDecomposition( + ParallelSplit{{SeriesSplit{{}}, SeriesSplit{{ParallelSplit{{}}}}}}); DiGraph result = digraph_from_sp_decomposition(input); CHECK(num_nodes(result) == 0); CHECK(num_edges(result) == 0); } SUBCASE("Single Node") { - SerialParallelDecomposition input = SerialParallelDecomposition(Node(1)); + SeriesParallelDecomposition input = SeriesParallelDecomposition(Node(1)); DiGraph result = digraph_from_sp_decomposition(input); CHECK(num_nodes(result) == 1); CHECK(num_edges(result) == 0); } - SUBCASE("Simple SerialSplit") { - SerialParallelDecomposition input = - SerialParallelDecomposition{SerialSplit{Node(1), Node(2), Node(3)}}; + SUBCASE("Simple SeriesSplit") { + SeriesParallelDecomposition input = + SeriesParallelDecomposition{SeriesSplit{{Node(1), Node(2), Node(3)}}}; DiGraph result = digraph_from_sp_decomposition(input); CHECK(num_nodes(result) == 3); CHECK(num_edges(result) == 2); @@ -41,8 +42,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Simple ParallelSplit") { - SerialParallelDecomposition input = - SerialParallelDecomposition{ParallelSplit{Node(1), Node(2), Node(3)}}; + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{{Node(1), Node(2), Node(3)}}}; DiGraph result = digraph_from_sp_decomposition(input); CHECK(num_nodes(result) == 3); CHECK(num_edges(result) == 0); @@ -51,9 +52,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Mixed Serial-Parallel") { - SerialParallelDecomposition input = SerialParallelDecomposition{ - SerialSplit{ParallelSplit{Node(1), Node(2)}, - ParallelSplit{Node(3), Node(4)}}}; + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{ParallelSplit{{Node(1), Node(2)}}, + ParallelSplit{{Node(3), Node(4)}}}}}; DiGraph result = digraph_from_sp_decomposition(input); CHECK(num_nodes(result) == 4); CHECK(num_edges(result) == 4); @@ -62,9 +63,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Mixed Parallel-Serial") { - SerialParallelDecomposition input = - SerialParallelDecomposition{ParallelSplit{ - SerialSplit{Node(1), Node(2)}, SerialSplit{Node(3), Node(4)}}}; + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{{SeriesSplit{{Node(1), Node(2)}}, + SeriesSplit{{Node(3), Node(4)}}}}}; DiGraph result = digraph_from_sp_decomposition(input); CHECK(num_nodes(result) == 4); CHECK(num_edges(result) == 2); @@ -73,8 +74,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Rhombus") { - SerialParallelDecomposition input = SerialParallelDecomposition{ - SerialSplit{Node(1), ParallelSplit{Node(2), Node(3)}, Node(4)}}; + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{Node(1), ParallelSplit{{Node(2), Node(3)}}, Node(4)}}}; DiGraph result = digraph_from_sp_decomposition(input); CHECK(num_nodes(result) == 4); CHECK(num_edges(result) == 4); @@ -83,8 +84,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Duplicate Nodes") { - SerialParallelDecomposition input = SerialParallelDecomposition{ - SerialSplit{Node(1), ParallelSplit{Node(1), Node(2)}, Node(1)}}; + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{Node(1), ParallelSplit{{Node(1), Node(2)}}, Node(1)}}}; DiGraph result = digraph_from_sp_decomposition(input); CHECK(num_nodes(result) == 4); CHECK(num_edges(result) == 4); @@ -93,12 +94,13 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Complex Graph") { - SerialParallelDecomposition input = SerialParallelDecomposition{ - SerialSplit{ParallelSplit{SerialSplit{ParallelSplit{Node(1), Node(2)}, - ParallelSplit{Node(3), Node(4)}, - Node(5)}, - SerialSplit{Node(6), Node(7)}}, - Node(8)}}; + SeriesParallelDecomposition input = + SeriesParallelDecomposition{SeriesSplit{ + {ParallelSplit{{SeriesSplit{{ParallelSplit{{Node(1), Node(2)}}, + ParallelSplit{{Node(3), Node(4)}}, + Node(5)}}, + SeriesSplit{{Node(6), Node(7)}}}}, + Node(8)}}}; DiGraph result = digraph_from_sp_decomposition(input); CHECK(num_nodes(result) == 8); diff --git a/lib/utils/test/src/utils/graph/series_parallel/get_ancestors.cc b/lib/utils/test/src/utils/graph/series_parallel/get_ancestors.cc index 5959c70c1d..5e4f19a0a2 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/get_ancestors.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/get_ancestors.cc @@ -1,5 +1,5 @@ -#include "utils/graph/serial_parallel/get_ancestors.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/get_ancestors.h" +#include "utils/fmt/unordered_set.h" #include #include @@ -11,51 +11,53 @@ TEST_SUITE(FF_TEST_SUITE) { Node(0), Node(1), Node(2), Node(3), Node(4), Node(5), Node(6), Node(7)}; SUBCASE("Single Node") { - SerialParallelDecomposition sp = SerialParallelDecomposition{n.at(0)}; + SeriesParallelDecomposition sp = SeriesParallelDecomposition{n.at(0)}; std::unordered_set correct = {}; std::unordered_set result = get_ancestors(sp, n.at(0)); CHECK(correct == result); } SUBCASE("Simple Serial") { - SerialParallelDecomposition sp = - SerialParallelDecomposition{SerialSplit{n.at(0), n.at(1), n.at(2)}}; + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{SeriesSplit{{n.at(0), n.at(1), n.at(2)}}}; std::unordered_set correct = {n.at(0), n.at(1)}; std::unordered_set result = get_ancestors(sp, n.at(2)); CHECK(correct == result); } SUBCASE("Simple Parallel") { - SerialParallelDecomposition sp = - SerialParallelDecomposition{ParallelSplit{n.at(0), n.at(1), n.at(2)}}; + SeriesParallelDecomposition sp = SeriesParallelDecomposition{ + ParallelSplit{{n.at(0), n.at(1), n.at(2)}}}; std::unordered_set correct = {}; std::unordered_set result = get_ancestors(sp, n.at(1)); CHECK(correct == result); } SUBCASE("Tree") { - SerialParallelDecomposition sp = SerialParallelDecomposition{SerialSplit{ - n.at(0), ParallelSplit{SerialSplit{n.at(1), n.at(2)}, n.at(3)}}}; + SeriesParallelDecomposition sp = SeriesParallelDecomposition{SeriesSplit{ + {n.at(0), + ParallelSplit{{SeriesSplit{{n.at(1), n.at(2)}}, n.at(3)}}}}}; std::unordered_set correct = {n.at(0), n.at(1)}; std::unordered_set result = get_ancestors(sp, n.at(2)); CHECK(correct == result); } SUBCASE("Rhombus") { - SerialParallelDecomposition sp = SerialParallelDecomposition{ - SerialSplit{n.at(0), ParallelSplit{n.at(1), n.at(2)}, n.at(3)}}; + SeriesParallelDecomposition sp = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), ParallelSplit{{n.at(1), n.at(2)}}, n.at(3)}}}; std::unordered_set correct = {n.at(0), n.at(1), n.at(2)}; std::unordered_set result = get_ancestors(sp, n.at(3)); CHECK(correct == result); } SUBCASE("Complex Structure") { - SerialParallelDecomposition sp = SerialParallelDecomposition{SerialSplit{ - n.at(0), - ParallelSplit{ - SerialSplit{n.at(1), ParallelSplit{n.at(2), n.at(3)}, n.at(4)}, - SerialSplit{n.at(5), n.at(6)}}, - n.at(7)}}; + SeriesParallelDecomposition sp = SeriesParallelDecomposition{SeriesSplit{ + {n.at(0), + ParallelSplit{ + {SeriesSplit{ + {n.at(1), ParallelSplit{{n.at(2), n.at(3)}}, n.at(4)}}, + SeriesSplit{{n.at(5), n.at(6)}}}}, + n.at(7)}}}; std::unordered_set correct = {n.at(0), n.at(1), n.at(2), n.at(3)}; std::unordered_set result = get_ancestors(sp, n.at(4)); CHECK(correct == result); diff --git a/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index e5b9045739..7b174bc384 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -189,4 +189,43 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } + + TEST_CASE("get_series_parallel_decomposition with dummy nodes") { + DiGraph g = DiGraph::create(); + + SUBCASE("base case") { + Node n = g.add_node(); + std::optional result = + get_series_parallel_decomposition_with_sync_nodes(g, {}); + std::optional correct = + SeriesParallelDecomposition{n}; + CHECK(result == correct); + } + + SUBCASE("SeriesSplit") { + std::vector n = add_nodes(g, 3); + add_edges( + g, {DirectedEdge{n.at(0), n.at(1)}, DirectedEdge{n.at(1), n.at(2)}}); + std::optional result = + get_series_parallel_decomposition_with_sync_nodes(g, {n[1]}); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{{n[0], n[2]}}}; + CHECK(result == correct); + } + + SUBCASE("ParallelSplit") { + std::vector n = add_nodes(g, 5); + add_edges(g, + {DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}}); + std::optional result = + get_series_parallel_decomposition_with_sync_nodes(g, {n[2]}); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ + {ParallelSplit{{n[0], n[1]}}, ParallelSplit{{n[3], n[4]}}}}}; + CHECK(result == correct); + } + } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/normalize_sp_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/normalize_sp_decomposition.cc index d1300da137..dd51f31093 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/normalize_sp_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/normalize_sp_decomposition.cc @@ -1,5 +1,5 @@ -#include "utils/graph/serial_parallel/normalize_sp_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/normalize_sp_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include using namespace FlexFlow; @@ -11,61 +11,62 @@ TEST_SUITE(FF_TEST_SUITE) { Node n3 = Node(3); SUBCASE("Empty") { - SerialParallelDecomposition input = SerialParallelDecomposition{ - SerialSplit{ParallelSplit{}, ParallelSplit{}}}; - SerialParallelDecomposition correct = - SerialParallelDecomposition{SerialSplit{}}; - SerialParallelDecomposition result = normalize_sp_decomposition(input); + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{ParallelSplit{{}}, ParallelSplit{{}}}}}; + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{}}}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); CHECK(correct == result); } SUBCASE("Node Decomposition") { - SerialParallelDecomposition input = SerialParallelDecomposition{n1}; - SerialParallelDecomposition correct = SerialParallelDecomposition{n1}; - SerialParallelDecomposition result = normalize_sp_decomposition(input); + SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); CHECK(correct == result); } SUBCASE("Serial with Single Node") { - SerialParallelDecomposition input = - SerialParallelDecomposition{SerialSplit{n1}}; - SerialParallelDecomposition correct = SerialParallelDecomposition{n1}; - SerialParallelDecomposition result = normalize_sp_decomposition(input); + SeriesParallelDecomposition input = + SeriesParallelDecomposition{SeriesSplit{{n1}}}; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); CHECK(correct == result); } SUBCASE("Parallel with Single Node") { - SerialParallelDecomposition input = - SerialParallelDecomposition{ParallelSplit{n1}}; - SerialParallelDecomposition correct = SerialParallelDecomposition{n1}; - SerialParallelDecomposition result = normalize_sp_decomposition(input); + SeriesParallelDecomposition input = + SeriesParallelDecomposition{ParallelSplit{{n1}}}; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); CHECK(correct == result); } SUBCASE("Mixed Serial") { - SerialParallelDecomposition input = - SerialParallelDecomposition{SerialSplit{ParallelSplit{n1}, n2}}; - SerialParallelDecomposition correct = - SerialParallelDecomposition{SerialSplit{n1, n2}}; - SerialParallelDecomposition result = normalize_sp_decomposition(input); + SeriesParallelDecomposition input = + SeriesParallelDecomposition{SeriesSplit{{ParallelSplit{{n1}}, n2}}}; + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{n1, n2}}}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); CHECK(correct == result); } SUBCASE("Mixed Parallel") { - SerialParallelDecomposition input = - SerialParallelDecomposition{ParallelSplit{SerialSplit{n1}, n2}}; - SerialParallelDecomposition correct = - SerialParallelDecomposition{ParallelSplit{n1, n2}}; - SerialParallelDecomposition result = normalize_sp_decomposition(input); + SeriesParallelDecomposition input = + SeriesParallelDecomposition{ParallelSplit{{SeriesSplit{{n1}}, n2}}}; + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{{n1, n2}}}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); CHECK(correct == result); } SUBCASE("Nested") { - SerialParallelDecomposition input = SerialParallelDecomposition{ - ParallelSplit{SerialSplit{ParallelSplit{n1, n2}}, n3, SerialSplit{}}}; - SerialParallelDecomposition correct = - SerialParallelDecomposition{ParallelSplit{n1, n2, n3}}; - SerialParallelDecomposition result = normalize_sp_decomposition(input); + SeriesParallelDecomposition input = + SeriesParallelDecomposition{ParallelSplit{ + {SeriesSplit{{ParallelSplit{{n1, n2}}}}, n3, SeriesSplit{{}}}}}; + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{{n1, n2, n3}}}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); CHECK(correct == result); } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/serial_parallel_splits.cc b/lib/utils/test/src/utils/graph/series_parallel/serial_parallel_splits.cc deleted file mode 100644 index c08a926875..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/serial_parallel_splits.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("ParallelSplit and SerialSplit equality checks") { - - SUBCASE("ParallelSplit::operator== - commutativity") { - ParallelSplit p1 = ParallelSplit{Node(1), Node(2), Node(3)}; - ParallelSplit p2 = ParallelSplit{Node(2), Node(1), Node(3)}; - ParallelSplit p3 = ParallelSplit{Node(3), Node(2), Node(1)}; - CHECK(p1 == p2); - CHECK(p2 == p3); - CHECK(p1 == p3); - } - - SUBCASE("SerialSplit::operator== - non-commutativity") { - SerialSplit p1 = SerialSplit{Node(1), Node(2), Node(3)}; - SerialSplit p2 = SerialSplit{Node(2), Node(1), Node(3)}; - SerialSplit p3 = SerialSplit{Node(3), Node(2), Node(1)}; - CHECK(p1 != p2); - CHECK(p2 != p3); - CHECK(p1 != p3); - } - - SUBCASE("operator==, mixed case, nested commutativity") { - std::vector n = {Node(0), Node(1), Node(2), Node(3)}; - - // All definitions are equivalent, since ParallelSplit commutes - ParallelSplit p1 = ParallelSplit{ - n.at(3), SerialSplit{ParallelSplit{n.at(2), n.at(1)}, n.at(2)}}; - ParallelSplit p2 = ParallelSplit{ - n.at(3), SerialSplit{ParallelSplit{n.at(1), n.at(2)}, n.at(2)}}; - ParallelSplit p3 = ParallelSplit{ - SerialSplit{ParallelSplit{n.at(1), n.at(2)}, n.at(2)}, n.at(3)}; - ParallelSplit p4 = ParallelSplit{ - SerialSplit{ParallelSplit{n.at(2), n.at(1)}, n.at(2)}, n.at(3)}; - - CHECK(p1 == p2); - CHECK(p1 == p3); - CHECK(p1 == p4); - CHECK(p2 == p3); - CHECK(p2 == p4); - CHECK(p3 == p4); - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc index a89b414e0b..6d0ec45dd5 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -158,70 +158,79 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - TEST_CASE("is_empty(SerialParallelDecomposition)") { + TEST_CASE("is_empty(SeriesParallelDecomposition)") { Node n1{1}; Node n2{2}; SUBCASE("Node Decomposition") { - SerialParallelDecomposition sp{n1}; + SeriesParallelDecomposition sp{n1}; CHECK_FALSE(is_empty(sp)); } SUBCASE("Empty Serial") { - SerialParallelDecomposition sp{SerialSplit{}}; + SeriesParallelDecomposition sp{ + SeriesSplit{std::vector>{}}}; CHECK(is_empty(sp)); } SUBCASE("Empty Parallel") { - SerialParallelDecomposition sp{ParallelSplit{}}; + SeriesParallelDecomposition sp{ParallelSplit{{}}}; CHECK(is_empty(sp)); } SUBCASE("Serial with Node") { - SerialParallelDecomposition sp{SerialSplit{n1}}; + SeriesParallelDecomposition sp{SeriesSplit{{n1}}}; CHECK_FALSE(is_empty(sp)); } SUBCASE("Parallel with Node") { - SerialParallelDecomposition sp{ParallelSplit{n1}}; + SeriesParallelDecomposition sp{ParallelSplit{{n1}}}; CHECK_FALSE(is_empty(sp)); } SUBCASE("Nested Serial") { - SerialParallelDecomposition sp{SerialSplit{ParallelSplit{}}}; + SeriesParallelDecomposition sp{SeriesSplit{{ParallelSplit{{}}}}}; CHECK(is_empty(sp)); } SUBCASE("Nested Parallel") { - SerialParallelDecomposition sp{ParallelSplit{SerialSplit{}}}; + SeriesParallelDecomposition sp{ParallelSplit{ + {SeriesSplit{std::vector>{}}}}}; CHECK(is_empty(sp)); } SUBCASE("Sparse") { - SerialSplit sp{ParallelSplit{}, ParallelSplit{SerialSplit{}}}; + SeriesSplit sp{{ParallelSplit{{}}, + ParallelSplit{{SeriesSplit{ + std::vector>{}}}}}}; CHECK(is_empty(sp)); } SUBCASE("Sparse with Node") { - SerialSplit sp{ParallelSplit{}, ParallelSplit{SerialSplit{}, n2}}; + SeriesSplit sp{ + {ParallelSplit{{}}, + ParallelSplit{ + {SeriesSplit{std::vector>{}}, + n2}}}}; CHECK_FALSE(is_empty(sp)); } } + TEST_CASE("delete_node") { Node n1{1}, n2{2}, n3{3}, n4{4}, n5{5}; SUBCASE("Node") { - SerialParallelDecomposition sp{n1}; + SeriesParallelDecomposition sp{n1}; - SerialParallelDecomposition result = delete_node(sp, n2); + SeriesParallelDecomposition result = delete_node(sp, n2); CHECK(result == sp); } - SUBCASE("SerialSplit") { - SerialParallelDecomposition sp{SerialSplit{{n1, n2, n3}}}; + SUBCASE("SeriesSplit") { + SeriesParallelDecomposition sp{SeriesSplit{{n1, n2, n3}}}; auto result = delete_node(sp, n2); - SerialParallelDecomposition expected{SerialSplit{{n1, n3}}}; + SeriesParallelDecomposition expected{SeriesSplit{{n1, n3}}}; CHECK(result == expected); result = delete_node(sp, n4); @@ -229,10 +238,10 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("ParallelSplit") { - SerialParallelDecomposition sp{ParallelSplit{{n1, n2, n3}}}; + SeriesParallelDecomposition sp{ParallelSplit{{n1, n2, n3}}}; auto result = delete_node(sp, n2); - SerialParallelDecomposition expected{ParallelSplit{{n1, n3}}}; + SeriesParallelDecomposition expected{ParallelSplit{{n1, n3}}}; CHECK(result == expected); result = delete_node(sp, n4); @@ -240,17 +249,17 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested structure, duplicate nodes") { - SerialParallelDecomposition sp{SerialSplit{ - {n1, ParallelSplit{{n2, SerialSplit{{n3, n4, n1}}, n5, n1}}}}}; + SeriesParallelDecomposition sp{SeriesSplit{ + {n1, ParallelSplit{{n2, SeriesSplit{{n3, n4, n1}}, n5, n1}}}}}; auto result = delete_node(sp, n3); - SerialParallelDecomposition expected{SerialSplit{ - {n1, ParallelSplit{{n2, SerialSplit{{n4, n1}}, n5, n1}}}}}; + SeriesParallelDecomposition expected{SeriesSplit{ + {n1, ParallelSplit{{n2, SeriesSplit{{n4, n1}}, n5, n1}}}}}; CHECK(result == expected); result = delete_node(sp, n1); - expected = SerialParallelDecomposition{ - SerialSplit{{ParallelSplit{{n2, SerialSplit{{n3, n4}}, n5}}}}}; + expected = SeriesParallelDecomposition{ + SeriesSplit{{ParallelSplit{{n2, SeriesSplit{{n3, n4}}, n5}}}}}; CHECK(result == expected); } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.cc index 130fa93411..b2a797cb69 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.cc @@ -1,12 +1,14 @@ -#include "utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.h" -#include "test/utils/doctest.h" +#include "utils/graph/series_parallel/sp_ization/critical_path_preserving_sp_ization.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_metrics.h" -#include "utils/graph/serial_parallel/serial_parallel_splits.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_metrics.h" +#include "utils/graph/series_parallel/series_parallel_splits.h" +#include + +using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { @@ -37,17 +39,17 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 13); CHECK(critical_path_cost(g, cost_map) == 12); - SerialParallelDecomposition sp = critical_path_preserving_sp_ization(g); + SeriesParallelDecomposition sp = critical_path_preserving_sp_ization(g); SUBCASE("structure") { Node sp0 = n.at(0); - SerialSplit sp1 = SerialSplit{sp0, n.at(1)}; - SerialSplit sp2 = SerialSplit{ParallelSplit{sp0, sp1}, n.at(2)}; - SerialSplit sp3 = SerialSplit{n.at(0), n.at(1), n.at(3)}; - SerialSplit sp4 = SerialSplit{ParallelSplit{sp2, sp3}, n.at(4)}; - SerialSplit sp5 = SerialSplit{ParallelSplit{sp3, sp4}, n.at(5)}; - SerialParallelDecomposition correct(sp5); - SerialParallelDecomposition result = sp; + SeriesSplit sp1 = SeriesSplit{{sp0, n.at(1)}}; + SeriesSplit sp2 = SeriesSplit{{ParallelSplit{{sp0, sp1}}, n.at(2)}}; + SeriesSplit sp3 = SeriesSplit{{n.at(0), n.at(1), n.at(3)}}; + SeriesSplit sp4 = SeriesSplit{{ParallelSplit{{sp2, sp3}}, n.at(4)}}; + SeriesSplit sp5 = SeriesSplit{{ParallelSplit{{sp3, sp4}}, n.at(5)}}; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{sp5}; + SeriesParallelDecomposition result = sp; CHECK(correct == result); } SUBCASE("work cost") { @@ -86,14 +88,15 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 15); CHECK(critical_path_cost(g, cost_map) == 12); - SerialParallelDecomposition sp = critical_path_preserving_sp_ization(g); + SeriesParallelDecomposition sp = critical_path_preserving_sp_ization(g); SUBCASE("structure") { - SerialParallelDecomposition correct(SerialSplit{ - ParallelSplit{SerialSplit{n.at(0), n.at(1), n.at(3), n.at(4)}, - SerialSplit{n.at(0), n.at(2)}}, - n.at(5)}); - SerialParallelDecomposition result = sp; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{ParallelSplit{ + {SeriesSplit{{n.at(0), n.at(1), n.at(3), n.at(4)}}, + SeriesSplit{{n.at(0), n.at(2)}}}}, + n.at(5)}}}; + SeriesParallelDecomposition result = sp; CHECK(correct == result); } SUBCASE("work cost") { @@ -137,17 +140,19 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 9); CHECK(critical_path_cost(g, cost_map) == 7); - SerialParallelDecomposition sp = + SeriesParallelDecomposition sp = critical_path_preserving_sp_ization_with_coalescing(g); SUBCASE("structure") { - SerialParallelDecomposition correct(SerialSplit{ - n.at(0), - n.at(1), - ParallelSplit{SerialSplit{ParallelSplit{n.at(2), n.at(3)}, n.at(4)}, - n.at(3)}, - n.at(5)}); - SerialParallelDecomposition result = sp; + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{ + {n.at(0), + n.at(1), + ParallelSplit{ + {SeriesSplit{{ParallelSplit{{n.at(2), n.at(3)}}, n.at(4)}}, + n.at(3)}}, + n.at(5)}}}; + SeriesParallelDecomposition result = sp; CHECK(correct == result); } SUBCASE("work cost") { @@ -186,15 +191,15 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 15); CHECK(critical_path_cost(g, cost_map) == 12); - SerialParallelDecomposition sp = + SeriesParallelDecomposition sp = critical_path_preserving_sp_ization_with_coalescing(g); SUBCASE("structure") { - SerialParallelDecomposition correct(SerialSplit{ - n.at(0), - ParallelSplit{SerialSplit{n.at(1), n.at(3), n.at(4)}, n.at(2)}, - n.at(5)}); - SerialParallelDecomposition result = sp; + SeriesParallelDecomposition correct(SeriesSplit{ + {n.at(0), + ParallelSplit{{SeriesSplit{{n.at(1), n.at(3), n.at(4)}}, n.at(2)}}, + n.at(5)}}); + SeriesParallelDecomposition result = sp; CHECK(correct == result); } SUBCASE("work cost") { @@ -243,24 +248,27 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 43); CHECK(critical_path_cost(g, cost_map) == 26); - SerialParallelDecomposition sp = + SeriesParallelDecomposition sp = critical_path_preserving_sp_ization_with_coalescing(g); SUBCASE("structure") { - SerialParallelDecomposition correct(SerialSplit{ - n.at(0), - ParallelSplit{ - SerialSplit{n.at(1), n.at(2), n.at(6)}, - SerialSplit{ParallelSplit{ - SerialSplit{ParallelSplit{n.at(1), n.at(3)}, - ParallelSplit{ - n.at(4), - SerialSplit{n.at(5), n.at(7)}}}, - n.at(3)}, - n.at(8)}}, - n.at(9)}); - SerialParallelDecomposition result = sp; + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{ + {n.at(0), + ParallelSplit{ + {SeriesSplit{{n.at(1), n.at(2), n.at(6)}}, + SeriesSplit{ + {ParallelSplit{ + {SeriesSplit{ + {ParallelSplit{{n.at(1), n.at(3)}}, + ParallelSplit{ + {n.at(4), + SeriesSplit{{n.at(5), n.at(7)}}}}}}, + n.at(3)}}, + n.at(8)}}}}, + n.at(9)}}}; + SeriesParallelDecomposition result = sp; CHECK(correct == result); }; SUBCASE("work cost") { @@ -290,10 +298,10 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(3), n.at(4)}, }); - SerialParallelDecomposition result = + SeriesParallelDecomposition result = critical_path_preserving_sp_ization_with_coalescing(g); - SerialParallelDecomposition correct = - SerialParallelDecomposition{SerialSplit{ + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{ {n.at(0), n.at(1), ParallelSplit{{n.at(2), n.at(3)}}, n.at(4)}}}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/is_valid_sp_ization.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc similarity index 71% rename from lib/utils/test/src/utils/graph/series_parallel/sp_ization/is_valid_sp_ization.cc rename to lib/utils/test/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc index 89f5e23234..e3f09253ea 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/is_valid_sp_ization.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/sp_ization/dependencies_are_maintained.h" +#include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" #include "utils/containers/get_only.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" @@ -11,17 +11,17 @@ TEST_SUITE(FF_TEST_SUITE) { DiGraph g = DiGraph::create(); SUBCASE("Single Node") { std::vector n = add_nodes(g, 1); - SerialParallelDecomposition sp = - SerialParallelDecomposition{SerialSplit{n[0]}}; + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{SeriesSplit{{n[0]}}}; CHECK(dependencies_are_maintained(g, sp)); } - SUBCASE("SerialSplit") { + SUBCASE("SeriesSplit") { SUBCASE("Valid SP-ization") { std::vector n = add_nodes(g, 3); add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}}); - SerialParallelDecomposition sp = - SerialParallelDecomposition{SerialSplit{{n[0], n[1], n[2]}}}; + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{SeriesSplit{{n[0], n[1], n[2]}}}; CHECK(dependencies_are_maintained(g, sp)); } @@ -29,8 +29,8 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector n = add_nodes(g, 3); add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}}); - SerialParallelDecomposition sp = - SerialParallelDecomposition{SerialSplit{{n[1], n[0], n[2]}}}; + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{SeriesSplit{{n[1], n[0], n[2]}}}; CHECK_FALSE(dependencies_are_maintained(g, sp)); } } @@ -38,16 +38,16 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("ParallelSplit") { SUBCASE("Valid SP-ization") { std::vector n = add_nodes(g, 3); - SerialParallelDecomposition sp = - SerialParallelDecomposition{ParallelSplit{{n[0], n[1], n[2]}}}; + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{ParallelSplit{{n[0], n[1], n[2]}}}; CHECK(dependencies_are_maintained(g, sp)); } SUBCASE("Incorrect SP-ization") { std::vector n = add_nodes(g, 3); - SerialParallelDecomposition sp = - SerialParallelDecomposition{ParallelSplit{{n[0], n[2]}}}; + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{ParallelSplit{{n[0], n[2]}}}; CHECK_FALSE(dependencies_are_maintained(g, sp)); } } @@ -60,17 +60,17 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(1), n.at(3)}, DirectedEdge{n.at(2), n.at(3)}}); SUBCASE("Valid SP-izations") { - SerialParallelDecomposition sp_correct = SerialParallelDecomposition{ - SerialSplit{{n.at(0), ParallelSplit{{n.at(1), n.at(2)}}, n.at(3)}}}; + SeriesParallelDecomposition sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), ParallelSplit{{n.at(1), n.at(2)}}, n.at(3)}}}; CHECK(dependencies_are_maintained(g, sp_correct)); - sp_correct = SerialParallelDecomposition{ - SerialSplit{{n.at(0), n.at(1), n.at(2), n.at(3)}}}; + sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), n.at(1), n.at(2), n.at(3)}}}; CHECK(dependencies_are_maintained(g, sp_correct)); } SUBCASE("Invalid SP-ization") { - SerialParallelDecomposition sp_incorrect = SerialParallelDecomposition{ - ParallelSplit{{n.at(0), SerialSplit{{n.at(1), n.at(3)}}, n.at(2)}}}; + SeriesParallelDecomposition sp_incorrect = SeriesParallelDecomposition{ + ParallelSplit{{n.at(0), SeriesSplit{{n.at(1), n.at(3)}}, n.at(2)}}}; CHECK_FALSE(dependencies_are_maintained(g, sp_incorrect)); } } @@ -88,23 +88,23 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("Valid SP-izations") { - SerialParallelDecomposition sp_correct = SerialParallelDecomposition{ - SerialSplit{{n.at(0), + SeriesParallelDecomposition sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), ParallelSplit{{n.at(1), n.at(2)}}, ParallelSplit{{n.at(3), n.at(4)}}, n.at(5)}}}; CHECK(dependencies_are_maintained(g, sp_correct)); - sp_correct = SerialParallelDecomposition{ - SerialSplit{{n.at(0), + sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), n.at(1), n.at(2), ParallelSplit{{n.at(3), n.at(4)}}, n.at(5)}}}; CHECK(dependencies_are_maintained(g, sp_correct)); - sp_correct = SerialParallelDecomposition{ - SerialSplit{{n.at(0), + sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), ParallelSplit{{n.at(1), n.at(2)}}, n.at(3), n.at(4), @@ -113,8 +113,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Invalid SP-izations") { - SerialParallelDecomposition sp_correct = SerialParallelDecomposition{ - SerialSplit{{n.at(0), + SeriesParallelDecomposition sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), ParallelSplit{{n.at(1), n.at(2), n.at(4)}}, n.at(3), n.at(5)}}}; diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/spanish_algo.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/spanish_algo.cc index 1152d881f2..95d5e9f753 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/spanish_algo.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/spanish_algo.cc @@ -1,70 +1,384 @@ -#include "utils/graph/serial_parallel/sp_ization/spanish_algo.h" +#include "utils/graph/series_parallel/sp_ization/spanish_algo.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/containers/values.h" #include "utils/graph/algorithms.h" -#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_incoming_edges.h" +#include "utils/graph/digraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/digraph/directed_edge.dtg.h" #include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/serial_parallel/sp_ization/dependencies_are_maintained.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_split.dtg.h" +#include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" +#include "utils/graph/series_parallel/sp_ization/node_role.dtg.h" #include +#include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("spanish_algo - subcomponents") { + SUBCASE("add_dummy_nodes") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + std::unordered_map node_types = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::PURE}, + {n.at(2), NodeRole::PURE}, + {n.at(3), NodeRole::PURE}, + }; - TEST_CASE("spanish_algorithm") { + DiGraph result = add_dummy_nodes(g, node_types); + CHECK(get_edges(result).size() == 6); + CHECK(get_nodes(result).size() == 6); + CHECK(get_incoming_edges(g, n.at(3)).size() == 2); + CHECK(get_outgoing_edges(g, n.at(0)).size() == 2); + + CHECK(node_types.size() == 6); + CHECK(values(node_types) == + std::unordered_multiset{NodeRole::PURE, + NodeRole::PURE, + NodeRole::PURE, + NodeRole::PURE, + NodeRole::DUMMY, + NodeRole::DUMMY}); + } - // SUBCASE("Single Node") { - // DiGraph g = DiGraph::create(); - // g.add_node(); - // SerialParallelDecomposition sp = one_node_at_a_time_spanish_sp_ization(g); - // CHECK(dependencies_are_maintained(g, sp)); - // } - // SUBCASE("Linear Graph") { - // DiGraph g = DiGraph::create(); - // std::vector n = add_nodes(g, 3); - // add_edges(g, - // { - // DirectedEdge{n[0], n[1]}, - // DirectedEdge{n[1], n[2]}}); - // SerialParallelDecomposition sp = one_node_at_a_time_spanish_sp_ization(g); - - // CHECK(dependencies_are_maintained(g, sp)); - - // } + SUBCASE("delete_dummy_nodes") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 5); + std::vector edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + add_edges(g, edges); - SUBCASE("Rhombus") { + std::unordered_map node_roles = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::DUMMY}, + {n.at(2), NodeRole::DUMMY}, + {n.at(3), NodeRole::PURE}, + {n.at(4), NodeRole::PURE}, + }; + + DiGraph result = delete_dummy_nodes(g, node_roles); + + CHECK(get_nodes(result) == + std::unordered_set{n.at(0), n.at(3), n.at(4)}); + CHECK(get_edges(result) == + std::unordered_set{DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}}); + } + SUBCASE("get_component") { + SUBCASE("2 layer graph, single simple component") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}}); + std::unordered_map node_roles = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::SYNC}, + {n.at(2), NodeRole::PURE}, + {n.at(3), NodeRole::PURE}, + }; + std::unordered_map depth_map = { + {n.at(0), 0}, + {n.at(2), 1}, + {n.at(3), 1}, + }; + std::unordered_set correct = {n.at(0), n.at(2), n.at(3)}; + std::unordered_set result = + get_component(g, n.at(2), depth_map, node_roles); + CHECK(correct == result); + } + SUBCASE("2 layer graph, single complex component") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}}); + std::unordered_map node_roles = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::PURE}, + {n.at(2), NodeRole::SYNC}, + {n.at(3), NodeRole::SYNC}, + {n.at(4), NodeRole::PURE}, + {n.at(5), NodeRole::PURE}, + }; + std::unordered_map depth_map = { + {n.at(0), 0}, + {n.at(1), 0}, + {n.at(4), 1}, + {n.at(5), 1}, + }; + SUBCASE("n.at(4)'s component") { + std::unordered_set correct = { + n.at(0), n.at(1), n.at(4), n.at(5)}; + std::unordered_set result = + get_component(g, n.at(4), depth_map, node_roles); + CHECK(correct == result); + } + SUBCASE("n.at(5)'s component") { + std::unordered_set correct = { + n.at(0), n.at(1), n.at(4), n.at(5)}; + std::unordered_set result = + get_component(g, n.at(5), depth_map, node_roles); + CHECK(correct == result); + } + } + SUBCASE("3 layer graph, single connected component") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 7); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}, + DirectedEdge{n.at(4), n.at(6)}}); + std::unordered_map node_roles = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::SYNC}, + {n.at(2), NodeRole::PURE}, + {n.at(3), NodeRole::PURE}, + {n.at(4), NodeRole::SYNC}, + {n.at(5), NodeRole::PURE}, + {n.at(6), NodeRole::PURE}}; + + std::unordered_map depth_map = {{n.at(0), 0}, + {n.at(2), 1}, + {n.at(3), 1}, + {n.at(5), 2}, + {n.at(6), 2}}; + SUBCASE("n.at(5)'s component") { + std::unordered_set correct = { + n.at(2), n.at(3), n.at(5), n.at(6)}; + std::unordered_set result = + get_component(g, n.at(5), depth_map, node_roles); + CHECK(correct == result); + } + + SUBCASE("n.at(6)'s component") { + std::unordered_set correct = { + n.at(2), n.at(3), n.at(5), n.at(6)}; + std::unordered_set result = + get_component(g, n.at(6), depth_map, node_roles); + CHECK(correct == result); + } + } + SUBCASE("3 layer graph, multiple weakly connected components") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 10); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(5)}, + DirectedEdge{n.at(3), n.at(6)}, + DirectedEdge{n.at(4), n.at(6)}, + DirectedEdge{n.at(5), n.at(7)}, + DirectedEdge{n.at(5), n.at(8)}, + DirectedEdge{n.at(6), n.at(9)}, + }); + std::unordered_map node_roles = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::SYNC}, + {n.at(2), NodeRole::PURE}, + {n.at(3), NodeRole::PURE}, + {n.at(4), NodeRole::PURE}, + {n.at(5), NodeRole::SYNC}, + {n.at(6), NodeRole::SYNC}, + {n.at(7), NodeRole::PURE}, + {n.at(8), NodeRole::PURE}, + {n.at(9), NodeRole::PURE}, + }; + + std::unordered_map depth_map = {{n.at(0), 0}, + {n.at(2), 1}, + {n.at(3), 1}, + {n.at(4), 1}, + {n.at(7), 2}, + {n.at(8), 2}, + {n.at(9), 2}}; + SUBCASE("n.at(7)'s component") { + std::unordered_set correct = {n.at(2), n.at(7), n.at(8)}; + std::unordered_set result = + get_component(g, n.at(7), depth_map, node_roles); + CHECK(correct == result); + } + SUBCASE("n.at(8)'s component") { + std::unordered_set correct = {n.at(2), n.at(7), n.at(8)}; + std::unordered_set result = + get_component(g, n.at(8), depth_map, node_roles); + CHECK(correct == result); + } + SUBCASE("n.at(9)'s component") { + std::unordered_set correct = {n.at(3), n.at(4), n.at(9)}; + std::unordered_set result = + get_component(g, n.at(9), depth_map, node_roles); + CHECK(correct == result); + } + } + } + } + + TEST_CASE("spanish_algorithm") { + + SUBCASE("Single Node") { + DiGraph g = DiGraph::create(); + Node n = g.add_node(); + SeriesParallelDecomposition sp = spanish_strata_sync(g); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{Node{n}}; + CHECK(sp == correct); + CHECK(dependencies_are_maintained(g, sp)); + } + SUBCASE("Linear Graph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}}); + SeriesParallelDecomposition sp = spanish_strata_sync(g); + CHECK(dependencies_are_maintained(g, sp)); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{n[0], n[1], n[2]}}}; + CHECK(sp == correct); + } + + SUBCASE("Rhombus") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[3]}}); + SeriesParallelDecomposition sp = spanish_strata_sync(g); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], ParallelSplit{{n[1], n[2]}}, n[3]}}}; + + CHECK(dependencies_are_maintained(g, sp)); + CHECK(correct == sp); + } + + SUBCASE("Sample Graph #1") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}, + }); + SeriesParallelDecomposition sp = spanish_strata_sync(g); + CHECK(dependencies_are_maintained(g, sp)); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], n[1], ParallelSplit{{n[2], n[3]}}, n[4], n[5]}}}; + CHECK(sp == correct); + } + + SUBCASE("Diamond without crossing") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, { DirectedEdge{n[0], n[1]}, DirectedEdge{n[0], n[2]}, DirectedEdge{n[1], n[3]}, - DirectedEdge{n[2], n[3]}} - ); - SerialParallelDecomposition sp = one_node_at_a_time_spanish_sp_ization(g); + DirectedEdge{n[2], n[5]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[4], n[5]}, + }); + SeriesParallelDecomposition sp = spanish_strata_sync(g); CHECK(dependencies_are_maintained(g, sp)); - + // SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + // SeriesSplit{{n[0], + // ParallelSplit{{SeriesSplit{{n[1], n[3], n[4]}}, + // n[2]}}, n[5]}}}; + // SeriesParallelDecomposition result = sp; + // CHECK(correct == result); + } - + SUBCASE("Diamond Graph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}, + }); + SeriesParallelDecomposition sp = spanish_strata_sync(g); + CHECK(dependencies_are_maintained(g, sp)); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{n[0], + ParallelSplit{{n[1], n[2]}}, + ParallelSplit{{n[3], n[4]}}, + n[5]}}}; + CHECK(sp == correct); } -// SUBCASE("Sample Graph #1") { -// DiGraph g = DiGraph::create(); -// std::vector n = add_nodes(g, 6); -// add_edges(g, -// {DirectedEdge{n[1], n[2]}}, - -// DirectedEdge{n[0], n[1]}, -// DirectedEdge{n[0], n[2]}, -// DirectedEdge{n[1], n[2]}, -// DirectedEdge{n[1], n[3]}, -// DirectedEdge{n[2], n[4]}, -// DirectedEdge{n[3], n[4]}, -// DirectedEdge{n[3], n[5]}, -// DirectedEdge{n[4], n[5]}, -// }); -// SerialParallelDecomposition sp = one_node_at_a_time_spanish_sp_ization(g); -// CHECK(dependencies_are_maintained(g, sp)); -// } + SUBCASE("Sample Graph #2") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 10); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[1], n[5]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[6]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[3], n[8]}, + DirectedEdge{n[4], n[8]}, + DirectedEdge{n[5], n[7]}, + DirectedEdge{n[7], n[8]}, + DirectedEdge{n[6], n[9]}, + DirectedEdge{n[8], n[9]}}); + SeriesParallelDecomposition sp = spanish_strata_sync(g); + CHECK(dependencies_are_maintained(g, sp)); + + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + ParallelSplit{{n[1], n[3]}}, + ParallelSplit{ + {SeriesSplit{{n[2], n[6]}}, + SeriesSplit{{ParallelSplit{ + {SeriesSplit{{n[5], n[7]}}, n[4]}}, + n[8]}}}}, + n[9]}}}; + CHECK(sp == correct); + } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.cc index 892af6ed26..00e44eecc4 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.cc @@ -1,16 +1,18 @@ -#include "utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.h" -#include "test/utils/doctest.h" +#include "utils/graph/series_parallel/sp_ization/work_preserving_sp_ization.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_metrics.h" -#include "utils/graph/serial_parallel/serial_parallel_splits.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_metrics.h" +#include "utils/graph/series_parallel/series_parallel_splits.h" +#include + +using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("work_preserving_") { + TEST_CASE("work_preserving_sp_ization") { SUBCASE("Sample Graph #1") { DiGraph g = DiGraph::create(); @@ -33,12 +35,12 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 9); CHECK(critical_path_cost(g, cost_map) == 7); - SerialParallelDecomposition sp = stratum_sync_sp_ization(g); + SeriesParallelDecomposition sp = stratum_sync_sp_ization(g); SUBCASE("structure") { - SerialParallelDecomposition correct( - SerialSplit{n[0], n[1], ParallelSplit{n[2], n[3]}, n[4], n[5]}); - SerialParallelDecomposition result = sp; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], n[1], ParallelSplit{{n[2], n[3]}}, n[4], n[5]}}}; + SeriesParallelDecomposition result = sp; CHECK(correct == result); } SUBCASE("work cost") { @@ -73,12 +75,12 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 15); CHECK(critical_path_cost(g, cost_map) == 12); - SerialParallelDecomposition sp = stratum_sync_sp_ization(g); + SeriesParallelDecomposition sp = stratum_sync_sp_ization(g); SUBCASE("structure") { - SerialParallelDecomposition correct( - SerialSplit{n[0], ParallelSplit{n[1], n[2]}, n[3], n[4], n[5]}); - SerialParallelDecomposition result = sp; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], ParallelSplit{{n[1], n[2]}}, n[3], n[4], n[5]}}}; + SeriesParallelDecomposition result = sp; CHECK(correct == result); } SUBCASE("work cost") { @@ -126,16 +128,16 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 45); CHECK(critical_path_cost(g, cost_map) == 23); - SerialParallelDecomposition sp = stratum_sync_sp_ization(g); + SeriesParallelDecomposition sp = stratum_sync_sp_ization(g); SUBCASE("structure") { - SerialParallelDecomposition correct( - SerialSplit{n[0], - ParallelSplit{n[1], n[3]}, - ParallelSplit{n[2], n[4], n[5]}, - ParallelSplit{n[6], n[7]}, - n[8]}); - SerialParallelDecomposition result = sp; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + ParallelSplit{{n[1], n[3]}}, + ParallelSplit{{n[2], n[4], n[5]}}, + ParallelSplit{{n[6], n[7]}}, + n[8]}}}; + SeriesParallelDecomposition result = sp; CHECK(correct == result); } SUBCASE("work cost") { @@ -175,13 +177,13 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 9); CHECK(critical_path_cost(g, cost_map) == 7); - SerialParallelDecomposition sp = + SeriesParallelDecomposition sp = cost_aware_stratum_sync_sp_ization(g, cost_map); SUBCASE("structure") { - SerialParallelDecomposition correct( - SerialSplit{n[0], n[1], ParallelSplit{n[2], n[3]}, n[4], n[5]}); - SerialParallelDecomposition result = sp; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], n[1], ParallelSplit{{n[2], n[3]}}, n[4], n[5]}}}; + SeriesParallelDecomposition result = sp; CHECK(correct == result); } SUBCASE("work cost") { @@ -216,13 +218,15 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 15); CHECK(critical_path_cost(g, cost_map) == 12); - SerialParallelDecomposition sp = + SeriesParallelDecomposition sp = cost_aware_stratum_sync_sp_ization(g, cost_map); SUBCASE("structure") { - SerialParallelDecomposition correct(SerialSplit{ - n[0], ParallelSplit{SerialSplit{n[1], n[3], n[4]}, n[2]}, n[5]}); - SerialParallelDecomposition result = sp; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + ParallelSplit{{SeriesSplit{{n[1], n[3], n[4]}}, n[2]}}, + n[5]}}}; + SeriesParallelDecomposition result = sp; CHECK(correct == result); } SUBCASE("work cost") { @@ -270,16 +274,16 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 42); CHECK(critical_path_cost(g, cost_map) == 25); - SerialParallelDecomposition sp = + SeriesParallelDecomposition sp = cost_aware_stratum_sync_sp_ization(g, cost_map); SUBCASE("structure") { - SerialParallelDecomposition correct( - SerialSplit{n[0], - ParallelSplit{SerialSplit{n[1], n[2], n[6]}, n[3]}, - ParallelSplit{n[4], SerialSplit{n[5], n[7]}}, - n[8]}); - SerialParallelDecomposition result = sp; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + ParallelSplit{{SeriesSplit{{n[1], n[2], n[6]}}, n[3]}}, + ParallelSplit{{n[4], SeriesSplit{{n[5], n[7]}}}}, + n[8]}}}; + SeriesParallelDecomposition result = sp; CHECK(correct == result); } SUBCASE("work cost") { @@ -333,19 +337,21 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(work_cost(g, cost_map) == 41); CHECK(critical_path_cost(g, cost_map) == 24); - SerialParallelDecomposition sp = + SeriesParallelDecomposition sp = cost_aware_stratum_sync_sp_ization(g, cost_map); SUBCASE("structure") { - SerialParallelDecomposition correct(SerialSplit{ - n[0], - ParallelSplit{SerialSplit{n[1], ParallelSplit{n[2], n[3]}, n[7]}, - n[5]}, - ParallelSplit{n[4], n[8], n[10]}, - ParallelSplit{n[6], n[12]}, - ParallelSplit{n[11], SerialSplit{n[9], n[13]}}, - n[14]}); - SerialParallelDecomposition result = sp; + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{ + {n[0], + ParallelSplit{ + {SeriesSplit{{n[1], ParallelSplit{{n[2], n[3]}}, n[7]}}, + n[5]}}, + ParallelSplit{{n[4], n[8], n[10]}}, + ParallelSplit{{n[6], n[12]}}, + ParallelSplit{{n[11], SeriesSplit{{n[9], n[13]}}}}, + n[14]}}}; + SeriesParallelDecomposition result = sp; CHECK(correct == result); } SUBCASE("work cost") {