Skip to content

Commit

Permalink
added recurse_n
Browse files Browse the repository at this point in the history
  • Loading branch information
Pietro Max Marsella committed Dec 21, 2024
1 parent 93298ed commit 29c9848
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
34 changes: 34 additions & 0 deletions lib/utils/include/utils/containers/recurse_n.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RECURSE_N_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RECURSE_N_H

#include "utils/exception.h"

namespace FlexFlow {

/**
* @brief
* Applies function `f` to value `initial_value` n times recursively.
*
* @example
* auto add_three = [](int x) { return x + 3; };
* int result = recurse_n(add_three, 3, 5);
* result -> f(f(f(5))) = ((5+3)+3)+3 = 14
*
* @throws RuntimeError if n is negative
*/
template <typename F, typename T>
T recurse_n(F const &f, int n, T const &initial_value) {
if (n < 0) {
throw mk_runtime_error(
fmt::format("Supplied n={} should be non-negative", n));
}
T t = initial_value;
for (int i = 0; i < n; i++) {
t = f(t);
}
return t;
}

} // namespace FlexFlow

#endif
1 change: 1 addition & 0 deletions lib/utils/src/utils/containers/recurse_n.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "utils/containers/recurse_n.h"
29 changes: 29 additions & 0 deletions lib/utils/test/src/utils/containers/recurse_n.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "utils/containers/recurse_n.h"
#include <doctest/doctest.h>
#include <string>

using namespace FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("recurse_n") {
auto append_bar = [](std::string const &x) {
return x + std::string("Bar");
};

SUBCASE("n = 0") {
std::string result = recurse_n(append_bar, 0, std::string("Foo"));
std::string correct = "Foo";
CHECK(result == correct);
}

SUBCASE("n = 3") {
std::string result = recurse_n(append_bar, 3, std::string("Foo"));
std::string correct = "FooBarBarBar";
CHECK(result == correct);
}

SUBCASE("n < 0") {
CHECK_THROWS(recurse_n(append_bar, -1, std::string("Foo")));
}
}
}

0 comments on commit 29c9848

Please sign in to comment.