diff --git a/lib/utils/include/utils/graph/labelled/labelled_open.decl.h b/lib/utils/include/utils/graph/labelled/labelled_open.decl.h index cdd22b7847..d62f0d2757 100644 --- a/lib/utils/include/utils/graph/labelled/labelled_open.decl.h +++ b/lib/utils/include/utils/graph/labelled/labelled_open.decl.h @@ -61,6 +61,7 @@ struct LabelledOpenMultiDiGraph { OutputLabel>() const; operator OpenMultiDiGraphView() const; + operator LabelledOpenMultiDiGraphView() const; friend void swap(LabelledOpenMultiDiGraph &lhs, LabelledOpenMultiDiGraph &rhs) { @@ -111,7 +112,7 @@ struct LabelledOpenMultiDiGraph { create(); private: - LabelledOpenMultiDiGraph(cow_ptr_t ptr); + LabelledOpenMultiDiGraph(cow_ptr_t ptr): ptr(ptr) {} private: cow_ptr_t ptr; diff --git a/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h b/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h index f7af522b3c..9db6db1ca9 100644 --- a/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h +++ b/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h @@ -6,6 +6,7 @@ #include "output_labelled_interfaces.h" #include "standard_labelled_interfaces.h" #include "utils/graph/open_graphs.h" +#include "views.h" namespace FlexFlow { @@ -107,6 +108,16 @@ struct UnorderedLabelledOpenMultiDiGraph this->add_edge(e); this->output_map.insert({e, label}); } + + void add_edge(InputMultiDiEdge const &e) { + OpenMultiDiEdge edge{e}; + this->base_graph.add_edge(edge); + } + + void add_edge(OutputMultiDiEdge const &e) { + OpenMultiDiEdge edge{e}; + this->base_graph.add_edge(edge); + } InputLabel const &at(InputMultiDiEdge const &e) const { return this->input_map.at(e);