Skip to content

Commit

Permalink
use std::forward when parameter is universal reference
Browse files Browse the repository at this point in the history
As discussed in #264
  • Loading branch information
kedixa authored and ben-clayton committed Dec 18, 2023
1 parent 535d491 commit dbf097e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
8 changes: 6 additions & 2 deletions include/marl/dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class DAGBase {
struct Node {
MARL_NO_EXPORT inline Node() = default;
MARL_NO_EXPORT inline Node(Work&& work);
MARL_NO_EXPORT inline Node(const Work& work);

// The work to perform for this node in the graph.
Work work;
Expand Down Expand Up @@ -137,6 +138,9 @@ class DAGBase {
template <typename T>
DAGBase<T>::Node::Node(Work&& work) : work(std::move(work)) {}

template <typename T>
DAGBase<T>::Node::Node(const Work& work) : work(work) {}

template <typename T>
void DAGBase<T>::initCounters(RunContext* ctx, Allocator* allocator) {
auto numCounters = initialCounters.size();
Expand Down Expand Up @@ -233,7 +237,7 @@ DAGNodeBuilder<T>::DAGNodeBuilder(DAGBuilder<T>* builder, NodeIndex index)
template <typename T>
template <typename F>
DAGNodeBuilder<T> DAGNodeBuilder<T>::then(F&& work) {
auto node = builder->node(std::move(work));
auto node = builder->node(std::forward<F>(work));
builder->addDependency(*this, node);
return node;
}
Expand Down Expand Up @@ -323,7 +327,7 @@ DAGNodeBuilder<T> DAGBuilder<T>::node(
"NodeBuilder vectors out of sync");
auto index = dag->nodes.size();
numIns.emplace_back(0);
dag->nodes.emplace_back(Node{std::move(work)});
dag->nodes.emplace_back(Node{std::forward<F>(work)});
auto node = DAGNodeBuilder<T>{this, index};
for (auto in : after) {
addDependency(in, node);
Expand Down
15 changes: 15 additions & 0 deletions src/dag_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,18 @@ TEST_P(WithBoundScheduler, DAGFanOutFanIn) {
UnorderedElementsAre("E0", "E1", "E2", "E3"));
ASSERT_THAT(data.order[11], "F");
}

TEST_P(WithBoundScheduler, DAGForwardFunc) {
marl::DAG<void>::Builder builder;
std::function<void()> func([](){});

ASSERT_TRUE(func);

auto a = builder.root()
.then(func)
.then(func);

builder.node(func, {a});

ASSERT_TRUE(func);
}

0 comments on commit dbf097e

Please sign in to comment.