From 2be849a61a3ceaab3da67f2fb4f9dd14f79fedff Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Wed, 9 Oct 2024 21:18:53 -0700 Subject: [PATCH 1/8] Graph Testing initial cleanup --- lib/utils/include/utils/graph/README.md | 168 ++++------- .../graph/dataflow_graph/dataflow_graph.h | 2 - .../graph/digraph/algorithms/get_dominators.h | 15 + .../include/utils/graph/digraph/digraph.h | 2 - .../utils/graph/digraph/digraph_view.h | 2 - .../instances/hashmap_undirected_graph.h | 0 .../utils/graph/multidigraph/multidigraph.h | 2 - lib/utils/include/utils/graph/node/graph.h | 2 - .../include/utils/graph/node/graph_view.h | 2 - .../include/utils/graph/node/node.struct.toml | 1 + .../utils/graph/undirected/undirected_edge.h | 27 +- .../undirected/undirected_edge.struct.toml | 18 ++ .../utils/graph/undirected/undirected_graph.h | 2 - .../graph/undirected/undirected_graph_view.h | 2 - lib/utils/src/utils/graph/algorithms.cc | 5 +- .../instances/hashmap_undirected_graph.cc | 26 +- .../unordered_set_undirected_graph.cc | 4 +- .../algorithms/get_neighboring_nodes.cc | 2 +- .../utils/graph/undirected/undirected_edge.cc | 33 +-- .../graph/undirected/undirected_edge_query.cc | 3 +- lib/utils/src/utils/graph/views/views.cc | 35 ++- .../graph/digraph/algorithms/algorithms.cc | 106 +++++++ .../utils/graph/digraph/algorithms/digraph.cc | 85 ++++++ .../digraph/algorithms/directed_edge_query.cc | 70 +++++ .../digraph/algorithms/get_dominators.cc | 68 +++++ .../algorithms/get_topological_ordering.cc | 36 +++ .../graph/digraph/algorithms/traversal.cc | 112 ++++++++ .../algorithms/get_incoming_edges.cc | 36 +++ .../algorithms/get_outgoing_edges.cc | 40 +++ .../algorithms/get_connected_components.cc | 26 ++ .../src/utils/graph/undirected/undirected.cc | 75 +++++ lib/utils/test/src/utils/graph/views/views.cc | 262 ++++++++++++++++++ 32 files changed, 1049 insertions(+), 220 deletions(-) rename lib/utils/{src => include}/utils/graph/instances/hashmap_undirected_graph.h (100%) create mode 100644 lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc create mode 100644 lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc create mode 100644 lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc create mode 100644 lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc create mode 100644 lib/utils/test/src/utils/graph/undirected/undirected.cc create mode 100644 lib/utils/test/src/utils/graph/views/views.cc diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 25b0103f9c..f3d31e7bc8 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -15,10 +15,16 @@ There is no single type of graph. Should it be directed? Allow multiple edges be Because there is no single answer to this question, similar to [networkx](https://networkx.org/) we provide a number of different graph variants. At their core, they are as follows: -- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected -- `DirectedGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) -- `MultiDiGraph`: arbitrary numbers of edges allowed between every pair of nodes, but each must have not only source/destination nodes but also _source/destination indices_, which serve to disambiguate different edges between the same nodes. There can exist at most one edge for every ordered tuple of source node, destination node, source index, and destination index. - +- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected. +- `DiGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) +- `MultiDiGraph`: arbitrary numbers of directed edges allowed between every pair of nodes. +- `DataflowGraph`: similar to `MultiDiGraph`, but with the following differences: + - The edges entering, exiting a given nodes now have a well-defined order. + - Due to the interface used to construct them (where essentially a node can only be added to the graph after all of its predecessor nodes have been added) `DataflowGraph`s are directed acyclic graphs. + - Each node has an associated ordered sequence of inputs and outputs, with the restriction that one and only one edge can enter an individual input. + +Conceptually, `DataflowGraph` is used within FlexFlow to represent computation-style graphs, where edges represent value uses and nodes represent multivariate functions from tuples of inputs to tuples of outputs. + Examples of the different graph variants are shown below. Example of `UndirectedGraph`: @@ -37,7 +43,7 @@ flowchart TD D --- B ``` -Example of `DirectedGraph`: +Example of `DiGraph`: ```mermaid flowchart TD A(" ") @@ -58,98 +64,34 @@ flowchart TD Example of `MultiDiGraph`: ```mermaid flowchart TD - A("A") - B("B") - C("C") - D("D") - E("E") - F("F") - - A -->|"(■, ★)"| B - B -->|"(●, ★)"| C - C -->|"(♥, ▲)"| D - D -->|"(●, ■)"| A - B -->|"(★, ●)"| E - E -->|"(■, ■)"| B - D -->|"(●, ●)"| A - A -->|"(●, ■)"| E - D -->|"(■, ●)"| D - E -->|"(■, ■)"| E -``` -or visualized a different way, -```mermaid -flowchart TD - Acirc("●") - Asqua("■") - Bcirc("●") - Bstar("★") - Bsqua("■") - Chear("♥") - Cstar("★") - Dsqua("■") - Dcirc("●") - Dtria("▲") - Ecirc("●") - Esqua("■") - Fplaceholder(" ") - - style Fplaceholder fill:#0000,stroke:#0000 - - subgraph "A" - Acirc - Asqua - end - - subgraph "B" - Bsqua - Bcirc - Bstar - end - - subgraph "C" - Chear - Cstar - end - - subgraph "D" - Dsqua - Dcirc - Dtria - end - - subgraph "E" - Ecirc - Esqua - end - - subgraph "F" - Fplaceholder - end - - Asqua --> Bstar - Bcirc --> Cstar - Chear --> Dtria - Dcirc --> Asqua - Bstar --> Ecirc - Esqua --> Bsqua - Dcirc --> Acirc - Acirc --> Esqua - Dsqua --> Dcirc - Esqua --> Esqua + A + B + C + D + E + F + + A --> B + B --> C + C --> D + D --> A + B --> E + E --> B + D --> A + A --> E + D --> D + E --> E ``` -Note that the nodes and source/destination indices are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. -This is the case as well with `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`. +Note that the node names are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. +This is the case with all of the 4 core graph classes. Nodes are of type `Node`, and from a user perspective are simply opaque handles, and source and destination indices should similarly be considered opaque from a user point of view. In addition, nodes should only be used in the context of their graph, so comparing or checking equality of nodes between different graphs (even of the same type) is undefined behavior[^1]. All three core graph variants allow insertion and deletion of both edges and nodes. To add a node to an `UndirectedGraph g`, simply call `g.add_node()` (the interface is identical for `DiGraph` and `MultiDiGraph`). To add an edge between two nodes `Node n1` and `Node n2` to an `UndirectedGraph g`, call `g.add_edge({n1, n2})`. -In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph` and `MultiDiGraph`. -`MultiDiGraph::add_edge` takes in two additional arguments of type `NodePort`, specifying the source and destination indices. -Similar to `Node`s, `NodePort`s can be generated via `g.add_node_port()`. -`NodePort:` an opaque object used within `MultiDiGraph` to disambiguate between multiple edges. `MultiDiGraph` will be able to distinguish between 2 edges that share the same source and destination as long as at at least one `NodePort` differs. Within the context of a PCG, `NodePorts` must be thought of as the various inputs and outputs of a single node. +In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph`, `MultiDiGraph` and `DataflowGraph`. The last paragraph covered the base API used to write to graphs, but we also want to be able to read from graphs. Reading from graphs is implemented with the `query_nodes` and `query_edges` methods, which can be thought of as executing a database query over the nodes and edges of the target graph, respectively (where queries are restricted to an incredibly simple set of operations). @@ -158,11 +100,13 @@ The argument to `query_nodes` is a `NodeQuery` (which is simply a set of `Node`s The set of nodes in the query is actually an `optional`, so `nullopt` could also be passed, which would simply retrieve all nodes from the target graph (essentially `nullopt` acts as the set of all nodes that could ever exist). `query_edges` functions similarly, but as with `add_edge` its behavior is differs slightly between the three graph variants. `UndirectedGraph::query_edges` simply takes an optional set of nodes and returns all edges that touch any of those nodes. -`DirectedGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. +`DiGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. In practice you will rarely ever use `query_nodes` and `query_edges` as the graph library provides a large number of algorithms that do that work for you, but it can be helpful to understand this base layer if you ever need to implement your own algorithms. -The layer users will most commonly interact with is the interface provided by [algorithms.h](./algorithms.h), which provides a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. -You may notice that the most of the functions declared in `algorithms.h` take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but actually operator on `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. +The layer users will most commonly interact with is the interface provided within either the `algorithms.h` header files or the `algorithms` folders, present in their respective graph class folders. +They provide a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. +Note that, due to the internal virtual inheritance structure, some functions for more privitive classes can be employed by the derived classes. (For example, `get_nodes` present in `node/algorithms.h` can be used by `DiGraph`). +You may notice that the most of algorithms present take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but rather `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. These `GraphView` objects represent read-only (i.e., immutable) graphs. Similar to C++'s `const` semantics, `Graph`s can be coerced[^2] to `GraphView`s but not the other way around. To transform a `GraphView` to a `Graph`, we can perform an explicit copy with `materialize_view`. @@ -171,40 +115,35 @@ This may seem wasteful (oftentimes graphs are large objects that are passed arou At this point, however, we still have not discussed how to create a graph. The user-facing graph interface is intentially separated from the underlying graph representations, so representations can be changed without requiring any user-side code modifications besides the choice of which implementation to use. -For example, to construct a `DiGraph` which internally uses a representation `MyDiGraphImpl`: +For example, to construct a `DiGDiraph` which internally uses a representation such as `AdjacencyDiGraph` we do the following: ```cpp -DiGraph g = DiGraph::create(); +DiGraph g = DiGraph::create(); ``` Generally users will use underlying representations provided by the graph library, but advanced users can create their own implementations (see the [Internals](#internals) section). [^1]: At some point we will likely add actual runtime checks on this, but for now we rely on the user not to mess up. Currently the implementation will keep going silently until the incorrectness grows so large that something breaks/crashes. [^2]: See if you're not familiar with the term _type coercion_ -### Open, Upward, Downward +### Open DataFlow Variant `Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. -We can further specify the "openeness" of a **directed** graph by specifying whether they are `UpwardOpen` (so some of the incoming edges are open) or `DownwardOpen` (so some of the outgoing edges are open). - -![Open graphs inheritance diagram](docs/open.svg) +This graph class is particularly useful for processing a sub-graph of a given graph while still maintaining information regarding the edges that cross the cut. -Arrows with pointed tips indicate inheritance, while arrows with square tips indicate that the pointing class has a 'cow_ptr' of the type of the pointed class. (for more info, see [cow_ptr](#cow_ptr-and-interfaces)) - - -### Labelled Graphs +### Labelled Dataflow Variant As nice as all of the above is, graphs without labels are mostly useless--in practice, nodes and edges represent some other system and the properties of that system (or at least a way to map the result of graph algorithms back to the underlying system) are necessary. -Thus, FlexFlow's graph library provides the ability to add labels via [labelled\_graphs.h](./labelled_graphs.h): examples include `NodeLabelledMultiDiGraph` (nodes have labels of type `T` and edges are unlabelled) and `OutputLabelledMultiDiGraph` (nodes have labels of type `T` and source indices have labels of type `U`). -While the interfaces of these graphs differ slightly from the core graph variants, they still have corresponding `GraphView` types, `add_node`/`add_edge` methods, and `query_nodes`/`query_edges` methods. -Note that all of the labelled graph types require that each element of the labelled types have a label (e.g., every node in a `NodeLabelledMultiDiGraph` must have a label of type `T`)., which is enforced via the interfaces they provide. +Thus, FlexFlow's graph library provides the ability to add labels to `DataflowGraph`, through the `LabelleledDataflowGraph` and `OpenLabelleledDataflowGraph`, which allow users to label different components of the graph. +- `LabelledDataflowGraph` allows for labelling of `Node`s and `DataflowOutput`s. +- `OpenLabelledDataflowGraph` allows for labelling of `Node`s and `OpenDataflowValue`s, which is a variant describing both `DataflowOutput`s and `DataflowGraphInput`s, which represent the open inputs to the graph (i.e. the inputs for which their corresponding output is not present in the graph). + +While the interfaces of these graphs differ slightly from the core graph variants, they still have the corresponding `add_node` methods, and `query_nodes`/`query_edges` methods. (Note that there is no `add_edge` method since, for `DataflowGraph`, edges are implicitly added when we add a node and specify its predecessors) +Note that all of the labelled graph types require that each element of the labelled types have a label, which is enforced via the interfaces they provide. Partial labelling can be implement via wrapping the label type in `optional`. -Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes (or other types depending in which labelled graph type is used) to labels. -As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants for use in functions provided by `algorithms.h`, etc. +Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes/edges to labels. +As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants. [^3]: `operator[]` currently is not present because all nodes must have labels and we don't require label types to be default constructible, though some simple template programming could probably add `operator[]` support in the cases where the label types _are_ default constructible. -![Labelled Graphs Inheritance Diagram](docs/labelled.svg) - - ## Internals @@ -236,12 +175,7 @@ To address this, graph classes store a `cow_ptr` as a member variable, which poi All member functions present in `ClassName` and `ClassNameView` delegate their calls to their corresponding interface classes (which implement the actual logic), meaning that these classes essentially act as wrappers to their interface counterparts. -To create graphs within the library, we thus use the following syntax: -`BaseGraph obj = BaseGraph::create();` - -Resulting in an object that, while of type `BaseGraph`, can access at runtime the member functions defined in `DerivedGraph` - ### Virtual Inheritance -Due to the complexity of the graph library, diamond-style inheritance patterns emerge (consider, for example, the `OutputLabelledOpenMultiDiGraphView` class, which inherits from both `NodeLabelledOpenMultiDiGraphView` and `OutputLabelledMultiDiGraphView`, which in turn inherit from both `NodeLabelledMultiDiGraphView`). -In the case of a diamond inheritance pattern C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. +Due to the complexity of the graph library, diamond-style inheritance patterns emerge. +In the case of a diamond inheritance pattern, C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. To address this issue, we employ [Virtual Inheritance](https://en.wikipedia.org/wiki/Virtual_inheritance), which removes the ambiguity associated with the multiple copies. diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h index 6a1898dd13..d73175c7dd 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -42,8 +42,6 @@ struct DataflowGraph : virtual public DataflowGraphView { private: IDataflowGraph &get_interface(); IDataflowGraph const &get_interface() const; - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h index 1e4d09d3ae..96e8864bc1 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h @@ -5,7 +5,22 @@ namespace FlexFlow { +/** + * @brief See https://en.wikipedia.org/wiki/Dominator_(graph_theory) + * + * @note By definition, the root node dominates every node and every node + * dominates itself. + * + */ std::unordered_set get_dominators(DiGraphView const &, Node const &); + +/** + * @brief Returns the intersection of the dominators of the given set of nodes. + * @note This is conceptually equivalent to merging the given set of nodes and + * then finding the set of dominators of the new merged node (where merged means + * that all edges belonging to the set of nodes now pass through a single + * unified node). + */ std::unordered_set get_dominators(DiGraphView const &, std::unordered_set const &); diff --git a/lib/utils/include/utils/graph/digraph/digraph.h b/lib/utils/include/utils/graph/digraph/digraph.h index e36b90d4bf..3d320b1c06 100644 --- a/lib/utils/include/utils/graph/digraph/digraph.h +++ b/lib/utils/include/utils/graph/digraph/digraph.h @@ -40,8 +40,6 @@ struct DiGraph : virtual DiGraphView { private: IDiGraph &get_ptr(); IDiGraph const &get_ptr() const; - - friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraph); diff --git a/lib/utils/include/utils/graph/digraph/digraph_view.h b/lib/utils/include/utils/graph/digraph/digraph_view.h index 54f84f8d2c..0380751c55 100644 --- a/lib/utils/include/utils/graph/digraph/digraph_view.h +++ b/lib/utils/include/utils/graph/digraph/digraph_view.h @@ -31,8 +31,6 @@ struct DiGraphView : virtual public GraphView { private: IDiGraphView const &get_ptr() const; - - friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView); diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h b/lib/utils/include/utils/graph/instances/hashmap_undirected_graph.h similarity index 100% rename from lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h rename to lib/utils/include/utils/graph/instances/hashmap_undirected_graph.h diff --git a/lib/utils/include/utils/graph/multidigraph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph/multidigraph.h index 69080b9348..692ee33783 100644 --- a/lib/utils/include/utils/graph/multidigraph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph/multidigraph.h @@ -40,8 +40,6 @@ struct MultiDiGraph : virtual public MultiDiGraphView { private: IMultiDiGraph &get_interface(); IMultiDiGraph const &get_interface() const; - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph.h b/lib/utils/include/utils/graph/node/graph.h index bddefdacb3..1d94d1a65e 100644 --- a/lib/utils/include/utils/graph/node/graph.h +++ b/lib/utils/include/utils/graph/node/graph.h @@ -31,8 +31,6 @@ struct Graph : virtual GraphView { private: IGraph const &get_ptr() const; IGraph &get_ptr(); - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph_view.h b/lib/utils/include/utils/graph/node/graph_view.h index fce3177ef1..8d904e05f2 100644 --- a/lib/utils/include/utils/graph/node/graph_view.h +++ b/lib/utils/include/utils/graph/node/graph_view.h @@ -22,8 +22,6 @@ struct GraphView { GraphView(); cow_ptr_t ptr; GraphView(cow_ptr_t ptr); - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml index d5c22e5d3d..46e0255de3 100644 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -6,6 +6,7 @@ features = [ "hash", "fmt", "json", + "rapidcheck", ] includes = [ diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.h b/lib/utils/include/utils/graph/undirected/undirected_edge.h index 33d50192cb..d051413faa 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.h @@ -2,33 +2,12 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_H #include "utils/graph/node/node.dtg.h" -namespace FlexFlow { - -struct UndirectedEdge { -public: - UndirectedEdge() = delete; - UndirectedEdge(Node const &src, Node const &dst); +#include "utils/graph/undirected/undirected_edge.dtg.h" - bool operator==(UndirectedEdge const &) const; - bool operator!=(UndirectedEdge const &) const; - bool operator<(UndirectedEdge const &) const; - -public: - Node smaller; - Node bigger; -}; +namespace FlexFlow { -bool is_connected_to(UndirectedEdge const &, Node const &); +bool is_connected_to(UndirectedEdge const &e, Node const &n); } // namespace FlexFlow -namespace std { - -template <> -struct hash<::FlexFlow::UndirectedEdge> { - size_t operator()(::FlexFlow::UndirectedEdge const &) const; -}; - -} // namespace std - #endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml new file mode 100644 index 0000000000..f5258b0bfd --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "UndirectedEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck" +] + +includes = [ + "utils/commutative_pair.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "endpoints" +type = "::FlexFlow::commutative_pair<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph.h b/lib/utils/include/utils/graph/undirected/undirected_graph.h index 69975991ce..09b6495699 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph.h @@ -34,8 +34,6 @@ struct UndirectedGraph : virtual UndirectedGraphView { using UndirectedGraphView::UndirectedGraphView; - friend struct GraphInternal; - private: IUndirectedGraph const &get_ptr() const; IUndirectedGraph &get_ptr(); diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h index c2df96abc0..90dd5dd5d8 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h @@ -29,8 +29,6 @@ struct UndirectedGraphView : virtual GraphView { using GraphView::GraphView; - friend struct GraphInternal; - private: IUndirectedGraphView const &get_ptr() const; }; diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 6ed41daf43..79c4fc9964 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -184,7 +184,8 @@ bool contains_edge(DiGraphView const &g, DirectedEdge const &e) { } bool contains_edge(UndirectedGraphView const &g, UndirectedEdge const &e) { - UndirectedEdgeQuery q = UndirectedEdgeQuery{{e.bigger, e.smaller}}; + UndirectedEdgeQuery q = + UndirectedEdgeQuery{{e.endpoints.max(), e.endpoints.min()}}; return contains(g.query_edges(q), e); } @@ -212,7 +213,7 @@ void remove_edges(UndirectedGraph &g, } std::unordered_set get_endpoints(UndirectedEdge const &e) { - return {e.smaller, e.bigger}; + return {e.endpoints.min(), e.endpoints.max()}; } // std::unordered_set get_edges(MultiDiGraphView const &g) { diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 61c4f80763..df84683a6b 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -22,23 +22,25 @@ void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { } void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { - if (!contains_key(this->adjacency, e.bigger)) { - throw mk_runtime_error(fmt::format( - "Could not add edge connected to non-existent node {}", e.bigger)); + if (!contains_key(this->adjacency, e.endpoints.max())) { + throw mk_runtime_error( + fmt::format("Could not add edge connected to non-existent node {}", + e.endpoints.max())); } - if (!contains_key(this->adjacency, e.smaller)) { - throw mk_runtime_error(fmt::format( - "Could not add edge connected to non-existent node {}", e.smaller)); + if (!contains_key(this->adjacency, e.endpoints.min())) { + throw mk_runtime_error( + fmt::format("Could not add edge connected to non-existent node {}", + e.endpoints.min())); } - this->adjacency.at(e.bigger).insert(e.smaller); - this->adjacency.at(e.smaller).insert(e.bigger); + this->adjacency.at(e.endpoints.max()).insert(e.endpoints.min()); + this->adjacency.at(e.endpoints.min()).insert(e.endpoints.max()); } void HashmapUndirectedGraph::remove_edge(UndirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.bigger); - m.erase(e.smaller); - m.erase(e.bigger); + std::unordered_set &m = this->adjacency.at(e.endpoints.max()); + m.erase(e.endpoints.min()); + m.erase(e.endpoints.max()); } std::unordered_set HashmapUndirectedGraph::query_edges( @@ -46,7 +48,7 @@ std::unordered_set HashmapUndirectedGraph::query_edges( std::unordered_set result; for (auto const &src_kv : query_keys(query.nodes, this->adjacency)) { for (auto const &dst : src_kv.second) { - result.insert({src_kv.first, dst}); + result.insert(UndirectedEdge{{src_kv.first, dst}}); } } return result; diff --git a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc index 6f6722f635..cb44f4636d 100644 --- a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc @@ -27,8 +27,8 @@ void UnorderedSetUndirectedGraph::remove_node_unsafe(Node const &n) { } void UnorderedSetUndirectedGraph::add_edge(UndirectedEdge const &e) { - assert(contains(this->nodes, e.bigger)); - assert(contains(this->nodes, e.smaller)); + assert(contains(this->nodes, e.endpoints.min())); + assert(contains(this->nodes, e.endpoints.max())); this->edges.insert(e); } diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc index 3c05b9d5d5..726fda8af7 100644 --- a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc @@ -10,7 +10,7 @@ std::unordered_set get_neighboring_nodes(UndirectedGraphView const &g, std::unordered_set result = set_union(transform(vector_of(edges), [](UndirectedEdge const &e) { - return std::unordered_set{e.bigger, e.smaller}; + return std::unordered_set{e.endpoints.max(), e.endpoints.max()}; })); result.erase(n); return result; diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge.cc b/lib/utils/src/utils/graph/undirected/undirected_edge.cc index 0a575e115c..4cfc6aaaa8 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge.cc @@ -1,40 +1,11 @@ #include "utils/graph/undirected/undirected_edge.h" #include "utils/hash/tuple.h" +#include namespace FlexFlow { -UndirectedEdge::UndirectedEdge(Node const &n1, Node const &n2) - : smaller(std::min(n1, n2)), bigger(std::max(n1, n2)) {} - -static std::tuple tie(UndirectedEdge const &e) { - return std::tie(e.smaller, e.bigger); -} - -bool UndirectedEdge::operator==(UndirectedEdge const &other) const { - return tie(*this) == tie(other); -} - -bool UndirectedEdge::operator!=(UndirectedEdge const &other) const { - return tie(*this) != tie(other); -} - -bool UndirectedEdge::operator<(UndirectedEdge const &other) const { - return tie(*this) < tie(other); -} - bool is_connected_to(UndirectedEdge const &e, Node const &n) { - return e.bigger == n || e.smaller == n; + return e.endpoints.min() == n || e.endpoints.max() == n; } } // namespace FlexFlow - -namespace std { - -using namespace FlexFlow; - -size_t hash::operator()(UndirectedEdge const &e) const { - std::tuple members = ::FlexFlow::tie(e); - return std::hash{}(members); -} - -} // namespace std diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc index 3cccf1c6eb..e9e948aa40 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc @@ -7,7 +7,8 @@ UndirectedEdgeQuery undirected_edge_query_all() { } bool matches_edge(UndirectedEdgeQuery const &q, UndirectedEdge const &e) { - return includes(q.nodes, e.bigger) && includes(q.nodes, e.smaller); + return includes(q.nodes, e.endpoints.max()) && + includes(q.nodes, e.endpoints.min()); } UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index 9b5353de9f..c29d478f1e 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -1,4 +1,5 @@ #include "utils/graph/views/views.h" +#include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/flatmap.h" #include "utils/containers/transform.h" #include "utils/disjoint_set.h" @@ -7,8 +8,8 @@ #include "utils/graph/digraph/directed_edge_query.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/node/node_query.h" +#include "utils/graph/query_set.h" #include "utils/graph/undirected/undirected_edge_query.h" - namespace FlexFlow { UndirectedSubgraphView::UndirectedSubgraphView( @@ -78,9 +79,13 @@ JoinedNodeView::JoinedNodeView(GraphView const &lhs, GraphView const &rhs) { std::unordered_set JoinedNodeView::query_nodes(NodeQuery const &query) const { - // TODO @lockshaw this is going to be reimplemented in 984, so don't bother - // fixing it for now - NOT_IMPLEMENTED(); + std::unordered_set nodes = right_entries(this->mapping); + if (query == node_query_all()) { + return nodes; + } + return filter(nodes, [&](Node const &n) { + return contains(allowed_values(query.nodes), n); + }); } std::pair, std::unordered_set> @@ -146,17 +151,18 @@ std::unordered_set JoinedUndirectedGraphView::query_edges( UndirectedEdge JoinedUndirectedGraphView::fix_lhs_edge(UndirectedEdge const &e) const { - return { - this->joined_nodes.at_join_key(JoinNodeKey{e.smaller, LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.bigger, LRDirection::LEFT})}; + return UndirectedEdge{{this->joined_nodes.at_join_key( + JoinNodeKey{e.endpoints.min(), LRDirection::LEFT}), + this->joined_nodes.at_join_key(JoinNodeKey{ + e.endpoints.max(), LRDirection::LEFT})}}; } UndirectedEdge JoinedUndirectedGraphView::fix_rhs_edge(UndirectedEdge const &e) const { - return {this->joined_nodes.at_join_key( - JoinNodeKey{e.smaller, LRDirection::RIGHT}), - this->joined_nodes.at_join_key( - JoinNodeKey{e.bigger, LRDirection::RIGHT})}; + return UndirectedEdge{{this->joined_nodes.at_join_key(JoinNodeKey{ + e.endpoints.min(), LRDirection::RIGHT}), + this->joined_nodes.at_join_key(JoinNodeKey{ + e.endpoints.max(), LRDirection::RIGHT})}}; } JoinedDigraphView::JoinedDigraphView(DiGraphView const &lhs, @@ -208,7 +214,7 @@ DirectedEdge JoinedDigraphView::fix_rhs_edge(DirectedEdge const &e) const { } UndirectedEdge to_undirected_edge(DirectedEdge const &e) { - return {e.src, e.dst}; + return UndirectedEdge{{e.src, e.dst}}; } std::unordered_set to_undirected_edges( @@ -218,8 +224,9 @@ std::unordered_set to_undirected_edges( } std::unordered_set to_directed_edges(UndirectedEdge const &e) { - return std::unordered_set{DirectedEdge{e.smaller, e.bigger}, - DirectedEdge{e.bigger, e.smaller}}; + return std::unordered_set{ + DirectedEdge{e.endpoints.min(), e.endpoints.max()}, + DirectedEdge{e.endpoints.max(), e.endpoints.min()}}; } std::unordered_set to_directed_edges( diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc new file mode 100644 index 0000000000..0817c69e06 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc @@ -0,0 +1,106 @@ +#include "utils/graph/digraph/algorithms.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DiGraph - algorithms.cc") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + }; + add_edges(g, e); + + SUBCASE("get_edges") { + SUBCASE("Base") { + std::unordered_set correct = unordered_set_of(e); + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge") { + g.add_edge(DirectedEdge{n[3], n[1]}); + std::unordered_set correct = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[3], n[1]}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge") { + g.remove_edge(DirectedEdge{n[0], n[3]}); + std::unordered_set correct = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + } + + SUBCASE("get_sinks") { + SUBCASE("Base") { + std::unordered_set correct = {n[2], n[3]}; + std::unordered_set result = get_sinks(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a sink") { + g.add_edge(DirectedEdge{n[3], n[2]}); + std::unordered_set correct = {n[2]}; + std::unordered_set result = get_sinks(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set result = get_sinks(g); + std::unordered_set correct = {n[3]}; + CHECK(result == correct); + } + } + + SUBCASE("get_sources") { + SUBCASE("Base") { + std::unordered_set correct = {n[0]}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a source") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set correct = {}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge to create a new source") { + g.remove_edge(DirectedEdge{n[0], n[1]}); + std::unordered_set correct = {n[0], n[1]}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set result = get_sources(g); + std::unordered_set correct = {}; + CHECK(result.empty()); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc new file mode 100644 index 0000000000..3a3648eec8 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc @@ -0,0 +1,85 @@ +#include "utils/graph/digraph/digraph.h" +#include "utils/containers/repeat.h" +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/node_query.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("DiGraph implementations", T, AdjacencyDiGraph) { + /* + graph TD + + n0 --> n1 + n0 --> n2 + n1 --> n2 + n2 --> n4 + n1 --> n3 + */ + + DiGraph g = DiGraph::create(); + std::vector n = repeat(5, [&] { return g.add_node(); }); + std::vector e = {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[1], n[3]}}; + for (DirectedEdge const &edge : e) { + g.add_edge(edge); + } + + SUBCASE("query_nodes") { + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); + + CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == + std::unordered_set{n[0], n[2]}); + + std::unordered_set queried_edges = + g.query_edges(directed_edge_query_all()); + std::unordered_set expected = { + e[0], e[1], e[2], e[3], e[4]}; + CHECK(queried_edges == expected); + + queried_edges = g.query_edges( + DirectedEdgeQuery{query_set{{n[0]}}, query_set{{n[1]}}}); + expected = std::unordered_set{e[0]}; + CHECK(queried_edges == expected); + } + SUBCASE("remove_node_unsafe") { + g.remove_node_unsafe(n[0]); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[1], n[2], n[3], n[4]}); + + // removing a node also removes its adjacent edges + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[2], e[3], e[4]}); + + g.remove_node_unsafe(n[1]); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[2], n[3], n[4]}); + + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[3]}); + } + + SUBCASE("remove_edge") { + g.remove_edge(e[0]); + + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[1], e[2], e[3], e[4]}); + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); + + g.remove_edge(e[1]); + g.remove_edge(e[3]); + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[2], e[4]}); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc new file mode 100644 index 0000000000..1dde5c8f69 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc @@ -0,0 +1,70 @@ +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("directed_edge_query") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 5); + + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}}); + + SUBCASE("directed_edge_query_all") { + + DirectedEdgeQuery result = directed_edge_query_all(); + + CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(1)})); + CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(2)})); + CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(2)})); + CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(4)})); + CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(3)})); + } + + SUBCASE("matches_edge") { + DirectedEdgeQuery q = + DirectedEdgeQuery{query_set{n.at(0)}, query_set{n.at(1)}}; + + CHECK(matches_edge(q, DirectedEdge{n.at(0), n.at(1)})); + CHECK_FALSE(matches_edge(q, DirectedEdge{n.at(1), n.at(2)})); + } + + SUBCASE("query_intersection") { + SUBCASE("standard intersection") { + DirectedEdgeQuery q1 = DirectedEdgeQuery{ + query_set{n.at(0), n.at(1)}, query_set{n.at(1), n.at(2), n.at(4)}}; + DirectedEdgeQuery q2 = DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, + query_set{n.at(2), n.at(3)}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery correct = DirectedEdgeQuery{ + query_set{n.at(1)}, + query_set{n.at(2)}, + }; + + CHECK(result == correct); + } + SUBCASE("intersection with std::nullopt") { + DirectedEdgeQuery q1 = + DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, matchall()}; + DirectedEdgeQuery q2 = + DirectedEdgeQuery{matchall(), query_set{n.at(3), n.at(4)}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery correct = DirectedEdgeQuery{ + query_set{n.at(1), n.at(2)}, query_set{n.at(3), n.at(4)}}; + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc new file mode 100644 index 0000000000..e9151b53e5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc @@ -0,0 +1,68 @@ +#include "utils/graph/digraph/algorithms/get_dominators.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_dominators") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); + + SUBCASE("single node") { + Node node = n.at(2); + std::unordered_set correct = {n.at(0), n.at(2)}; + std::unordered_set result = get_dominators(g, node); + CHECK(correct == result); + } + + SUBCASE("multiple nodes") { + std::unordered_set nodes = {n.at(1), n.at(3)}; + std::unordered_set result = get_dominators(g, nodes); + std::unordered_set correct = {n.at(0)}; + CHECK(correct == result); + } + + SUBCASE("graph with cycles") { + // example from + // https://en.wikipedia.org/w/index.php?title=Dominator_(graph_theory)&oldid=1189814332 + + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 6); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(1)}, + }); + + SUBCASE("node 1") { + std::unordered_set result = get_dominators(g, n.at(1)); + std::unordered_set correct = {n.at(0), n.at(1)}; + CHECK(result == correct); + } + + SUBCASE("node 3") { + std::unordered_set result = get_dominators(g, n.at(3)); + std::unordered_set correct = {n.at(0), n.at(1), n.at(3)}; + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc new file mode 100644 index 0000000000..de6953fad4 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc @@ -0,0 +1,36 @@ +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/containers.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_topological_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + std::vector edges = {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}}; + add_edges(g, edges); + std::vector ordering = get_topological_ordering(g); + auto CHECK_BEFORE = [&](int l, int r) { + CHECK(index_of(ordering, n[l]).value() < + index_of(ordering, n[r]).value()); + }; + + CHECK(ordering.size() == n.size()); + CHECK_BEFORE(0, 1); + CHECK_BEFORE(0, 2); + CHECK_BEFORE(1, 5); + CHECK_BEFORE(2, 3); + CHECK_BEFORE(3, 4); + CHECK_BEFORE(4, 5); + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc new file mode 100644 index 0000000000..0d8e7ca53a --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc @@ -0,0 +1,112 @@ +#include "utils/graph/traversal.h" +#include "utils/fmt/vector.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/hash/vector.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_unchecked_dfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}}); + + SUBCASE("simple path") { + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_unchecked_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + } + + TEST_CASE("get_bfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[3]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[4], n[5]}}); + + SUBCASE("branching path") { + std::unordered_set> corrects = { + {n[0], n[1], n[2], n[3], n[4], n[5]}, + {n[0], n[2], n[1], n[3], n[4], n[5]}}; + std::vector result = get_bfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); + } + + SUBCASE("isolated node") { + std::vector correct = {n[5]}; + std::vector result = get_bfs_ordering(g, {n[5]}); + CHECK(correct == result); + } + + SUBCASE("graph with cycle") { + g = DiGraph::create(); + n = add_nodes(g, 3); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[0]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[0]}, + DirectedEdge{n[2], n[1]}}); + std::unordered_set> corrects = {{n[0], n[1], n[2]}, + {n[0], n[2], n[1]}}; + std::vector result = get_bfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); + } + } + + TEST_CASE("get_dfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}}); + + SUBCASE("simple path") { + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + + SUBCASE("with cycle") { + g.add_edge(DirectedEdge{n[3], n[1]}); + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + + SUBCASE("branching") { + g.add_edge(DirectedEdge{n[1], n[3]}); + std::unordered_set> corrects = { + {n[0], n[1], n[2], n[3]}, {n[0], n[1], n[3], n[2]}}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); + } + + SUBCASE("disconnected") { + g.remove_edge(DirectedEdge{n[2], n[3]}); + std::vector correct = {n[0], n[1], n[2]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + + SUBCASE("isolated node") { + g.remove_edge(DirectedEdge{n[2], n[3]}); + std::vector correct = {n[3]}; + std::vector result = get_dfs_ordering(g, {n[3]}); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc new file mode 100644 index 0000000000..b5943cd99f --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -0,0 +1,36 @@ +#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/algorithms/add_edges.h" +#include "utils/graph/multidigraph/algorithms/add_nodes.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_incoming_edges(MultiDiGraphView, Node)") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 3); + + std::vector edges = add_edges(g, + {{n.at(0), n.at(0)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(1), n.at(0)}}); + + SUBCASE("node has incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(1)); + std::unordered_set correct = {edges.at(1), edges.at(2)}; + CHECK(result == correct); + } + + SUBCASE("node has no incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc new file mode 100644 index 0000000000..d4748e8422 --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -0,0 +1,40 @@ +#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/algorithms/add_edges.h" +#include "utils/graph/multidigraph/algorithms/add_nodes.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_outgoing_edges(MultiDiGraph, Node)") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 3); + + std::vector> input = { + {n.at(0), n.at(0)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(1), n.at(0)}, + }; + + std::vector edges = add_edges(g, input); + + SUBCASE("node has outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(0)); + std::unordered_set correct = { + edges.at(0), edges.at(1), edges.at(2)}; + CHECK(result == correct); + } + + SUBCASE("node has no outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc new file mode 100644 index 0000000000..7f6f0dd064 --- /dev/null +++ b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc @@ -0,0 +1,26 @@ +#include "utils/graph/undirected/algorithms/get_connected_components.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/undirected/undirected_graph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, {UndirectedEdge{{n[0], n[1]}}, UndirectedEdge{{n[2], n[1]}}}); + + std::unordered_set> correct = { + {n[0], n[1], n[2]}, + {n[3]}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } +} diff --git a/lib/utils/test/src/utils/graph/undirected/undirected.cc b/lib/utils/test/src/utils/graph/undirected/undirected.cc new file mode 100644 index 0000000000..7973cf8af5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/undirected/undirected.cc @@ -0,0 +1,75 @@ +#include "test/utils/rapidcheck.h" +#include "test/utils/rapidcheck/visitable.h" +#include "utils/commutative_pair.h" +#include "utils/containers/repeat.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/undirected/undirected_edge_query.h" +#include "utils/graph/undirected/undirected_graph.h" + +using namespace FlexFlow; + +using namespace rc; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE( + "UndirectedGraph implementations", T, HashmapUndirectedGraph) { + + RC_SUBCASE("Full", [&]() { + UndirectedGraph g = UndirectedGraph::create(); + int num_nodes = *gen::inRange(1, 10); + std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); + int num_edges = *gen::inRange(0, num_nodes); + std::vector e; + if (num_nodes > 0) { + e = *gen::unique>( + num_edges, + gen::construct( + gen::construct>(gen::elementOf(n), + gen::elementOf(n)))); + } + for (UndirectedEdge const &edge : e) { + g.add_edge(edge); + } + + CHECK(g.query_nodes(node_query_all()) == unordered_set_of(n)); + + auto subset = *rc::subset_of(n); + CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); + + CHECK(g.query_edges(undirected_edge_query_all()) == unordered_set_of(e)); + }); + } +} +/* static_assert(is_fmtable::value, ""); */ + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE( + "UndirectedGraph implementations", T, HashmapUndirectedGraph) { + + RC_SUBCASE("Full", [&]() { + UndirectedGraph g = UndirectedGraph::create(); + int num_nodes = *gen::inRange(1, 10); + std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); + int num_edges = *gen::inRange(0, num_nodes); + std::vector e; + if (num_nodes > 0) { + e = *gen::unique>( + num_edges, + gen::construct( + gen::construct>(gen::elementOf(n), + gen::elementOf(n)))); + } + for (UndirectedEdge const &edge : e) { + g.add_edge(edge); + } + + CHECK(g.query_nodes(node_query_all()) == unordered_set_of(n)); + + auto subset = *rc::subset_of(n); + CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); + + CHECK(g.query_edges(undirected_edge_query_all()) == unordered_set_of(e)); + }); + } +} diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc new file mode 100644 index 0000000000..58f2e35cb5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -0,0 +1,262 @@ +#include "utils/graph/views/views.h" +#include "utils/containers/set_union.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/undirected/undirected_graph.h" +#include "utils/graph/undirected/undirected_graph_view.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + // TEST_CASE("UndirectedSubgraphView") { + // UndirectedGraph g = UndirectedGraph::create(); + // std::vector n = add_nodes(g, 5); + // add_edges(g, + // {UndirectedEdge{{n.at(0), n.at(3)}}, + // UndirectedEdge{{n.at(1), n.at(1)}}, + // UndirectedEdge{{n.at(1), n.at(2)}}, + // UndirectedEdge{{n.at(1), n.at(3)}}, + // UndirectedEdge{{n.at(2), n.at(3)}}, + // UndirectedEdge{{n.at(2), n.at(4)}}}); + // std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; + // UndirectedGraphView view = get_subgraph(g, sub_nodes); + + // SUBCASE("get_nodes") { + // std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; + + // std::unordered_set result = get_nodes(view); + + // CHECK(result == expected); + // } + + // SUBCASE("get_edges") { + // std::unordered_set expected = { + // UndirectedEdge{{n.at(0), n.at(3)}}, + // UndirectedEdge{{n.at(1), n.at(1)}}, + // UndirectedEdge{{n.at(1), n.at(3)}}, + // }; + + // std::unordered_set result = get_edges(view); + + // // TODO(@pietro) TODO(@lockshaw) current BUG, get_edges also + // CHECK(result == expected); + // } + // } + + TEST_CASE("DiSubgraphView") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + {DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}, + DirectedEdge{n.at(1), n.at(1)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(2)}, + DirectedEdge{n.at(2), n.at(4)}}); + std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; + DiGraphView view = get_subgraph(g, sub_nodes); + + SUBCASE("get_nodes") { + std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}, + DirectedEdge{n.at(1), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + TEST_CASE("JoinedNodeView") { + UndirectedGraph g1 = UndirectedGraph::create(); + UndirectedGraph g2 = UndirectedGraph::create(); + + std::vector n1 = add_nodes(g1, 3); + std::vector n2 = add_nodes(g2, 2); + std::unordered_set joined_nodes = + set_union(unordered_set_of(n1), unordered_set_of(n2)); + add_edges(g1, + {UndirectedEdge{{n1[0], n1[1]}}, UndirectedEdge{{n1[1], n1[2]}}}); + add_edges(g2, {UndirectedEdge{{n2[0], n2[1]}}}); + + JoinedNodeView joined_view(g1, g2); + + SUBCASE("trace_nodes") { + std::pair, std::unordered_set> result = + joined_view.trace_nodes(joined_nodes); + std::pair, std::unordered_set> correct = { + {n1[0], n1[1], n1[2]}, {n2[0], n2[1]}}; + + CHECK(result == correct); + } + + SUBCASE("query_nodes") { + SUBCASE("matchall") {} + SUBCASE("subset") {} + } + } + + // TEST_CASE("JoinedUndirectedGraphView") { + // UndirectedGraph g1 = UndirectedGraph::create(); + // UndirectedGraph g2 = UndirectedGraph::create(); + + // std::vector n1 = add_nodes(g1, 3); + // std::vector n2 = add_nodes(g2, 3); + + // add_edges(g1, + // {UndirectedEdge{{n1.at(0), n1.at(1)}}, + // UndirectedEdge{{n1.at(1), n1.at(2)}}}); + // add_edges(g2, + // {UndirectedEdge{{n2.at(0), n2.at(2)}}, + // UndirectedEdge{{n2.at(1), n2.at(2)}}}); + + // UndirectedGraphView view = join(g1, g2); + + // SUBCASE("get_nodes") { + // std::unordered_set expected = + // set_union(unordered_set_of(n1), unordered_set_of(n2)); + + // std::unordered_set result = get_nodes(view); + + // CHECK(result == expected); + // } + + // SUBCASE("get_edges") { + // std::unordered_set expected = { + // UndirectedEdge{{n1.at(0), n1.at(1)}}, + // UndirectedEdge{{n1.at(1), n1.at(2)}}, + // UndirectedEdge{{n2.at(0), n2.at(2)}}, + // UndirectedEdge{{n2.at(1), n2.at(2)}}}; + + // std::unordered_set result = get_edges(view); + + // CHECK(result == expected); + // } + // } + + // TEST_CASE("JoinedDigraphView") { + // DiGraph g1 = DiGraph::create(); + // DiGraph g2 = DiGraph::create(); + + // std::vector n1 = add_nodes(g1, 3); + // std::vector n2 = add_nodes(g2, 3); + + // add_edges( + // g1, + // {DirectedEdge{n1.at(0), n1.at(1)}, DirectedEdge{n1.at(1), + // n1.at(2)}}); + // add_edges( + // g2, + // {DirectedEdge{n2.at(0), n2.at(2)}, DirectedEdge{n2.at(1), + // n2.at(2)}}); + + // DiGraphView view = join(g1, g2); + + // SUBCASE("get_nodes") { + // std::unordered_set expected = + // set_union(unordered_set_of(n1), unordered_set_of(n2)); + + // std::unordered_set result = get_nodes(view); + + // CHECK(result == expected); + // } + + // SUBCASE("get_edges") { + // std::unordered_set expected = { + // DirectedEdge{n1.at(0), n1.at(1)}, + // DirectedEdge{n1.at(1), n1.at(2)}, + // DirectedEdge{n2.at(0), n2.at(2)}, + // DirectedEdge{n2.at(1), n2.at(2)}}; + + // std::unordered_set result = get_edges(view); + + // CHECK(result == expected); + // } + // } + + TEST_CASE("ViewDiGraphAsUndirectedGraph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(0), n.at(2)}}); + + UndirectedGraphView view = as_undirected(g); + + SUBCASE("get_nodes") { + std::unordered_set expected = unordered_set_of(n); + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(2), n.at(0)}}}; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + TEST_CASE("ViewUndirectedGraphAsDiGraph") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, + {UndirectedEdge{{n.at(0), n.at(0)}}, + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(2), n.at(0)}}}); + + DiGraphView view = as_digraph(g); + + SUBCASE("get_nodes") { + std::unordered_set expected = unordered_set_of(n); + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + DirectedEdge{n.at(0), n.at(0)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(0)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(0), n.at(2)}}; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } +} From 0e8c9625195c322b29960a1fa074191bf6dfbc2d Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Wed, 9 Oct 2024 21:21:57 -0700 Subject: [PATCH 2/8] fmt --- lib/utils/test/src/utils/graph/views/views.cc | 56 ++++++++++--------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc index 58f2e35cb5..fb5374bd6a 100644 --- a/lib/utils/test/src/utils/graph/views/views.cc +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -87,34 +87,36 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("JoinedNodeView") { - UndirectedGraph g1 = UndirectedGraph::create(); - UndirectedGraph g2 = UndirectedGraph::create(); - - std::vector n1 = add_nodes(g1, 3); - std::vector n2 = add_nodes(g2, 2); - std::unordered_set joined_nodes = - set_union(unordered_set_of(n1), unordered_set_of(n2)); - add_edges(g1, - {UndirectedEdge{{n1[0], n1[1]}}, UndirectedEdge{{n1[1], n1[2]}}}); - add_edges(g2, {UndirectedEdge{{n2[0], n2[1]}}}); - - JoinedNodeView joined_view(g1, g2); - - SUBCASE("trace_nodes") { - std::pair, std::unordered_set> result = - joined_view.trace_nodes(joined_nodes); - std::pair, std::unordered_set> correct = { - {n1[0], n1[1], n1[2]}, {n2[0], n2[1]}}; - - CHECK(result == correct); - } + // TEST_CASE("JoinedNodeView") { + // UndirectedGraph g1 = UndirectedGraph::create(); + // UndirectedGraph g2 = UndirectedGraph::create(); - SUBCASE("query_nodes") { - SUBCASE("matchall") {} - SUBCASE("subset") {} - } - } + // std::vector n1 = add_nodes(g1, 3); + // std::vector n2 = add_nodes(g2, 2); + // std::unordered_set joined_nodes = + // set_union(unordered_set_of(n1), unordered_set_of(n2)); + // add_edges(g1, + // {UndirectedEdge{{n1[0], n1[1]}}, UndirectedEdge{{n1[1], + // n1[2]}}}); + // add_edges(g2, {UndirectedEdge{{n2[0], n2[1]}}}); + + // JoinedNodeView joined_view(g1, g2); + + // SUBCASE("trace_nodes") { + // std::pair, std::unordered_set> result = + // joined_view.trace_nodes(joined_nodes); + // std::pair, std::unordered_set> correct = + // { + // {n1[0], n1[1], n1[2]}, {n2[0], n2[1]}}; + + // CHECK(result == correct); + // } + + // SUBCASE("query_nodes") { + // SUBCASE("matchall") {} + // SUBCASE("subset") {} + // } + // } // TEST_CASE("JoinedUndirectedGraphView") { // UndirectedGraph g1 = UndirectedGraph::create(); From 8e11b0b2e0db8777bfc75f526fa57b8945a8575f Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 14 Oct 2024 17:09:00 -0700 Subject: [PATCH 3/8] removed unneccesary views, fixed views adjacent bugs --- lib/utils/include/utils/graph/views/views.h | 126 +------------ lib/utils/src/utils/graph/algorithms.cc | 9 - .../instances/hashmap_undirected_graph.cc | 9 +- lib/utils/src/utils/graph/views/views.cc | 155 +--------------- .../algorithms/get_connected_components.cc | 41 ++++- lib/utils/test/src/utils/graph/views/views.cc | 167 +++--------------- 6 files changed, 77 insertions(+), 430 deletions(-) diff --git a/lib/utils/include/utils/graph/views/views.h b/lib/utils/include/utils/graph/views/views.h index aaa1e033f4..5e0109ed5b 100644 --- a/lib/utils/include/utils/graph/views/views.h +++ b/lib/utils/include/utils/graph/views/views.h @@ -41,104 +41,11 @@ struct DiSubgraphView : public IDiGraphView { std::unordered_set subgraph_nodes; }; -struct JoinedNodeView { -public: - JoinedNodeView() = delete; - explicit JoinedNodeView(GraphView const &lhs, GraphView const &rhs); - - std::unordered_set query_nodes(NodeQuery const &) const; - std::pair, std::unordered_set> - trace_nodes(std::unordered_set const &) const; - - Node at_join_key(JoinNodeKey const &) const; - JoinNodeKey at_node(Node const &) const; - -private: - bidict mapping; - NodeSource node_source; -}; - -struct JoinedUndirectedGraphView : public IUndirectedGraphView { -public: - JoinedUndirectedGraphView() = delete; - explicit JoinedUndirectedGraphView(UndirectedGraphView const &lhs, - UndirectedGraphView const &rhs); - - std::unordered_set - query_edges(UndirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - JoinedUndirectedGraphView *clone() const override; - -private: - UndirectedEdge fix_lhs_edge(UndirectedEdge const &) const; - UndirectedEdge fix_rhs_edge(UndirectedEdge const &) const; - -private: - UndirectedGraphView lhs; - UndirectedGraphView rhs; - JoinedNodeView joined_nodes; -}; - -struct JoinedDigraphView : virtual public IDiGraphView { -public: - JoinedDigraphView() = delete; - explicit JoinedDigraphView(DiGraphView const &lhs, DiGraphView const &rhs); - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - JoinedNodeView const &joined_nodes_view() const; +UndirectedGraphView view_subgraph(UndirectedGraphView const &, + std::unordered_set const &); - JoinedDigraphView *clone() const override; - -private: - DirectedEdge fix_lhs_edge(DirectedEdge const &) const; - DirectedEdge fix_rhs_edge(DirectedEdge const &) const; - -private: - DiGraphView lhs; - DiGraphView rhs; - JoinedNodeView joined_nodes; -}; - -struct AddDirectedEdgesView : public IDiGraphView { -public: - AddDirectedEdgesView() = delete; - - explicit AddDirectedEdgesView(DiGraphView const &g, - std::unordered_set const &edges); - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - AddDirectedEdgesView *clone() const override; - -private: - DiGraphView g; - std::unordered_set edges; -}; - -struct SingleSourceNodeView : public IDiGraphView { -public: - SingleSourceNodeView() = delete; - - explicit SingleSourceNodeView(DiGraphView const &g) : g(g) {} - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - SingleSourceNodeView *clone() const override; - -private: - DiGraphView g; - std::optional singleton_src; - std::optional joined_view; - std::unique_ptr added_edges_view; -}; +DiGraphView view_subgraph(DiGraphView const &, + std::unordered_set const &); UndirectedEdge to_undirected_edge(DirectedEdge const &); std::unordered_set @@ -176,31 +83,6 @@ struct ViewUndirectedGraphAsDiGraph : public IDiGraphView { UndirectedGraphView g; }; -std::unordered_map - flatten_contraction(std::unordered_map const &); - -template -Impl materialize_view(View const &g) { - Impl result; - for (Node const &n : get_nodes(g)) { - result.add_node_unsafe(n); - } - for (auto const &e : get_edges(g)) { - result.add_edge(e); - } - return result; -} - -template -Impl materialize_undirected_graph_view(IUndirectedGraphView const &g) { - return materialize_view(g); -} - -template -Impl materialize_digraph_view(IDiGraphView const &g) { - return materialize_view(g); -} - } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 79c4fc9964..d7cd979f14 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -481,15 +481,6 @@ DiGraphView get_subgraph(DiGraphView const &g, // return MultiDiGraphView::create(lhs, rhs); // } -DiGraphView join(DiGraphView const &lhs, DiGraphView const &rhs) { - return DiGraphView::create(lhs, rhs); -} - -UndirectedGraphView join(UndirectedGraphView const &lhs, - UndirectedGraphView const &rhs) { - return UndirectedGraphView::create(lhs, rhs); -} - UndirectedGraphView as_undirected(DiGraphView const &g) { return UndirectedGraphView::create(g); } diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index df84683a6b..5d16304701 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -38,16 +38,17 @@ void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { } void HashmapUndirectedGraph::remove_edge(UndirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.endpoints.max()); - m.erase(e.endpoints.min()); - m.erase(e.endpoints.max()); + std::unordered_set &max_map = this->adjacency.at(e.endpoints.max()); + max_map.erase(e.endpoints.min()); + std::unordered_set &min_map = this->adjacency.at(e.endpoints.min()); + min_map.erase(e.endpoints.max()); } std::unordered_set HashmapUndirectedGraph::query_edges( UndirectedEdgeQuery const &query) const { std::unordered_set result; for (auto const &src_kv : query_keys(query.nodes, this->adjacency)) { - for (auto const &dst : src_kv.second) { + for (auto const &dst : apply_query(query.nodes, src_kv.second)) { result.insert(UndirectedEdge{{src_kv.first, dst}}); } } diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index c29d478f1e..7bb039d314 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -66,153 +66,6 @@ DiGraphView view_subgraph(DiGraphView const &g, return DiGraphView::create(g, subgraph_nodes); } -JoinedNodeView::JoinedNodeView(GraphView const &lhs, GraphView const &rhs) { - for (Node const &n : get_nodes(lhs)) { - this->mapping.equate(JoinNodeKey{n, LRDirection::LEFT}, - this->node_source.new_node()); - } - for (Node const &n : get_nodes(rhs)) { - this->mapping.equate(JoinNodeKey{n, LRDirection::RIGHT}, - this->node_source.new_node()); - } -} - -std::unordered_set - JoinedNodeView::query_nodes(NodeQuery const &query) const { - std::unordered_set nodes = right_entries(this->mapping); - if (query == node_query_all()) { - return nodes; - } - return filter(nodes, [&](Node const &n) { - return contains(allowed_values(query.nodes), n); - }); -} - -std::pair, std::unordered_set> - JoinedNodeView::trace_nodes(std::unordered_set const &nodes) const { - std::unordered_set left_nodes, right_nodes; - - for (Node const &n : nodes) { - JoinNodeKey k = this->at_node(n); - if (k.direction == LRDirection::LEFT) { - left_nodes.insert(k.node); - } else { - assert(k.direction == LRDirection::RIGHT); - right_nodes.insert(k.node); - } - } - - return {left_nodes, right_nodes}; -} - -Node JoinedNodeView::at_join_key(JoinNodeKey const &k) const { - return this->mapping.at_l(k); -} - -JoinNodeKey JoinedNodeView::at_node(Node const &n) const { - return this->mapping.at_r(n); -} - -JoinedUndirectedGraphView::JoinedUndirectedGraphView( - UndirectedGraphView const &lhs, UndirectedGraphView const &rhs) - : lhs(lhs), rhs(rhs), joined_nodes(lhs, rhs) {} - -std::unordered_set - JoinedUndirectedGraphView::query_nodes(NodeQuery const &query) const { - return this->joined_nodes.query_nodes(query); -} - -std::unordered_set JoinedUndirectedGraphView::query_edges( - UndirectedEdgeQuery const &query) const { - std::unordered_set nodes = this->query_nodes(NodeQuery{query.nodes}); - std::unordered_set left_nodes, right_nodes; - for (Node const &n : nodes) { - JoinNodeKey k = this->joined_nodes.at_node(n); - if (k.direction == LRDirection::LEFT) { - left_nodes.insert(k.node); - } else { - assert(k.direction == LRDirection::RIGHT); - right_nodes.insert(k.node); - } - } - - std::unordered_set result; - for (UndirectedEdge const &e : - this->lhs.query_edges(UndirectedEdgeQuery{left_nodes})) { - result.insert(this->fix_lhs_edge(e)); - } - for (UndirectedEdge const &e : - this->rhs.query_edges(UndirectedEdgeQuery{right_nodes})) { - result.insert(this->fix_rhs_edge(e)); - } - - return result; -} - -UndirectedEdge - JoinedUndirectedGraphView::fix_lhs_edge(UndirectedEdge const &e) const { - return UndirectedEdge{{this->joined_nodes.at_join_key( - JoinNodeKey{e.endpoints.min(), LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{ - e.endpoints.max(), LRDirection::LEFT})}}; -} - -UndirectedEdge - JoinedUndirectedGraphView::fix_rhs_edge(UndirectedEdge const &e) const { - return UndirectedEdge{{this->joined_nodes.at_join_key(JoinNodeKey{ - e.endpoints.min(), LRDirection::RIGHT}), - this->joined_nodes.at_join_key(JoinNodeKey{ - e.endpoints.max(), LRDirection::RIGHT})}}; -} - -JoinedDigraphView::JoinedDigraphView(DiGraphView const &lhs, - DiGraphView const &rhs) - : lhs(lhs), rhs(rhs), joined_nodes(lhs, rhs) {} - -JoinedDigraphView *JoinedDigraphView::clone() const { - return new JoinedDigraphView(lhs, rhs); -} - -std::unordered_set - JoinedDigraphView::query_nodes(NodeQuery const &query) const { - return this->joined_nodes.query_nodes(query); -} - -std::unordered_set - JoinedDigraphView::query_edges(DirectedEdgeQuery const &query) const { - - std::unordered_set srcs = this->query_nodes(NodeQuery{query.srcs}); - std::unordered_set dsts = this->query_nodes(NodeQuery{query.dsts}); - auto traced_srcs = this->joined_nodes.trace_nodes(srcs); - auto traced_dsts = this->joined_nodes.trace_nodes(dsts); - DirectedEdgeQuery left_query = - DirectedEdgeQuery{traced_srcs.first, traced_dsts.first}; - DirectedEdgeQuery right_query = - DirectedEdgeQuery{traced_srcs.second, traced_dsts.second}; - - std::unordered_set result; - for (DirectedEdge const &e : this->lhs.query_edges(left_query)) { - result.insert(this->fix_lhs_edge(e)); - } - for (DirectedEdge const &e : this->rhs.query_edges(right_query)) { - result.insert(this->fix_rhs_edge(e)); - } - - return result; -} - -DirectedEdge JoinedDigraphView::fix_lhs_edge(DirectedEdge const &e) const { - return DirectedEdge{ - this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::LEFT})}; -} - -DirectedEdge JoinedDigraphView::fix_rhs_edge(DirectedEdge const &e) const { - return DirectedEdge{ - this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::RIGHT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::RIGHT})}; -} - UndirectedEdge to_undirected_edge(DirectedEdge const &e) { return UndirectedEdge{{e.src, e.dst}}; } @@ -265,8 +118,8 @@ ViewUndirectedGraphAsDiGraph *ViewUndirectedGraphAsDiGraph::clone() const { std::unordered_set ViewUndirectedGraphAsDiGraph::query_edges( DirectedEdgeQuery const &q) const { std::unordered_set undirected_edges = - intersection(g.query_edges(UndirectedEdgeQuery{q.srcs}), - g.query_edges(UndirectedEdgeQuery{q.dsts})); + set_union(g.query_edges(UndirectedEdgeQuery{q.srcs}), + g.query_edges(UndirectedEdgeQuery{q.dsts})); std::unordered_set directed_edges = flatmap(undirected_edges, [](UndirectedEdge const &e) { return to_directed_edges(e); }); @@ -279,8 +132,4 @@ std::unordered_set return g.query_nodes(q); } -JoinedUndirectedGraphView *JoinedUndirectedGraphView::clone() const { - return new JoinedUndirectedGraphView(lhs, rhs); -} - } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc index 7f6f0dd064..179cce7db7 100644 --- a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc +++ b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc @@ -7,10 +7,24 @@ using namespace FlexFlow; -TEST_SUITE(FF_TEST_SUITE) { +TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); - TEST_CASE("get_connected_components") { - UndirectedGraph g = UndirectedGraph::create(); + SUBCASE("disjoint nodes") { + std::vector n = add_nodes(g, 3); + + std::unordered_set> correct = { + {n[0]}, + {n[1]}, + {n[2]}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("2 components") { std::vector n = add_nodes(g, 4); add_edges(g, {UndirectedEdge{{n[0], n[1]}}, UndirectedEdge{{n[2], n[1]}}}); @@ -23,4 +37,25 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(correct == result); } + + SUBCASE("3 components") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + UndirectedEdge{{n[0], n[1]}}, + UndirectedEdge{{n[0], n[2]}}, + UndirectedEdge{{n[1], n[2]}}, + UndirectedEdge{{n[3], n[4]}}, + }); + + std::unordered_set> correct = { + {n[0], n[1], n[2]}, + {n[3], n[4]}, + {n[5]}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } } diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc index fb5374bd6a..8a6a44d1cc 100644 --- a/lib/utils/test/src/utils/graph/views/views.cc +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -14,41 +14,39 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("UndirectedSubgraphView") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + {UndirectedEdge{{n.at(0), n.at(3)}}, + UndirectedEdge{{n.at(1), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(1), n.at(3)}}, + UndirectedEdge{{n.at(2), n.at(3)}}, + UndirectedEdge{{n.at(2), n.at(4)}}}); + std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; + UndirectedGraphView view = view_subgraph(g, sub_nodes); - // TEST_CASE("UndirectedSubgraphView") { - // UndirectedGraph g = UndirectedGraph::create(); - // std::vector n = add_nodes(g, 5); - // add_edges(g, - // {UndirectedEdge{{n.at(0), n.at(3)}}, - // UndirectedEdge{{n.at(1), n.at(1)}}, - // UndirectedEdge{{n.at(1), n.at(2)}}, - // UndirectedEdge{{n.at(1), n.at(3)}}, - // UndirectedEdge{{n.at(2), n.at(3)}}, - // UndirectedEdge{{n.at(2), n.at(4)}}}); - // std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; - // UndirectedGraphView view = get_subgraph(g, sub_nodes); - - // SUBCASE("get_nodes") { - // std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; + SUBCASE("get_nodes") { + std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; - // std::unordered_set result = get_nodes(view); + std::unordered_set result = get_nodes(view); - // CHECK(result == expected); - // } + CHECK(result == expected); + } - // SUBCASE("get_edges") { - // std::unordered_set expected = { - // UndirectedEdge{{n.at(0), n.at(3)}}, - // UndirectedEdge{{n.at(1), n.at(1)}}, - // UndirectedEdge{{n.at(1), n.at(3)}}, - // }; + SUBCASE("get_edges") { + std::unordered_set expected = { + UndirectedEdge{{n.at(0), n.at(3)}}, + UndirectedEdge{{n.at(1), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(3)}}, + }; - // std::unordered_set result = get_edges(view); + std::unordered_set result = get_edges(view); - // // TODO(@pietro) TODO(@lockshaw) current BUG, get_edges also - // CHECK(result == expected); - // } - // } + CHECK(result == expected); + } + } TEST_CASE("DiSubgraphView") { DiGraph g = DiGraph::create(); @@ -63,7 +61,7 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(3), n.at(2)}, DirectedEdge{n.at(2), n.at(4)}}); std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; - DiGraphView view = get_subgraph(g, sub_nodes); + DiGraphView view = view_subgraph(g, sub_nodes); SUBCASE("get_nodes") { std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; @@ -87,115 +85,6 @@ TEST_SUITE(FF_TEST_SUITE) { } } - // TEST_CASE("JoinedNodeView") { - // UndirectedGraph g1 = UndirectedGraph::create(); - // UndirectedGraph g2 = UndirectedGraph::create(); - - // std::vector n1 = add_nodes(g1, 3); - // std::vector n2 = add_nodes(g2, 2); - // std::unordered_set joined_nodes = - // set_union(unordered_set_of(n1), unordered_set_of(n2)); - // add_edges(g1, - // {UndirectedEdge{{n1[0], n1[1]}}, UndirectedEdge{{n1[1], - // n1[2]}}}); - // add_edges(g2, {UndirectedEdge{{n2[0], n2[1]}}}); - - // JoinedNodeView joined_view(g1, g2); - - // SUBCASE("trace_nodes") { - // std::pair, std::unordered_set> result = - // joined_view.trace_nodes(joined_nodes); - // std::pair, std::unordered_set> correct = - // { - // {n1[0], n1[1], n1[2]}, {n2[0], n2[1]}}; - - // CHECK(result == correct); - // } - - // SUBCASE("query_nodes") { - // SUBCASE("matchall") {} - // SUBCASE("subset") {} - // } - // } - - // TEST_CASE("JoinedUndirectedGraphView") { - // UndirectedGraph g1 = UndirectedGraph::create(); - // UndirectedGraph g2 = UndirectedGraph::create(); - - // std::vector n1 = add_nodes(g1, 3); - // std::vector n2 = add_nodes(g2, 3); - - // add_edges(g1, - // {UndirectedEdge{{n1.at(0), n1.at(1)}}, - // UndirectedEdge{{n1.at(1), n1.at(2)}}}); - // add_edges(g2, - // {UndirectedEdge{{n2.at(0), n2.at(2)}}, - // UndirectedEdge{{n2.at(1), n2.at(2)}}}); - - // UndirectedGraphView view = join(g1, g2); - - // SUBCASE("get_nodes") { - // std::unordered_set expected = - // set_union(unordered_set_of(n1), unordered_set_of(n2)); - - // std::unordered_set result = get_nodes(view); - - // CHECK(result == expected); - // } - - // SUBCASE("get_edges") { - // std::unordered_set expected = { - // UndirectedEdge{{n1.at(0), n1.at(1)}}, - // UndirectedEdge{{n1.at(1), n1.at(2)}}, - // UndirectedEdge{{n2.at(0), n2.at(2)}}, - // UndirectedEdge{{n2.at(1), n2.at(2)}}}; - - // std::unordered_set result = get_edges(view); - - // CHECK(result == expected); - // } - // } - - // TEST_CASE("JoinedDigraphView") { - // DiGraph g1 = DiGraph::create(); - // DiGraph g2 = DiGraph::create(); - - // std::vector n1 = add_nodes(g1, 3); - // std::vector n2 = add_nodes(g2, 3); - - // add_edges( - // g1, - // {DirectedEdge{n1.at(0), n1.at(1)}, DirectedEdge{n1.at(1), - // n1.at(2)}}); - // add_edges( - // g2, - // {DirectedEdge{n2.at(0), n2.at(2)}, DirectedEdge{n2.at(1), - // n2.at(2)}}); - - // DiGraphView view = join(g1, g2); - - // SUBCASE("get_nodes") { - // std::unordered_set expected = - // set_union(unordered_set_of(n1), unordered_set_of(n2)); - - // std::unordered_set result = get_nodes(view); - - // CHECK(result == expected); - // } - - // SUBCASE("get_edges") { - // std::unordered_set expected = { - // DirectedEdge{n1.at(0), n1.at(1)}, - // DirectedEdge{n1.at(1), n1.at(2)}, - // DirectedEdge{n2.at(0), n2.at(2)}, - // DirectedEdge{n2.at(1), n2.at(2)}}; - - // std::unordered_set result = get_edges(view); - - // CHECK(result == expected); - // } - // } - TEST_CASE("ViewDiGraphAsUndirectedGraph") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 3); From 87e767cdafed8d520180c73e49a3cb692ff87200 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 14 Oct 2024 18:54:44 -0700 Subject: [PATCH 4/8] minor optimization --- .../series_parallel/parallel_reduction.cc | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 12a6630bf0..609d065660 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,5 +1,9 @@ #include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" +#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/multidigraph/algorithms/get_edge_counts.h" namespace FlexFlow { @@ -10,13 +14,17 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &e1, std::optional find_parallel_reduction(MultiDiGraphView const &g) { - std::unordered_set edges = get_edges(g); - for (MultiDiEdge const &e1 : edges) { - for (MultiDiEdge const &e2 : edges) { - if (e1 != e2 && g.get_multidiedge_src(e1) == g.get_multidiedge_src(e2) && - g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { - return make_parallel_reduction(e1, e2); + for (auto const &[directed_edge, count] : get_edge_counts(g)) { + + if (count <= 1) {continue;} + + std::unordered_set const &outgoing_edges = get_outgoing_edges(g, directed_edge.src); + for (MultiDiEdge const &e1 : outgoing_edges) { + for (MultiDiEdge const &e2 : outgoing_edges) { + if (e1 != e2 && g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { + return make_parallel_reduction(e1, e2); + } } } } From 44b32f88c47b1a10eeaab6d42843afb9d090d7c0 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 14 Oct 2024 19:04:17 -0700 Subject: [PATCH 5/8] fmt --- .../graph/series_parallel/parallel_reduction.cc | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 609d065660..78265f6856 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,9 +1,9 @@ #include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/graph/multidigraph/algorithms/get_edge_counts.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" -#include "utils/graph/node/algorithms.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" -#include "utils/graph/multidigraph/algorithms/get_edge_counts.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -17,12 +17,16 @@ std::optional for (auto const &[directed_edge, count] : get_edge_counts(g)) { - if (count <= 1) {continue;} + if (count <= 1) { + continue; + } - std::unordered_set const &outgoing_edges = get_outgoing_edges(g, directed_edge.src); + std::unordered_set const &outgoing_edges = + get_outgoing_edges(g, directed_edge.src); for (MultiDiEdge const &e1 : outgoing_edges) { for (MultiDiEdge const &e2 : outgoing_edges) { - if (e1 != e2 && g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { + if (e1 != e2 && + g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { return make_parallel_reduction(e1, e2); } } From e7055ad8ac1da6dce17e5691600bb4aa45cf8b7d Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 14 Oct 2024 19:40:37 -0700 Subject: [PATCH 6/8] small fix --- .../utils/graph/digraph/algorithms/get_topological_ordering.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc index de6953fad4..5adc0cc4df 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc @@ -1,5 +1,5 @@ #include "utils/graph/digraph/algorithms/get_topological_ordering.h" -#include "utils/containers.h" +#include "utils/containers/index_of.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" From 8564b8b979419ced5a67774c89777d380061e300 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Wed, 27 Nov 2024 11:38:54 -0800 Subject: [PATCH 7/8] get_series_parallel_decomposition fix --- lib/utils/include/utils/containers/find.h | 7 +++ .../series_parallel/parallel_reduction.h | 7 +++ .../series_parallel_decomposition.h | 19 ++++++ .../get_series_parallel_decomposition.cc | 61 +++++++++--------- .../series_parallel/parallel_reduction.cc | 63 +++++++++++++------ .../series_parallel_decomposition.cc | 61 ++++++++++++++++++ .../graph/series_parallel/series_reduction.cc | 33 ++++------ .../test/src/utils/containers/contains.cc | 15 ++++- 8 files changed, 192 insertions(+), 74 deletions(-) diff --git a/lib/utils/include/utils/containers/find.h b/lib/utils/include/utils/containers/find.h index eed5f8453c..7b103fed16 100644 --- a/lib/utils/include/utils/containers/find.h +++ b/lib/utils/include/utils/containers/find.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FIND_H #include +#include namespace FlexFlow { @@ -11,6 +12,12 @@ typename Container::const_iterator return std::find(c.cbegin(), c.cend(), e); } +template +typename std::unordered_set::const_iterator + find(std::unordered_set const &c, V const &e) { + return c.find(e); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h index 3fc1347ee5..0b3c7f3619 100644 --- a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h @@ -12,8 +12,15 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &, std::optional find_parallel_reduction(MultiDiGraphView const &); +std::unordered_map> + find_all_extended_parallel_reductions(MultiDiGraphView const &); + MultiDiEdge apply_parallel_reduction(MultiDiGraph &, ParallelReduction const &); +MultiDiEdge + apply_extended_parallel_reduction(MultiDiGraph &, + std::unordered_set const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h index 52d2cb7236..d56d4a55f7 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -17,6 +17,25 @@ std::unordered_multiset get_nodes(SeriesSplit const &); std::unordered_multiset get_nodes(ParallelSplit const &); std::unordered_multiset get_nodes(Node const &); +bool is_empty(Node const &node); +bool is_empty(SeriesSplit const &serial); +bool is_empty(ParallelSplit const ¶llel); +bool is_empty(SeriesParallelDecomposition const &sp); + +bool has_no_duplicate_nodes(SeriesParallelDecomposition const &sp); + +SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp, + Node const &node); + +// duplicate nodes within `sp` are counted multiple times +size_t num_nodes(SeriesParallelDecomposition const &sp); + +SeriesParallelDecomposition serial_composition( + std::vector const &sp_compositions); +SeriesParallelDecomposition parallel_composition( + std::unordered_multiset const + &sp_compositions); + } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index cd29af59a0..7a5cb1ea82 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -2,14 +2,18 @@ #include "utils/containers/get_only.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" #include "utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.h" #include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" #include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/graph/series_parallel/series_reduction.h" @@ -26,39 +30,18 @@ std::optional if (!maybe_line_graph.has_value()) { return std::nullopt; } - maybe_line_graph.value(); }); MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); - std::unordered_map + std::unordered_map ttsp_edge_to_sp_tree = map_values( inverse_line_graph_result.inverse_edge_to_line_node_bidict .as_unordered_map(), - [](Node const &n) { return BinarySPDecompositionTree{n}; }); + [](Node const &n) { return SeriesParallelDecomposition{n}; }); while (true) { - assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); - std::optional maybe_parallel_reduction = - find_parallel_reduction(ttsp); - if (maybe_parallel_reduction.has_value()) { - ParallelReduction parallel_reduction = maybe_parallel_reduction.value(); - auto [e1, e2] = parallel_reduction.edges.ordered(); - MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); - BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ - BinaryParallelSplit{ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }, - }; - ttsp_edge_to_sp_tree.erase(e1); - ttsp_edge_to_sp_tree.erase(e2); - ttsp_edge_to_sp_tree.insert({merged, new_tree}); - - continue; - } - std::optional maybe_series_reduction = find_series_reduction(ttsp); if (maybe_series_reduction.has_value()) { @@ -66,15 +49,33 @@ std::optional MultiDiEdge e1 = series_reduction.first; MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ - BinarySeriesSplit{ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }, - }; + + SeriesParallelDecomposition new_tree = serial_composition({ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }); + ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); + + continue; + } + std::unordered_map> + parallel_reductions = find_all_extended_parallel_reductions(ttsp); + if (!parallel_reductions.empty()) { + for (auto const &[_, parallel_reduction] : parallel_reductions) { + MultiDiEdge merged = + apply_extended_parallel_reduction(ttsp, parallel_reduction); + + SeriesParallelDecomposition new_tree = parallel_composition(transform( + unordered_multiset_of(parallel_reduction), + [&](MultiDiEdge const &e) { return ttsp_edge_to_sp_tree.at(e); })); + for (MultiDiEdge const &e : parallel_reduction) { + ttsp_edge_to_sp_tree.erase(e); + } + ttsp_edge_to_sp_tree.insert({merged, new_tree}); + } continue; } @@ -87,7 +88,7 @@ std::optional MultiDiEdge e = get_only(get_edges(ttsp)); if (ttsp.get_multidiedge_src(e) != ttsp.get_multidiedge_dst(e)) { - return nary_sp_tree_from_binary(ttsp_edge_to_sp_tree.at(e)); + return ttsp_edge_to_sp_tree.at(e); } } } diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 78265f6856..c7eb866b62 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,9 +1,17 @@ #include "utils/graph/series_parallel/parallel_reduction.h" -#include "utils/graph/multidigraph/algorithms/get_edge_counts.h" +#include "utils/containers/get_one_of.h" +#include "utils/containers/group_by.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" -#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" -#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" +#include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/node/algorithms.h" +#include "utils/hash/unordered_set.h" +#include +#include namespace FlexFlow { @@ -15,31 +23,48 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &e1, std::optional find_parallel_reduction(MultiDiGraphView const &g) { - for (auto const &[directed_edge, count] : get_edge_counts(g)) { - - if (count <= 1) { - continue; - } - - std::unordered_set const &outgoing_edges = - get_outgoing_edges(g, directed_edge.src); - for (MultiDiEdge const &e1 : outgoing_edges) { - for (MultiDiEdge const &e2 : outgoing_edges) { - if (e1 != e2 && - g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { - return make_parallel_reduction(e1, e2); - } - } + std::unordered_map seen; + for (MultiDiEdge const &edge : get_edges(g)) { + DirectedEdge diedge = get_directed_edge(g, edge); + if (seen.find(diedge) != seen.end()) { + return make_parallel_reduction(seen.at(diedge), edge); } + seen.emplace(diedge, edge); } - return std::nullopt; } +std::unordered_map> + find_all_extended_parallel_reductions(MultiDiGraphView const &g) { + std::unordered_map> + parallel_groups = group_by(get_edges(g), [&](MultiDiEdge const &edge) { + return get_directed_edge(g, edge); + }); + + return filter( + parallel_groups, + [](std::pair> const + &group) { return group.second.size() > 1; }); +} + MultiDiEdge apply_parallel_reduction(MultiDiGraph &g, ParallelReduction const &r) { g.remove_edge(r.edges.max()); return r.edges.min(); } +MultiDiEdge apply_extended_parallel_reduction( + MultiDiGraph &g, std::unordered_set const ¶llel_edges) { + + MultiDiEdge keep_edge = get_one_of(parallel_edges); + + for (MultiDiEdge const ¶llel_edge : parallel_edges) { + if (parallel_edge != keep_edge) { + g.remove_edge(parallel_edge); + } + } + + return keep_edge; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index b7a84b871a..dc99ef6c5a 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,12 +1,17 @@ #include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/containers/all_of.h" +#include "utils/containers/extend.h" #include "utils/containers/multiset_union.h" #include "utils/containers/set_union.h" +#include "utils/containers/sum.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/values.h" #include "utils/containers/vector_of.h" #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/hash/unordered_set.h" #include "utils/variant.h" +#include namespace FlexFlow { @@ -74,4 +79,60 @@ std::unordered_multiset get_nodes(Node const &node) { return {node}; } +bool is_empty(Node const &node) { + return false; +} + +bool is_empty(SeriesSplit const &serial) { + return all_of(serial.children, [](auto const &child) { + return is_empty(widen(child)); + }); +} + +bool is_empty(ParallelSplit const ¶llel) { + return all_of(parallel.get_children(), [](auto const &child) { + return is_empty(widen(child)); + }); +} + +bool is_empty(SeriesParallelDecomposition const &sp) { + return sp.visit([](auto const &t) { return is_empty(t); }); +} + +SeriesParallelDecomposition serial_composition( + std::vector const &sp_compositions) { + std::vector> composition{}; + for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + extend(composition, sp_comp.get().children); + } else if (sp_comp.has()) { + composition.push_back(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.push_back(sp_comp.get()); + } + } + return SeriesParallelDecomposition{SeriesSplit{composition}}; +} + +SeriesParallelDecomposition parallel_composition( + std::unordered_multiset const + &sp_compositions) { + std::unordered_multiset< + std::variant<::FlexFlow::SeriesSplit, ::FlexFlow::Node>> + composition{}; + for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + composition = multiset_union(composition, + sp_comp.get().get_children()); + } else if (sp_comp.has()) { + composition.insert(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.insert(sp_comp.get()); + } + } + return SeriesParallelDecomposition(ParallelSplit{composition}); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc index 7300c93fb0..c312bb4a6b 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -1,8 +1,14 @@ #include "utils/graph/series_parallel/series_reduction.h" +#include "utils/containers/contains.h" +#include "utils/containers/get_only.h" #include "utils/containers/require_same.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/multidigraph/multidigraph_view.h" +#include "utils/graph/node/algorithms.h" +#include namespace FlexFlow { @@ -26,30 +32,13 @@ SeriesReduction make_series_reduction(MultiDiEdge const &e1, std::optional find_series_reduction(MultiDiGraphView const &g) { - std::unordered_set edges = get_edges(g); - - for (MultiDiEdge const &e1 : edges) { - for (MultiDiEdge const &e2 : edges) { - if (e1 == e2) { - continue; - } - Node e1_dst = g.get_multidiedge_dst(e1); - Node e2_src = g.get_multidiedge_src(e2); - if (e1_dst != e2_src) { - continue; - } - - std::unordered_set outgoing = get_outgoing_edges(g, e1_dst); - std::unordered_set incoming = get_incoming_edges(g, e1_dst); - - if (outgoing.size() > 1 || incoming.size() > 1) { - continue; - } - - return SeriesReduction{e1, e2}; + for (Node const &node : get_nodes(g)) { + if (get_incoming_edges(g, node).size() == 1 && + get_outgoing_edges(g, node).size() == 1) { + return make_series_reduction(get_only(get_incoming_edges(g, node)), + get_only(get_outgoing_edges(g, node))); } } - return std::nullopt; } diff --git a/lib/utils/test/src/utils/containers/contains.cc b/lib/utils/test/src/utils/containers/contains.cc index 6e0a84c7ab..fc42d25eea 100644 --- a/lib/utils/test/src/utils/containers/contains.cc +++ b/lib/utils/test/src/utils/containers/contains.cc @@ -1,13 +1,22 @@ #include "utils/containers/contains.h" #include +#include #include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("contains") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); + SUBCASE("std::vector") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(contains(v, 3)); + CHECK(!contains(v, 6)); + } + + SUBCASE("std::unordered_set") { + std::unordered_set s = {1, 2, 3, 4, 5}; + CHECK(contains(s, 3)); + CHECK(!contains(s, 6)); + } } } From 8a04a2ee3bcde144165c269dbfed4417d8df94d9 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Fri, 20 Dec 2024 14:55:52 -0800 Subject: [PATCH 8/8] updates to graph documentation + get_series_parallel_decomposition speedup --- lib/utils/include/utils/graph/README.md | 114 +++++++++++-- .../algorithms/get_incoming_edges.h | 3 + .../algorithms/get_outgoing_edges.h | 3 + .../graph/series_parallel/series_reduction.h | 7 + .../algorithms/get_incoming_edges.cc | 17 ++ .../algorithms/get_outgoing_edges.cc | 17 +- .../get_series_parallel_decomposition.cc | 51 +++--- .../graph/series_parallel/series_reduction.cc | 51 ++++++ .../graph/series_parallel/series_reduction.cc | 152 ++++++++++++++++++ 9 files changed, 382 insertions(+), 33 deletions(-) diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index f3d31e7bc8..41777b9b9a 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -18,13 +18,8 @@ At their core, they are as follows: - `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected. - `DiGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) - `MultiDiGraph`: arbitrary numbers of directed edges allowed between every pair of nodes. -- `DataflowGraph`: similar to `MultiDiGraph`, but with the following differences: - - The edges entering, exiting a given nodes now have a well-defined order. - - Due to the interface used to construct them (where essentially a node can only be added to the graph after all of its predecessor nodes have been added) `DataflowGraph`s are directed acyclic graphs. - - Each node has an associated ordered sequence of inputs and outputs, with the restriction that one and only one edge can enter an individual input. +- `DataflowGraph`: used to model computation graphs. See the [DataflowGraph](#dataflowgraph) section for a detailed explanation. -Conceptually, `DataflowGraph` is used within FlexFlow to represent computation-style graphs, where edges represent value uses and nodes represent multivariate functions from tuples of inputs to tuples of outputs. - Examples of the different graph variants are shown below. Example of `UndirectedGraph`: @@ -89,7 +84,9 @@ Nodes are of type `Node`, and from a user perspective are simply opaque handles, In addition, nodes should only be used in the context of their graph, so comparing or checking equality of nodes between different graphs (even of the same type) is undefined behavior[^1]. All three core graph variants allow insertion and deletion of both edges and nodes. -To add a node to an `UndirectedGraph g`, simply call `g.add_node()` (the interface is identical for `DiGraph` and `MultiDiGraph`). +To add a node to an `UndirectedGraph g`, simply call `g.add_node()`, which will return a `Node` object. +For semantics closer to `networkx`'s method of adding nodes, `g.add_node_unsafe(my_node)` can be used. This is useful when constructing a modified copy of an existing graph (given that it maintains node bijection), though it is not generally recommended. +The interface for node addition is identical for `DiGraph` and `MultiDiGraph`. To add an edge between two nodes `Node n1` and `Node n2` to an `UndirectedGraph g`, call `g.add_edge({n1, n2})`. In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph`, `MultiDiGraph` and `DataflowGraph`. @@ -114,8 +111,8 @@ Both `Graph` and `GraphView` types follow normal value semantics. This may seem wasteful (oftentimes graphs are large objects that are passed around via reference to avoid making additional copies), but the `Graph` and `GraphView` types internally implement copy-on-write optimizations to only perform the minimum number of actual copies while maintaining immutability and lifetime safety (if you allocate a `DiGraph` use for example `get_subgraph` to get a `DiGraphView` representing a part of this graph, modifications to the underlying `DiGraph` will not be mirrored in the `DiGraphView` and the `DiGraphView` will remain valid even after the base `DiGraph` leaves scope. At this point, however, we still have not discussed how to create a graph. -The user-facing graph interface is intentially separated from the underlying graph representations, so representations can be changed without requiring any user-side code modifications besides the choice of which implementation to use. -For example, to construct a `DiGDiraph` which internally uses a representation such as `AdjacencyDiGraph` we do the following: +The user-facing graph interface is intentionally separated from the underlying graph representations, so representations can be changed without requiring any user-side code modifications besides the choice of which implementation to use. +For example, to construct a `DiGraph` which internally uses a representation such as `AdjacencyDiGraph` we do the following: ```cpp DiGraph g = DiGraph::create(); ``` @@ -124,7 +121,104 @@ Generally users will use underlying representations provided by the graph librar [^1]: At some point we will likely add actual runtime checks on this, but for now we rely on the user not to mess up. Currently the implementation will keep going silently until the incorrectness grows so large that something breaks/crashes. [^2]: See if you're not familiar with the term _type coercion_ -### Open DataFlow Variant +### DataflowGraph + +The primary abstraction for representing computation graphs / task graphs is the `DataflowGraph` interface (along with its variants, `OpenDataflowGraph`, `LabelleledDataflowGraph` and `OpenLabelleledDataflowGraph`). +At a high level, nodes represent multivariate functions (from tuples of inputs to tuple of outputs), while edges represent value uses of such functions. + +`DataflowGraph` is similar to `MultiDiGraph`, but with the following important differences: + - The edges entering, exiting a given nodes have a well-defined order. + - `DataflowGraph`s are directed acyclic graphs. This is enforced by the interface used to construct them, since a node can only be added to the graph after all of its predecessor nodes have already been added. + +The main components of `DataflowGraph` are as follows: +- `DataflowInput`: used to represent the ordered sequence of incoming dependencies (arguments) of a given node (operator). +- `DataflowOutput`: used to represent the ordered sequence of outgoing results (value uses) from a given node (operator). +- `DataflowEdge`: wrapper around a `DataflowInput`, `DataflowOutput` pair between 2 nodes. +- `NodeAddedResult`: returned upon adding a new node. Contains the newly generated `Node` and the vector of `DataflowOutput`s for the given node. + +`DataflowGraph`s are constructed as follows: + +```cpp + auto g = DataflowGraph::create(); + + // Node with no inputs and 2 outputs + NodeAddedResult n1_result = g.add_node({}, 2); + Node n1 = n1_result.node; + DataflowOutput n1_o1 = n1_result.outputs[0]; + DataflowOutput n1_o2 = n1_result.outputs[1]; + + // Node with 2 inputs and 1 output + NodeAddedResult n2_result = g.add_node({n1_o1, n1_o2}, 1); + Node n2 = n2_result.node; + DataflowOutput n2_o1 = n2_result.outputs[0]; + + // Node with 1 input and 2 outputs + NodeAddedResult n3_result = g.add_node({n1_o2}, 1); + Node n3 = n3_result.node; + DataflowOutput n3_o1 = n3_result.outputs[0]; + DataflowOutput n3_o2 = n3_result.outputs[1]; + + // Node with 2 inputs and 1 output + NodeAddedResult n4_result = g.add_node({n2_o1, n3_o1}, 1); + Node n4 = n4_result.node; + DataflowOutput n4_o1 = n4_result.outputs[0]; +``` + +which generates the following graph + +```mermaid +flowchart TD + subgraph Node1[ ] + direction TB + N1Process[n1] + n1_o1((n1_o1)) + n1_o2((n1_o2)) + N1Process --> n1_o1 + N1Process --> n1_o2 + end + + subgraph Node2[ ] + direction TB + n2_i1((n2_i1)) + n2_i2((n2_i2)) + N2Process[n2] + n2_o1((o1)) + n2_i1 --> N2Process + n2_i2 --> N2Process + N2Process --> n2_o1 + end + + subgraph Node3[ ] + direction TB + n3_i1((n3_i1)) + N3Process[n3] + n3_o1((n3_o1)) + n3_o2((n3_o2)) + n3_i1 --> N3Process + N3Process --> n3_o1 + N3Process --> n3_o2 + end + + subgraph Node4[ ] + direction TB + n4_i1((n4_i1)) + n4_i2((n4_i2)) + N4Process[n4] + n4_o1((n4_o1)) + n4_i1 --> N4Process + n4_i2 --> N4Process + N4Process --> n4_o1 + end + + n1_o1 --> n2_i1 + n1_o2 --> n2_i2 + n1_o2 --> n3_i1 + n2_o1 --> n4_i1 + n3_o1 --> n4_i2 +``` + + +### Open Dataflow Variant `Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. This graph class is particularly useful for processing a sub-graph of a given graph while still maintaining information regarding the edges that cross the cut. diff --git a/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h b/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h index df5662804a..76be999b54 100644 --- a/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h +++ b/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h @@ -8,6 +8,9 @@ namespace FlexFlow { std::unordered_set get_incoming_edges(MultiDiGraphView const &, Node const &); +std::unordered_map> + get_incoming_edges(MultiDiGraphView const &g); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h b/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h index 6bc73533e7..6a8474673e 100644 --- a/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h +++ b/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h @@ -8,6 +8,9 @@ namespace FlexFlow { std::unordered_set get_outgoing_edges(MultiDiGraphView const &, Node const &); +std::unordered_map> + get_outgoing_edges(MultiDiGraphView const &g); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_reduction.h b/lib/utils/include/utils/graph/series_parallel/series_reduction.h index a7d53fecfc..0de8aecc19 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/series_reduction.h @@ -4,6 +4,7 @@ #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/series_parallel/series_reduction.dtg.h" +#include "utils/hash/vector.h" namespace FlexFlow { @@ -14,8 +15,14 @@ Node get_center_node(MultiDiGraphView const &, SeriesReduction const &); SeriesReduction make_series_reduction(MultiDiEdge const &, MultiDiEdge const &); std::optional find_series_reduction(MultiDiGraphView const &); +std::unordered_set> + find_all_extended_series_reductions(MultiDiGraphView const &g); + MultiDiEdge apply_series_reduction(MultiDiGraph &, SeriesReduction const &); +MultiDiEdge apply_extended_series_reduction( + MultiDiGraph &g, std::vector const &series_edges); + } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index be39dc158f..50818dea2f 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -1,4 +1,8 @@ #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" +#include "utils/containers/group_by.h" +#include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -7,4 +11,17 @@ std::unordered_set get_incoming_edges(MultiDiGraphView const &g, return g.query_edges(MultiDiEdgeQuery{query_set::matchall(), {n}}); } +std::unordered_map> + get_incoming_edges(MultiDiGraphView const &g) { + std::unordered_map> result = + group_by(get_edges(g), + [&](MultiDiEdge const &e) { return g.get_multidiedge_dst(e); }); + + for (Node const &n : get_nodes(g)) { + result[n]; + } + + return result; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index f98c599614..55847cf2af 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -1,5 +1,7 @@ #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" - +#include "utils/containers/group_by.h" +#include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, @@ -7,4 +9,17 @@ std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, return g.query_edges(MultiDiEdgeQuery{{n}, query_set::matchall()}); } +std::unordered_map> + get_outgoing_edges(MultiDiGraphView const &g) { + std::unordered_map> result = + group_by(get_edges(g), + [&](MultiDiEdge const &e) { return g.get_multidiedge_src(e); }); + + for (Node const &n : get_nodes(g)) { + result[n]; + } + + return result; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 7a5cb1ea82..908743fae1 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -35,6 +35,7 @@ std::optional MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); + std::unordered_map ttsp_edge_to_sp_tree = map_values( inverse_line_graph_result.inverse_edge_to_line_node_bidict @@ -42,27 +43,11 @@ std::optional [](Node const &n) { return SeriesParallelDecomposition{n}; }); while (true) { - std::optional maybe_series_reduction = - find_series_reduction(ttsp); - if (maybe_series_reduction.has_value()) { - SeriesReduction series_reduction = maybe_series_reduction.value(); - MultiDiEdge e1 = series_reduction.first; - MultiDiEdge e2 = series_reduction.second; - MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - - SeriesParallelDecomposition new_tree = serial_composition({ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }); - - ttsp_edge_to_sp_tree.erase(e1); - ttsp_edge_to_sp_tree.erase(e2); - ttsp_edge_to_sp_tree.insert({merged, new_tree}); + int reductions = 0; - continue; - } std::unordered_map> parallel_reductions = find_all_extended_parallel_reductions(ttsp); + if (!parallel_reductions.empty()) { for (auto const &[_, parallel_reduction] : parallel_reductions) { MultiDiEdge merged = @@ -71,18 +56,40 @@ std::optional SeriesParallelDecomposition new_tree = parallel_composition(transform( unordered_multiset_of(parallel_reduction), [&](MultiDiEdge const &e) { return ttsp_edge_to_sp_tree.at(e); })); + for (MultiDiEdge const &e : parallel_reduction) { ttsp_edge_to_sp_tree.erase(e); } ttsp_edge_to_sp_tree.insert({merged, new_tree}); } - continue; + reductions++; } - if (get_nodes(ttsp).size() != 2) { - return std::nullopt; + std::unordered_set> series_reductions = + find_all_extended_series_reductions(ttsp); + if (!series_reductions.empty()) { + for (std::vector series_reduction : series_reductions) { + MultiDiEdge merged = + apply_extended_series_reduction(ttsp, series_reduction); + + SeriesParallelDecomposition new_tree = serial_composition( + transform(series_reduction, [&](MultiDiEdge const &e) { + return ttsp_edge_to_sp_tree.at(e); + })); + + for (MultiDiEdge const &e : series_reduction) { + ttsp_edge_to_sp_tree.erase(e); + } + ttsp_edge_to_sp_tree.insert({merged, new_tree}); + } + reductions++; } - if (get_edges(ttsp).size() != 1) { + + if (reductions > 0) { + continue; + } + + if (get_nodes(ttsp).size() != 2 || get_edges(ttsp).size() != 1) { return std::nullopt; } diff --git a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc index c312bb4a6b..26fabe593c 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -1,13 +1,21 @@ #include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/contains.h" +#include "utils/containers/contains_key.h" #include "utils/containers/get_only.h" #include "utils/containers/require_same.h" +#include "utils/containers/subvec.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/multidigraph/multidigraph_view.h" #include "utils/graph/node/algorithms.h" +#include "utils/hash/unordered_set.h" #include namespace FlexFlow { @@ -42,6 +50,34 @@ std::optional return std::nullopt; } +std::unordered_set> + find_all_extended_series_reductions(MultiDiGraphView const &g) { + std::unordered_map> incoming_edges = + get_incoming_edges(g); + std::unordered_map> outgoing_edges = + get_outgoing_edges(g); + std::unordered_map> strands; + std::unordered_map node_to_head_of_strand; + for (Node const &n : get_topological_ordering(g)) { + if ((incoming_edges.at(n).size() == 1) && + (outgoing_edges.at(n).size() == 1)) { + MultiDiEdge incoming = get_only(incoming_edges.at(n)); + MultiDiEdge outgoing = get_only(outgoing_edges.at(n)); + Node pre = g.get_multidiedge_src(incoming); + if (contains_key(node_to_head_of_strand, pre)) { + Node head = node_to_head_of_strand.at(pre); + node_to_head_of_strand.emplace(n, head); + strands.at(head).push_back(outgoing); + } else { + node_to_head_of_strand.emplace(n, n); + strands[n].push_back(incoming); + strands[n].push_back(outgoing); + } + } + } + return unordered_set_of(values(strands)); +} + MultiDiEdge apply_series_reduction(MultiDiGraph &g, SeriesReduction const &r) { Node pre_node = get_pre_node(g, r); Node center_node = get_center_node(g, r); @@ -51,4 +87,19 @@ MultiDiEdge apply_series_reduction(MultiDiGraph &g, SeriesReduction const &r) { return g.add_edge(pre_node, post_node); } +MultiDiEdge apply_extended_series_reduction( + MultiDiGraph &g, std::vector const &series_edges) { + + Node first = g.get_multidiedge_src(series_edges.at(0)); + Node last = g.get_multidiedge_dst(series_edges.at(series_edges.size() - 1)); + + std::vector internal_nodes; + for (MultiDiEdge const &e : subvec(series_edges, std::nullopt, -1)) { + internal_nodes.push_back(g.get_multidiedge_dst(e)); + } + for (Node const &n : internal_nodes) { + g.remove_node(n); + } + return g.add_edge(first, last); +} } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc index c6b45ec6ce..3a8a5e9a60 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc @@ -1,9 +1,12 @@ #include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/set_minus.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/add_edges.h" #include "utils/graph/multidigraph/algorithms/add_nodes.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/node/algorithms.h" #include @@ -234,6 +237,155 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(returned_edge_src == correct_src); } + SUBCASE("dst") { + Node returned_edge_dst = g.get_multidiedge_dst(returned_edge); + Node correct_dst = n.at(5); + CHECK(returned_edge_dst == correct_dst); + } + } + } + } + TEST_CASE("find_all_extended_series_reductions") { + MultiDiGraph g = MultiDiGraph::create(); + + SUBCASE("linear graph") { + std::vector n = add_nodes(g, 4); + std::vector e = add_edges(g, + { + {n.at(0), n.at(1)}, + {n.at(1), n.at(2)}, + {n.at(2), n.at(3)}, + }); + + std::unordered_set> result = + find_all_extended_series_reductions(g); + std::unordered_set> correct = { + {e[0], e[1], e[2]}}; + CHECK(result == correct); + } + + SUBCASE("2 linear strands") { + std::vector n = add_nodes(g, 4); + std::vector e = add_edges(g, + {{n.at(0), n.at(1)}, + {n.at(0), n.at(2)}, + {n.at(1), n.at(3)}, + {n.at(2), n.at(3)}}); + + std::unordered_set> result = + find_all_extended_series_reductions(g); + std::unordered_set> correct = {{e[0], e[2]}, + {e[1], e[3]}}; + CHECK(result == correct); + } + + SUBCASE("graph with multiple separate serial strands") { + std::vector n = add_nodes(g, 9); + std::vector e = add_edges(g, + {{n.at(0), n.at(1)}, + {n.at(0), n.at(2)}, + {n.at(1), n.at(4)}, + {n.at(2), n.at(3)}, + {n.at(2), n.at(5)}, + {n.at(2), n.at(6)}, + {n.at(3), n.at(5)}, + {n.at(4), n.at(7)}, + {n.at(5), n.at(7)}, + {n.at(6), n.at(8)}, + {n.at(7), n.at(8)}}); + + std::unordered_set> result = + find_all_extended_series_reductions(g); + std::unordered_set> correct = { + {e[0], e[2], e[7]}, {e[3], e[6]}, {e[5], e[9]}}; + CHECK(result == correct); + } + } + + TEST_CASE("apply_extended_series_reduction") { + MultiDiGraph g = MultiDiGraph::create(); + + SUBCASE("base case") { + std::vector n = add_nodes(g, 4); + std::vector e = add_edges( + g, {{n.at(0), n.at(1)}, {n.at(1), n.at(2)}, {n.at(2), n.at(3)}}); + + std::vector reduction = {e.at(0), e.at(1), e.at(2)}; + + MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(g); + std::unordered_set correct_nodes = {n.at(0), n.at(3)}; + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(g); + std::unordered_set correct_edges = {returned_edge}; + CHECK(result_edges == correct_edges); + } + + SUBCASE("returned edge") { + SUBCASE("src") { + Node returned_edge_src = g.get_multidiedge_src(returned_edge); + Node correct_src = n.at(0); + CHECK(returned_edge_src == correct_src); + } + + SUBCASE("dst") { + Node returned_edge_dst = g.get_multidiedge_dst(returned_edge); + Node correct_dst = n.at(3); + CHECK(returned_edge_dst == correct_dst); + } + } + } + + SUBCASE("in larger graph") { + std::vector n = add_nodes(g, 8); + std::vector e = add_edges(g, + { + {n.at(0), n.at(2)}, + {n.at(1), n.at(2)}, + {n.at(2), n.at(5)}, + {n.at(2), n.at(3)}, + {n.at(3), n.at(4)}, + {n.at(4), n.at(5)}, + {n.at(5), n.at(6)}, + {n.at(5), n.at(7)}, + }); + + std::vector reduction = {e.at(3), e.at(4), e.at(5)}; + + MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(g); + std::unordered_set correct_nodes = + set_minus(unordered_set_of(n), {n.at(4), n.at(3)}); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(g); + std::unordered_set correct_edges = [&] { + std::unordered_set new_edges = unordered_set_of(e); + new_edges.erase(e.at(3)); + new_edges.erase(e.at(4)); + new_edges.erase(e.at(5)); + new_edges.insert(returned_edge); + return new_edges; + }(); + CHECK(result_edges == correct_edges); + } + + SUBCASE("returned edge") { + SUBCASE("src") { + Node returned_edge_src = g.get_multidiedge_src(returned_edge); + Node correct_src = n.at(2); + CHECK(returned_edge_src == correct_src); + } + SUBCASE("dst") { Node returned_edge_dst = g.get_multidiedge_dst(returned_edge); Node correct_dst = n.at(5);