Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Library #1521

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lib/utils/include/utils/containers/find.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FIND_H

#include <algorithm>
#include <unordered_set>

namespace FlexFlow {

Expand All @@ -11,6 +12,12 @@ typename Container::const_iterator
return std::find(c.cbegin(), c.cend(), e);
}

template <typename V>
typename std::unordered_set<V>::const_iterator
find(std::unordered_set<V> const &c, V const &e) {
return c.find(e);
}

} // namespace FlexFlow

#endif
260 changes: 144 additions & 116 deletions lib/utils/include/utils/graph/README.md

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ struct DataflowGraph : virtual public DataflowGraphView {
private:
IDataflowGraph &get_interface();
IDataflowGraph const &get_interface() const;

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
15 changes: 15 additions & 0 deletions lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,22 @@

namespace FlexFlow {

/**
* @brief See https://en.wikipedia.org/wiki/Dominator_(graph_theory)
*
* @note By definition, the root node dominates every node and every node
* dominates itself.
*
*/
std::unordered_set<Node> get_dominators(DiGraphView const &, Node const &);

/**
* @brief Returns the intersection of the dominators of the given set of nodes.
* @note This is conceptually equivalent to merging the given set of nodes and
* then finding the set of dominators of the new merged node (where merged means
* that all edges belonging to the set of nodes now pass through a single
* unified node).
*/
std::unordered_set<Node> get_dominators(DiGraphView const &,
std::unordered_set<Node> const &);

Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/digraph/digraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ struct DiGraph : virtual DiGraphView {
private:
IDiGraph &get_ptr();
IDiGraph const &get_ptr() const;

friend struct GraphInternal;
};
CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraph);

Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/digraph/digraph_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ struct DiGraphView : virtual public GraphView {

private:
IDiGraphView const &get_ptr() const;

friend struct GraphInternal;
};
CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ namespace FlexFlow {
std::unordered_set<MultiDiEdge> get_incoming_edges(MultiDiGraphView const &,
Node const &);

std::unordered_map<Node, std::unordered_set<MultiDiEdge>>
get_incoming_edges(MultiDiGraphView const &g);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ namespace FlexFlow {
std::unordered_set<MultiDiEdge> get_outgoing_edges(MultiDiGraphView const &,
Node const &);

std::unordered_map<Node, std::unordered_set<MultiDiEdge>>
get_outgoing_edges(MultiDiGraphView const &g);

} // namespace FlexFlow

#endif
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/multidigraph/multidigraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ struct MultiDiGraph : virtual public MultiDiGraphView {
private:
IMultiDiGraph &get_interface();
IMultiDiGraph const &get_interface() const;

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/node/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ struct Graph : virtual GraphView {
private:
IGraph const &get_ptr() const;
IGraph &get_ptr();

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/node/graph_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ struct GraphView {
GraphView();
cow_ptr_t<IGraphView> ptr;
GraphView(cow_ptr_t<IGraphView> ptr);

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
1 change: 1 addition & 0 deletions lib/utils/include/utils/graph/node/node.struct.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ features = [
"hash",
"fmt",
"json",
"rapidcheck",
]

includes = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &,
std::optional<ParallelReduction>
find_parallel_reduction(MultiDiGraphView const &);

std::unordered_map<DirectedEdge, std::unordered_set<MultiDiEdge>>
find_all_extended_parallel_reductions(MultiDiGraphView const &);

MultiDiEdge apply_parallel_reduction(MultiDiGraph &, ParallelReduction const &);

MultiDiEdge
apply_extended_parallel_reduction(MultiDiGraph &,
std::unordered_set<MultiDiEdge> const &);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,25 @@ std::unordered_multiset<Node> get_nodes(SeriesSplit const &);
std::unordered_multiset<Node> get_nodes(ParallelSplit const &);
std::unordered_multiset<Node> get_nodes(Node const &);

bool is_empty(Node const &node);
bool is_empty(SeriesSplit const &serial);
bool is_empty(ParallelSplit const &parallel);
bool is_empty(SeriesParallelDecomposition const &sp);

bool has_no_duplicate_nodes(SeriesParallelDecomposition const &sp);

SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp,
Node const &node);

// duplicate nodes within `sp` are counted multiple times
size_t num_nodes(SeriesParallelDecomposition const &sp);

SeriesParallelDecomposition serial_composition(
std::vector<SeriesParallelDecomposition> const &sp_compositions);
SeriesParallelDecomposition parallel_composition(
std::unordered_multiset<SeriesParallelDecomposition> const
&sp_compositions);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "utils/graph/multidigraph/multidiedge.dtg.h"
#include "utils/graph/multidigraph/multidigraph.h"
#include "utils/graph/series_parallel/series_reduction.dtg.h"
#include "utils/hash/vector.h"

namespace FlexFlow {

Expand All @@ -14,8 +15,14 @@ Node get_center_node(MultiDiGraphView const &, SeriesReduction const &);
SeriesReduction make_series_reduction(MultiDiEdge const &, MultiDiEdge const &);
std::optional<SeriesReduction> find_series_reduction(MultiDiGraphView const &);

std::unordered_set<std::vector<MultiDiEdge>>
find_all_extended_series_reductions(MultiDiGraphView const &g);

MultiDiEdge apply_series_reduction(MultiDiGraph &, SeriesReduction const &);

MultiDiEdge apply_extended_series_reduction(
MultiDiGraph &g, std::vector<MultiDiEdge> const &series_edges);

} // namespace FlexFlow

#endif
27 changes: 3 additions & 24 deletions lib/utils/include/utils/graph/undirected/undirected_edge.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,12 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_H

#include "utils/graph/node/node.dtg.h"
namespace FlexFlow {

struct UndirectedEdge {
public:
UndirectedEdge() = delete;
UndirectedEdge(Node const &src, Node const &dst);
#include "utils/graph/undirected/undirected_edge.dtg.h"

bool operator==(UndirectedEdge const &) const;
bool operator!=(UndirectedEdge const &) const;
bool operator<(UndirectedEdge const &) const;

public:
Node smaller;
Node bigger;
};
namespace FlexFlow {

bool is_connected_to(UndirectedEdge const &, Node const &);
bool is_connected_to(UndirectedEdge const &e, Node const &n);

} // namespace FlexFlow

namespace std {

template <>
struct hash<::FlexFlow::UndirectedEdge> {
size_t operator()(::FlexFlow::UndirectedEdge const &) const;
};

} // namespace std

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace = "FlexFlow"
name = "UndirectedEdge"
features = [
"eq",
"ord",
"hash",
"fmt",
"rapidcheck"
]

includes = [
"utils/commutative_pair.h",
"utils/graph/node/node.dtg.h",
]

[[fields]]
name = "endpoints"
type = "::FlexFlow::commutative_pair<::FlexFlow::Node>"
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/undirected/undirected_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ struct UndirectedGraph : virtual UndirectedGraphView {

using UndirectedGraphView::UndirectedGraphView;

friend struct GraphInternal;

private:
IUndirectedGraph const &get_ptr() const;
IUndirectedGraph &get_ptr();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ struct UndirectedGraphView : virtual GraphView {

using GraphView::GraphView;

friend struct GraphInternal;

private:
IUndirectedGraphView const &get_ptr() const;
};
Expand Down
Loading
Loading