Skip to content

Commit

Permalink
refactor: simplify chunking code (PROOF-923) (#205)
Browse files Browse the repository at this point in the history
* rework split

* refactor split

* refactor split

* refactor multiexponentiation

* refactor split code

* drop cruft

* refactor splitting code

* refactor splitting code

* refactor split code

* refactor chunking code

* refactor split

* refactor splitting code

* drop dead code

* refactor

* drop dead code

* refactor

* refactor

* refactor

* drop dead code

* drop dead code

* refactor

* drop dead code

* drop dead code

* refactor

* drop dead code

* drop dead code

* reformat
  • Loading branch information
rnburn authored Dec 12, 2024
1 parent 07f0e71 commit f785e9a
Show file tree
Hide file tree
Showing 40 changed files with 197 additions and 318 deletions.
3 changes: 1 addition & 2 deletions benchmark/memory/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ sxt_cc_binary(
"//sxt/base/container:span_utility",
"//sxt/base/device:memory_utility",
"//sxt/base/device:stream",
"//sxt/base/iterator:index_range",
"//sxt/base/iterator:index_range_utility",
"//sxt/base/iterator:split",
"//sxt/execution/async:coroutine",
"//sxt/execution/async:future",
"//sxt/execution/device:copy",
Expand Down
9 changes: 5 additions & 4 deletions benchmark/memory/copy.m.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/pinned_buffer_pool.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/iterator/index_range.h"
#include "sxt/base/iterator/index_range_iterator.h"
#include "sxt/base/iterator/index_range_utility.h"
#include "sxt/base/iterator/split.h"
#include "sxt/execution/async/coroutine.h"
#include "sxt/execution/async/future.h"
#include "sxt/execution/device/copy.h"
Expand Down Expand Up @@ -147,7 +145,10 @@ static double run_benchmark(benchmark_fn f, unsigned n, unsigned m,
}

// chunk
auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, n}, split_factor);
basit::split_options split_options{
.split_factor = split_factor,
};
auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, n}, split_options);

