Skip to content

Commit

Permalink
feat: add strided host-to-device copy (PROOF-923) (#204)
Browse files Browse the repository at this point in the history
* work on buffer pool

* test pinned buffer pool

* fill in testing

* fill in strided copy testing

* fill in strided copy testing

* fill in tests

* fill in strided copy

* fill in strided copy

* test strided copy

* test strided copy

* rename

* add benchmark

* fill in benchmark

* reformat
  • Loading branch information
rnburn authored Dec 10, 2024
1 parent 38b002d commit 07f0e71
Show file tree
Hide file tree
Showing 6 changed files with 526 additions and 0 deletions.
27 changes: 27 additions & 0 deletions benchmark/memory/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
load("//bazel:sxt_build_system.bzl", "sxt_cc_binary")

sxt_cc_binary(
name = "copy",
srcs = [
"copy.m.cc",
],
visibility = ["//visibility:public"],
deps = [
"//sxt/algorithm/iteration:for_each",
"//sxt/base/container:span",
"//sxt/base/container:span_utility",
"//sxt/base/device:memory_utility",
"//sxt/base/device:stream",
"//sxt/base/iterator:index_range",
"//sxt/base/iterator:index_range_utility",
"//sxt/execution/async:coroutine",
"//sxt/execution/async:future",
"//sxt/execution/device:copy",
"//sxt/execution/device:for_each",
"//sxt/execution/device:synchronization",
"//sxt/execution/schedule:scheduler",
"//sxt/memory/management:managed_array",
"//sxt/memory/resource:async_device_resource",
"//sxt/memory/resource:pinned_resource",
],
)
193 changes: 193 additions & 0 deletions benchmark/memory/copy.m.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <chrono>
#include <cstring>
#include <print>
#include <random>

#include "sxt/algorithm/iteration/for_each.h"
#include "sxt/base/container/span.h"
#include "sxt/base/container/span_utility.h"
#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/pinned_buffer_pool.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/iterator/index_range.h"
#include "sxt/base/iterator/index_range_iterator.h"
#include "sxt/base/iterator/index_range_utility.h"
#include "sxt/execution/async/coroutine.h"
#include "sxt/execution/async/future.h"
#include "sxt/execution/device/copy.h"
#include "sxt/execution/device/for_each.h"
#include "sxt/execution/device/synchronization.h"
#include "sxt/execution/schedule/scheduler.h"
#include "sxt/memory/management/managed_array.h"
#include "sxt/memory/resource/async_device_resource.h"
#include "sxt/memory/resource/pinned_resource.h"

using namespace sxt;

// sum1
static xena::future<> sum1(basct::span<double> res, basct::cspan<double> data, unsigned n,
unsigned a) noexcept {
auto chunk_size = res.size();
auto m = data.size() / n;

basdv::stream stream;
memr::async_device_resource resource{stream};

// copy
memmg::managed_array<double> data_dev{res.size() * m, &resource};
for (unsigned i = 0; i < m; ++i) {
basdv::async_copy_host_to_device(basct::subspan(data_dev, i * chunk_size, chunk_size),
basct::subspan(data, i * n + a, chunk_size), stream);
}

// sum
memmg::managed_array<double> res_dev{chunk_size, &resource};
auto f = [res = res_dev.data(), data = data_dev.data(), m = m] __device__ __host__(
unsigned chunk_size, unsigned i) noexcept {
double sum = 0;
for (unsigned j = 0; j < m; ++j) {
sum += data[i + chunk_size * j];
}
res[i] = sum;
};
algi::launch_for_each_kernel(stream, f, chunk_size);
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
}

// sum2
static xena::future<> sum2(basct::span<double> res, basct::cspan<double> data, unsigned n,
unsigned a) noexcept {
auto chunk_size = res.size();
auto m = data.size() / n;

basdv::stream stream;
memr::async_device_resource resource{stream};

// copy
memmg::managed_array<double> data_p(res.size() * m);
memmg::managed_array<double> data_dev{res.size() * m, &resource};
for (unsigned i = 0; i < m; ++i) {
std::memcpy(static_cast<void*>(data_p.data() + chunk_size * i),
static_cast<const void*>(data.data() + a + n * i), chunk_size * sizeof(double));
}
basdv::async_copy_host_to_device(data_dev, data_p, stream);

// sum
memmg::managed_array<double> res_dev{chunk_size, &resource};
auto f = [res = res_dev.data(), data = data_dev.data(), m = m] __device__ __host__(
unsigned chunk_size, unsigned i) noexcept {
double sum = 0;
for (unsigned j = 0; j < m; ++j) {
sum += data[i + chunk_size * j];
}
res[i] = sum;
};
algi::launch_for_each_kernel(stream, f, chunk_size);
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
}

// sum3
static xena::future<> sum3(basct::span<double> res, basct::cspan<double> data, unsigned n,
unsigned a) noexcept {
auto chunk_size = res.size();
auto m = data.size() / n;

basdv::stream stream;
memr::async_device_resource resource{stream};

// copy
memmg::managed_array<double> data_dev{res.size() * m, &resource};
co_await xendv::strided_copy_host_to_device<double>(data_dev, stream, data, n, chunk_size, a);

// sum
memmg::managed_array<double> res_dev{chunk_size, &resource};
auto f = [res = res_dev.data(), data = data_dev.data(), m = m] __device__ __host__(
unsigned chunk_size, unsigned i) noexcept {
double sum = 0;
for (unsigned j = 0; j < m; ++j) {
sum += data[i + chunk_size * j];
}
res[i] = sum;
};
algi::launch_for_each_kernel(stream, f, chunk_size);
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
}

using benchmark_fn = xena::future<> (*)(basct::span<double>, basct::cspan<double>, unsigned,
unsigned);

// run_benchmark
static double run_benchmark(benchmark_fn f, unsigned n, unsigned m,
unsigned split_factor) noexcept {
// fill data
memmg::managed_array<double> data(n * m);
std::mt19937 rng{0};
std::uniform_real_distribution<double> dist{-1, 1};
for (auto& x : data) {
x = dist(rng);
}

// chunk
auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, n}, split_factor);

