diff --git a/lib/utils/include/utils/containers/recurse_n.h b/lib/utils/include/utils/containers/recurse_n.h new file mode 100644 index 0000000000..8dc22cb8a8 --- /dev/null +++ b/lib/utils/include/utils/containers/recurse_n.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RECURSE_N_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RECURSE_N_H + +#include "utils/exception.h" + +namespace FlexFlow { + +/** + * @brief + * Applies function `f` to value `initial_value` n times recursively. + * + * @example + * auto add_three = [](int x) { return x + 3; }; + * int result = recurse_n(add_three, 3, 5); + * result -> f(f(f(5))) = ((5+3)+3)+3 = 14 + * + * @throws RuntimeError if n is negative + */ +template +T recurse_n(F const &f, int n, T const &initial_value) { + if (n < 0) { + throw mk_runtime_error( + fmt::format("Supplied n={} should be non-negative", n)); + } + T t = initial_value; + for (int i = 0; i < n; i++) { + t = f(t); + } + return t; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/containers/recurse_n.cc b/lib/utils/src/utils/containers/recurse_n.cc new file mode 100644 index 0000000000..d00ddcae8f --- /dev/null +++ b/lib/utils/src/utils/containers/recurse_n.cc @@ -0,0 +1 @@ +#include "utils/containers/recurse_n.h" diff --git a/lib/utils/test/src/utils/containers/recurse_n.cc b/lib/utils/test/src/utils/containers/recurse_n.cc new file mode 100644 index 0000000000..1805ee891f --- /dev/null +++ b/lib/utils/test/src/utils/containers/recurse_n.cc @@ -0,0 +1,29 @@ +#include "utils/containers/recurse_n.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("recurse_n") { + auto append_bar = [](std::string const &x) { + return x + std::string("Bar"); + }; + + SUBCASE("n = 0") { + std::string result = recurse_n(append_bar, 0, std::string("Foo")); + std::string correct = "Foo"; + CHECK(result == correct); + } + + SUBCASE("n = 3") { + std::string result = recurse_n(append_bar, 3, std::string("Foo")); + std::string correct = "FooBarBarBar"; + CHECK(result == correct); + } + + SUBCASE("n < 0") { + CHECK_THROWS(recurse_n(append_bar, -1, std::string("Foo"))); + } + } +}