// invoker
memmg::managed_array<double> sum(n);
Expand Down
6 changes: 1 addition & 5 deletions sxt/algorithm/iteration/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ sxt_cc_component(
":kernel_fit",
"//sxt/algorithm/base:index_functor",
"//sxt/base/device:stream",
"//sxt/base/iterator:chunk_options",
"//sxt/base/num:divide_up",
"//sxt/execution/async:coroutine",
"//sxt/execution/device:synchronization",
Expand All @@ -54,10 +53,7 @@ sxt_cc_component(
"//sxt/base/device:memory_utility",
"//sxt/base/device:stream",
"//sxt/base/error:assert",
"//sxt/base/iterator:chunk_options",
"//sxt/base/iterator:index_range",
"//sxt/base/iterator:index_range_iterator",
"//sxt/base/iterator:index_range_utility",
"//sxt/base/iterator:split",
"//sxt/base/macro:cuda_callable",
"//sxt/base/type:value_type",
"//sxt/execution/async:coroutine",
Expand Down
17 changes: 6 additions & 11 deletions sxt/algorithm/iteration/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/error/assert.h"
#include "sxt/base/iterator/chunk_options.h"
#include "sxt/base/iterator/index_range.h"
#include "sxt/base/iterator/index_range_iterator.h"
#include "sxt/base/iterator/index_range_utility.h"
#include "sxt/base/iterator/split.h"
#include "sxt/base/macro/cuda_callable.h"
#include "sxt/base/type/value_type.h"
#include "sxt/execution/async/coroutine.h"
Expand Down Expand Up @@ -87,7 +84,7 @@ template <class F, class Arg1, class... ArgsRest>
requires algb::transform_functor_factory<F, bast::value_type_t<Arg1>,
bast::value_type_t<ArgsRest>...>
xena::future<> transform(basct::span<bast::value_type_t<Arg1>> res,
basit::chunk_options chunk_options, F make_f, const Arg1& x1,
basit::split_options split_options, F make_f, const Arg1& x1,
const ArgsRest&... xrest) noexcept {
auto n = res.size();
SXT_DEBUG_ASSERT(x1.size() == n && ((xrest.size() == n) && ...));
Expand All @@ -96,11 +93,9 @@ xena::future<> transform(basct::span<bast::value_type_t<Arg1>> res,
}
std::tuple<basct::cspan<bast::value_type_t<Arg1>>, basct::cspan<bast::value_type_t<ArgsRest>>...>
srcs{x1, xrest...};
auto full_rng = basit::index_range{0, n}
.min_chunk_size(chunk_options.min_size)
.max_chunk_size(chunk_options.max_size);
auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, n}, split_options);
co_await xendv::concurrent_for_each(
full_rng, [&](const basit::index_range& rng) noexcept -> xena::future<> {
chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> {
co_await detail::transform_impl(res.subspan(rng.a(), rng.size()), make_f, srcs, rng,
std::make_index_sequence<sizeof...(ArgsRest) + 1>{});
});
Expand All @@ -109,11 +104,11 @@ xena::future<> transform(basct::span<bast::value_type_t<Arg1>> res,
template <class F, class Arg1, class... ArgsRest>
requires algb::transform_functor<F, bast::value_type_t<Arg1>, bast::value_type_t<ArgsRest>...>
xena::future<> transform(basct::span<bast::value_type_t<Arg1>> res,
basit::chunk_options chunk_options, F f, const Arg1& x1,
basit::split_options split_options, F f, const Arg1& x1,
const ArgsRest&... xrest) noexcept {
auto make_f = [&](std::pmr::polymorphic_allocator<> /*alloc*/, basdv::stream& /*stream*/) {
return xena::make_ready_future<F>(F{f});
};
co_await transform(res, chunk_options, make_f, x1, xrest...);
co_await transform(res, split_options, make_f, x1, xrest...);
}
} // namespace sxt::algi
16 changes: 8 additions & 8 deletions sxt/algorithm/iteration/transform.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@ using namespace sxt::algi;

TEST_CASE("we can transform contigous regions of memory") {
std::vector<double> res;
basit::chunk_options chunk_options;
basit::split_options split_options;

SECTION("we handle the empty case") {
auto f = [] __device__ __host__(double& x) noexcept { x *= 2; };
auto fut = transform(res, chunk_options, f, res);
auto fut = transform(res, split_options, f, res);
REQUIRE(fut.ready());
}

SECTION("we can transform a vector with a single element") {
res = {123};
auto f = [] __device__ __host__(double& x) noexcept { x *= 2; };
auto fut = transform(res, chunk_options, f, res);
auto fut = transform(res, split_options, f, res);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 246);
Expand All @@ -48,7 +48,7 @@ TEST_CASE("we can transform contigous regions of memory") {
res = {2};
std::vector<double> y = {4};
auto f = [] __device__ __host__(double& x, double& y) noexcept { x = x + y; };
auto fut = transform(res, chunk_options, f, res, y);
auto fut = transform(res, split_options, f, res, y);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 6);
Expand All @@ -57,8 +57,8 @@ TEST_CASE("we can transform contigous regions of memory") {
SECTION("we can split transform across multiple chunks") {
res = {3, 5};
auto f = [] __device__ __host__(double& x) noexcept { x *= 2; };
chunk_options.max_size = 1;
auto fut = transform(res, chunk_options, f, res);
split_options.split_factor = 2;
auto fut = transform(res, split_options, f, res);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 6);
Expand All @@ -82,8 +82,8 @@ TEST_CASE("we can transform contigous regions of memory") {
};

res = {3, 4};
chunk_options.max_size = 1;
auto fut = transform(res, chunk_options, make_f, res);
split_options.split_factor = 2;
auto fut = transform(res, split_options, make_f, res);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 18);
Expand Down
3 changes: 3 additions & 0 deletions sxt/base/device/active_device_guard.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class active_device_guard {
public:
active_device_guard() noexcept;

explicit active_device_guard(unsigned device) noexcept
: active_device_guard{static_cast<int>(device)} {}

explicit active_device_guard(int device) noexcept;

~active_device_guard() noexcept;
Expand Down
2 changes: 1 addition & 1 deletion sxt/base/device/active_device_guard.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ TEST_CASE("we can control the active device") {
}

SECTION("we can set/unset a specific device") {
active_device_guard guard{get_num_devices() - 1};
active_device_guard guard{get_num_devices() - 1u};
}
}
6 changes: 3 additions & 3 deletions sxt/base/device/property.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ namespace sxt::basdv {
//--------------------------------------------------------------------------------------------------
// get_num_devices
//--------------------------------------------------------------------------------------------------
int get_num_devices() noexcept {
unsigned get_num_devices() noexcept {
static int num_devices = []() noexcept {
int res;
auto rcode = cudaGetDeviceCount(&res);
if (rcode != cudaSuccess) {
return 0;
return 0u;
}
return res;
return static_cast<unsigned>(res);
}();
return num_devices;
}
Expand Down
2 changes: 1 addition & 1 deletion sxt/base/device/property.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace sxt::basdv {
//--------------------------------------------------------------------------------------------------
// get_num_devices
//--------------------------------------------------------------------------------------------------
int get_num_devices() noexcept;
unsigned get_num_devices() noexcept;

//--------------------------------------------------------------------------------------------------
// get_latest_cuda_version_supported_by_driver
Expand Down
31 changes: 13 additions & 18 deletions sxt/base/iterator/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@ load(
"sxt_cc_component",
)

sxt_cc_component(
name = "chunk_options",
with_test = False,
)

sxt_cc_component(
name = "counting_iterator",
test_deps = [
Expand Down Expand Up @@ -37,30 +32,30 @@ sxt_cc_component(
)

sxt_cc_component(
name = "index_range_utility",
impl_deps = [
name = "index_range_iterator",
test_deps = [
"//sxt/base/test:unit_test",
],
deps = [
":index_range",
":index_range_iterator",
":iterator_facade",
"//sxt/base/error:assert",
"//sxt/base/num:divide_up",
],
test_deps = [
":index_range",
":index_range_iterator",
"//sxt/base/test:unit_test",
"//sxt/base/type:narrow_cast",
],
)

sxt_cc_component(
name = "index_range_iterator",
name = "split",
impl_deps = [
"//sxt/base/error:assert",
"//sxt/base/num:divide_up",
],
test_deps = [
"//sxt/base/test:unit_test",
],
deps = [
":index_range",
":iterator_facade",
"//sxt/base/error:assert",
"//sxt/base/num:divide_up",
"//sxt/base/type:narrow_cast",
":index_range_iterator",
],
)
17 changes: 0 additions & 17 deletions sxt/base/iterator/chunk_options.cc

This file was deleted.

30 changes: 0 additions & 30 deletions sxt/base/iterator/chunk_options.h

This file was deleted.

33 changes: 7 additions & 26 deletions sxt/base/iterator/index_range.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,45 +22,26 @@ namespace sxt::basit {
//--------------------------------------------------------------------------------------------------
// constructor
//--------------------------------------------------------------------------------------------------
index_range::index_range(size_t a, size_t b) noexcept
: index_range{a, b, 1, std::numeric_limits<size_t>::max(), 1} {}
index_range::index_range(size_t a, size_t b) noexcept : index_range{a, b, 1} {}

index_range::index_range(size_t a, size_t b, size_t min_chunk_size, size_t max_chunk_size,
size_t chunk_multiple) noexcept
: a_{a}, b_{b}, min_chunk_size_{min_chunk_size}, max_chunk_size_{max_chunk_size},
chunk_multiple_{chunk_multiple} {
index_range::index_range(size_t a, size_t b, size_t chunk_multiple) noexcept
: a_{a}, b_{b}, chunk_multiple_{chunk_multiple} {
SXT_DEBUG_ASSERT(
// clang-format off
0 <= a && a <= b &&
0 < min_chunk_size_ && min_chunk_size_ <= max_chunk_size_
chunk_multiple > 0
// clang-format on
);
}

//--------------------------------------------------------------------------------------------------
// min_chunk_size
//--------------------------------------------------------------------------------------------------
index_range index_range::min_chunk_size(size_t val) const noexcept {
return {
a_, b_, val, max_chunk_size_, chunk_multiple_,
};
}

//--------------------------------------------------------------------------------------------------
// max_chunk_size
//--------------------------------------------------------------------------------------------------
index_range index_range::max_chunk_size(size_t val) const noexcept {
return {
a_, b_, min_chunk_size_, val, chunk_multiple_,
};
}

//--------------------------------------------------------------------------------------------------
// chunk_multiple
//--------------------------------------------------------------------------------------------------
index_range index_range::chunk_multiple(size_t val) const noexcept {
return {
a_, b_, min_chunk_size_, max_chunk_size_, val,
a_,
b_,
val,
};
}
} // namespace sxt::basit
11 changes: 1 addition & 10 deletions sxt/base/iterator/index_range.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class index_range {

index_range(size_t a, size_t b) noexcept;

index_range(size_t a, size_t b, size_t min_chunk_size, size_t max_chunk_size,
size_t chunk_multiple) noexcept;
index_range(size_t a, size_t b, size_t chunk_multiple) noexcept;

size_t a() const noexcept { return a_; }
size_t b() const noexcept { return b_; }
Expand All @@ -39,21 +38,13 @@ class index_range {

bool operator==(const index_range&) const noexcept = default;

size_t min_chunk_size() const noexcept { return min_chunk_size_; }
size_t max_chunk_size() const noexcept { return max_chunk_size_; }
size_t chunk_multiple() const noexcept { return chunk_multiple_; }

[[nodiscard]] index_range min_chunk_size(size_t val) const noexcept;

[[nodiscard]] index_range max_chunk_size(size_t val) const noexcept;

[[nodiscard]] index_range chunk_multiple(size_t val) const noexcept;

private:
size_t a_{0};
size_t b_{0};
size_t min_chunk_size_{1};
size_t max_chunk_size_{std::numeric_limits<size_t>::max()};
size_t chunk_multiple_{1};
};
} // namespace sxt::basit
Loading

0 comments on commit f785e9a

Please sign in to comment.