// invoker
memmg::managed_array<double> sum(n);
auto invoker = [&] {
auto fut = xendv::concurrent_for_each(
chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> {
co_await f(basct::subspan(sum, rng.a(), rng.size()), data, n, rng.a());
});
xens::get_scheduler().run();
};

// initial run
invoker();

// average
auto avg = 0.0;
unsigned num_iterations = 10;
for (unsigned i = 0; i < num_iterations; ++i) {
auto t1 = std::chrono::steady_clock::now();
invoker();
auto t2 = std::chrono::steady_clock::now();
auto elapse = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count() / 1.0e3;
avg += elapse;
}
return avg / num_iterations;
}

int main() {
const unsigned n = 1'000'00;
const unsigned m = 32;
const unsigned split_factor = 16;

auto avg_elapse = run_benchmark(sum1, n, m, split_factor);
std::println("sum1: average elapse: {}", avg_elapse);

avg_elapse = run_benchmark(sum2, n, m, split_factor);
std::println("sum2: average elapse: {}", avg_elapse);

avg_elapse = run_benchmark(sum3, n, m, split_factor);
std::println("sum3: average elapse: {}", avg_elapse);

return 0;
}
24 changes: 24 additions & 0 deletions sxt/execution/device/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,30 @@ sxt_cc_component(
],
)

sxt_cc_component(
name = "copy",
impl_deps = [
":synchronization",
"//sxt/base/device:pinned_buffer",
"//sxt/base/device:memory_utility",
"//sxt/base/device:stream",
"//sxt/execution/async:coroutine",
],
test_deps = [
"//sxt/base/device:pinned_buffer",
"//sxt/base/device:stream",
"//sxt/base/device:synchronization",
"//sxt/base/test:unit_test",
"//sxt/execution/schedule:scheduler",
"//sxt/memory/resource:managed_device_resource",
],
deps = [
"//sxt/base/container:span",
"//sxt/base/error:assert",
"//sxt/execution/async:future",
],
)

sxt_cc_component(
name = "synchronization",
test_deps = [
Expand Down
106 changes: 106 additions & 0 deletions sxt/execution/device/copy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/execution/device/copy.h"

#include <cassert>
#include <cstring>

#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/pinned_buffer.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/error/assert.h"
#include "sxt/execution/async/coroutine.h"
#include "sxt/execution/device/synchronization.h"

namespace sxt::xendv {
//--------------------------------------------------------------------------------------------------
// strided_copy_host_to_device_one_sweep
//--------------------------------------------------------------------------------------------------
static xena::future<> strided_copy_host_to_device_one_sweep(std::byte* dst,
const basdv::stream& stream,
const std::byte* src, size_t n,
size_t count, size_t stride) noexcept {
auto num_bytes = n * count;
if (num_bytes == 0) {
co_return;
}
basdv::pinned_buffer buffer;
auto data = static_cast<std::byte*>(buffer.data());
for (size_t i = 0; i < count; ++i) {
std::memcpy(data, src, n);
data += n;
src += stride;
}
basdv::async_memcpy_host_to_device(static_cast<void*>(dst), buffer.data(), num_bytes, stream);
co_await await_stream(stream);
}

//--------------------------------------------------------------------------------------------------
// strided_copy_host_to_device
//--------------------------------------------------------------------------------------------------
xena::future<> strided_copy_host_to_device(std::byte* dst, const basdv::stream& stream,
const std::byte* src, size_t n, size_t count,
size_t stride) noexcept {
SXT_RELEASE_ASSERT(
// clang-format off
basdv::is_active_device_pointer(dst) &&
basdv::is_host_pointer(src) &&
stride >= n
// clang-format on
);
auto num_bytes = n * count;
if (num_bytes <= basdv::pinned_buffer::size()) {
co_return co_await strided_copy_host_to_device_one_sweep(dst, stream, src, n, count, stride);
}
auto cur_n = n;

auto fill_buffer = [&](basdv::pinned_buffer& buffer) noexcept {
size_t remaining_size = buffer.size();
auto data = static_cast<std::byte*>(buffer.data());
while (remaining_size > 0 && count > 0) {
auto chunk_size = std::min(remaining_size, cur_n);
std::memcpy(data, src, chunk_size);
src += chunk_size;
data += chunk_size;
remaining_size -= chunk_size;
cur_n -= chunk_size;
if (cur_n == 0) {
--count;
cur_n = n;
src += stride - n;
}
}
return buffer.size() - remaining_size;
};

// copy
basdv::pinned_buffer cur_buffer, alt_buffer;
auto chunk_size = fill_buffer(cur_buffer);
SXT_DEBUG_ASSERT(count > 0, "copy can't be done in a single sweep");
while (count > 0) {
basdv::async_memcpy_host_to_device(static_cast<void*>(dst), cur_buffer.data(), chunk_size,
stream);
dst += chunk_size;
chunk_size = fill_buffer(alt_buffer);
co_await await_stream(stream);
std::swap(cur_buffer, alt_buffer);
}
basdv::async_memcpy_host_to_device(static_cast<void*>(dst), cur_buffer.data(), chunk_size,
stream);
co_await await_stream(stream);
}
} // namespace sxt::xendv
Loading

0 comments on commit 07f0e71

Please sign in to comment.