diff --git a/sxt/scalar25/operation/inner_product.cc b/sxt/scalar25/operation/inner_product.cc index d4eda1fe..cc6bf686 100644 --- a/sxt/scalar25/operation/inner_product.cc +++ b/sxt/scalar25/operation/inner_product.cc @@ -74,8 +74,14 @@ xena::future async_inner_product_impl(basct::cspan basct::cspan rhs, size_t split_factor, size_t min_chunk_size, size_t max_chunk_size) noexcept { + SXT_DEBUG_ASSERT( + (basdv::is_host_pointer(lhs.data()) && basdv::is_host_pointer(rhs.data())) || + (basdv::is_active_device_pointer(lhs.data()) && basdv::is_active_device_pointer(rhs.data()))); auto n = std::min(lhs.size(), rhs.size()); SXT_DEBUG_ASSERT(n > 0); + if (basdv::is_active_device_pointer(lhs.data())) { + co_return co_await async_inner_product_partial(lhs.subspan(0, n), rhs.subspan(0, n)); + } s25t::element res = s25t::element::identity(); basit::split_options split_options{ diff --git a/sxt/scalar25/operation/inner_product.t.cc b/sxt/scalar25/operation/inner_product.t.cc index 15681527..96ff4e0a 100644 --- a/sxt/scalar25/operation/inner_product.t.cc +++ b/sxt/scalar25/operation/inner_product.t.cc @@ -121,24 +121,20 @@ TEST_CASE("we can compute inner products asynchronously on the GPU") { REQUIRE(res.value() == expected_res); } - SECTION("async inner product works with both device and host points") { + SECTION("we can split a GPU inner product into smaller chunks") { size_t n = 100; make_dataset(a_host, b_host, a_dev, b_dev, rng, n); - auto res1 = async_inner_product(a_dev, b_host); - auto res2 = async_inner_product(a_host, b_dev); - auto res3 = async_inner_product(a_host, b_host); + auto res = async_inner_product_impl(a_host, b_host, 4, 1, 10); s25t::element expected_res; inner_product(expected_res, a_host, b_host); xens::get_scheduler().run(); - REQUIRE(res1.value() == expected_res); - REQUIRE(res2.value() == expected_res); - REQUIRE(res3.value() == expected_res); + REQUIRE(res.value() == expected_res); } - SECTION("we can split a GPU inner product into smaller chunks") { + SECTION("we don't split when inputs are already in device memory") { size_t n = 100; make_dataset(a_host, b_host, a_dev, b_dev, rng, n); - auto res = async_inner_product_impl(a_dev, b_host, 4, 1, 10); + auto res = async_inner_product_impl(a_dev, b_dev, 4, 1, 10); s25t::element expected_res; inner_product(expected_res, a_host, b_host); xens::get_scheduler().run();