Skip to content

Commit

Permalink
fix: don't split inner products that are already on device memory (PR…
Browse files Browse the repository at this point in the history
…OOF-923) (#206)

* fix inner product

* reformat

* add test
  • Loading branch information
rnburn authored Dec 13, 2024
1 parent f785e9a commit 6f83d9e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
6 changes: 6 additions & 0 deletions sxt/scalar25/operation/inner_product.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,14 @@ xena::future<s25t::element> async_inner_product_impl(basct::cspan<s25t::element>
basct::cspan<s25t::element> rhs,
size_t split_factor, size_t min_chunk_size,
size_t max_chunk_size) noexcept {
SXT_DEBUG_ASSERT(
(basdv::is_host_pointer(lhs.data()) && basdv::is_host_pointer(rhs.data())) ||
(basdv::is_active_device_pointer(lhs.data()) && basdv::is_active_device_pointer(rhs.data())));
auto n = std::min(lhs.size(), rhs.size());
SXT_DEBUG_ASSERT(n > 0);
if (basdv::is_active_device_pointer(lhs.data())) {
co_return co_await async_inner_product_partial(lhs.subspan(0, n), rhs.subspan(0, n));
}
s25t::element res = s25t::element::identity();

basit::split_options split_options{
Expand Down
14 changes: 5 additions & 9 deletions sxt/scalar25/operation/inner_product.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,20 @@ TEST_CASE("we can compute inner products asynchronously on the GPU") {
REQUIRE(res.value() == expected_res);
}

SECTION("async inner product works with both device and host points") {
SECTION("we can split a GPU inner product into smaller chunks") {
size_t n = 100;
make_dataset(a_host, b_host, a_dev, b_dev, rng, n);
auto res1 = async_inner_product(a_dev, b_host);
auto res2 = async_inner_product(a_host, b_dev);
auto res3 = async_inner_product(a_host, b_host);
auto res = async_inner_product_impl(a_host, b_host, 4, 1, 10);
s25t::element expected_res;
inner_product(expected_res, a_host, b_host);
xens::get_scheduler().run();
REQUIRE(res1.value() == expected_res);
REQUIRE(res2.value() == expected_res);
REQUIRE(res3.value() == expected_res);
REQUIRE(res.value() == expected_res);
}

SECTION("we can split a GPU inner product into smaller chunks") {
SECTION("we don't split when inputs are already in device memory") {
size_t n = 100;
make_dataset(a_host, b_host, a_dev, b_dev, rng, n);
auto res = async_inner_product_impl(a_dev, b_host, 4, 1, 10);
auto res = async_inner_product_impl(a_dev, b_dev, 4, 1, 10);
s25t::element expected_res;
inner_product(expected_res, a_host, b_host);
xens::get_scheduler().run();
Expand Down

0 comments on commit 6f83d9e

Please sign in to comment.