Skip to content

Commit

Permalink
feat: chunk combination of multiexponentiation partials (PROOF-923) (#…
Browse files Browse the repository at this point in the history
…207)

* fill in combine reduce

* rework combine reduce

* fill in tests

* rework combine reduce

* add tests

* reformat

* drop unused header

* add assertion checks
  • Loading branch information
rnburn authored Dec 13, 2024
1 parent 6f83d9e commit ff56d21
Show file tree
Hide file tree
Showing 4 changed files with 345 additions and 0 deletions.
28 changes: 28 additions & 0 deletions sxt/multiexp/pippenger2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,34 @@ sxt_cc_component(
],
)

sxt_cc_component(
name = "combine_reduce",
test_deps = [
"//sxt/base/curve:example_element",
"//sxt/base/device:stream",
"//sxt/base/device:synchronization",
"//sxt/base/test:unit_test",
"//sxt/execution/schedule:scheduler",
"//sxt/memory/resource:managed_device_resource",
],
deps = [
"//sxt/algorithm/iteration:for_each",
"//sxt/base/container:span",
"//sxt/base/container:span_utility",
"//sxt/base/curve:element",
"//sxt/base/error:assert",
"//sxt/base/iterator:split",
"//sxt/base/macro:cuda_callable",
"//sxt/base/type:raw_stream",
"//sxt/execution/async:coroutine",
"//sxt/execution/device:copy",
"//sxt/execution/device:for_each",
"//sxt/execution/device:synchronization",
"//sxt/memory/management:managed_array",
"//sxt/memory/resource:async_device_resource",
],
)

sxt_cc_component(
name = "partition_product",
test_deps = [
Expand Down
17 changes: 17 additions & 0 deletions sxt/multiexp/pippenger2/combine_reduce.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/multiexp/pippenger2/combine_reduce.h"
195 changes: 195 additions & 0 deletions sxt/multiexp/pippenger2/combine_reduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cassert>
#include <numeric>

#include "sxt/algorithm/iteration/for_each.h"
#include "sxt/base/container/span.h"
#include "sxt/base/container/span_utility.h"
#include "sxt/base/curve/element.h"
#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/property.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/error/assert.h"
#include "sxt/base/iterator/split.h"
#include "sxt/execution/async/coroutine.h"
#include "sxt/execution/async/future.h"
#include "sxt/execution/device/copy.h"
#include "sxt/execution/device/for_each.h"
#include "sxt/memory/management/managed_array.h"
#include "sxt/memory/resource/async_device_resource.h"

namespace sxt::mtxpp2 {
//--------------------------------------------------------------------------------------------------
// combine_reduce_chunk_kernel
//--------------------------------------------------------------------------------------------------
template <bascrv::element T>
__device__ void combine_reduce_chunk_kernel(T* __restrict__ res, const T* __restrict__ partials,
const unsigned* __restrict__ bit_table_partial_sums,
unsigned num_partials, unsigned reduction_size,
unsigned partials_offset,
unsigned output_index) noexcept {
auto output_index_p = umax(output_index, 1u) - 1u;
auto output_correction = bit_table_partial_sums[output_index_p] * (output_index != 0) +
partials_offset * (output_index == 0);
auto bit_width = bit_table_partial_sums[output_index] - output_correction;
assert(bit_width > 0);

// adjust points
res += output_index;
partials += bit_table_partial_sums[output_index] - partials_offset;

// combine reduce
unsigned bit_index = bit_width - 1u;
--partials;
T e = *partials;
for (unsigned reduction_index = 1; reduction_index < reduction_size; ++reduction_index) {
auto ep = partials[reduction_index * num_partials];
add_inplace(e, ep);
}
for (; bit_index-- > 0u;) {
--partials;
double_element(e, e);
for (unsigned reduction_index = 0; reduction_index < reduction_size; ++reduction_index) {
auto ep = partials[reduction_index * num_partials];
add_inplace(e, ep);
}
}
*res = e;
}

//--------------------------------------------------------------------------------------------------
// combine_reduce_chunk
//--------------------------------------------------------------------------------------------------
template <bascrv::element T>
xena::future<> combine_reduce_chunk(basct::span<T> res,
basct::cspan<unsigned> output_bit_table_partial_sums,
basct::cspan<T> partial_products, unsigned reduction_size,
unsigned partials_offset) noexcept {
auto num_partials = partial_products.size() / reduction_size;
auto num_outputs = output_bit_table_partial_sums.size();
unsigned slice_num_partials = output_bit_table_partial_sums[num_outputs - 1] - partials_offset;
SXT_RELEASE_ASSERT(
// clang-format off
num_outputs > 0 &&
res.size() == num_outputs &&
output_bit_table_partial_sums.size() == num_outputs &&
partial_products.size() == num_partials * reduction_size &&
partials_offset < output_bit_table_partial_sums[num_outputs-1]
// clang-format on
);
basdv::stream stream;

// copy data
memr::async_device_resource resource{stream};
memmg::managed_array<T> partials_dev_data{&resource};
basct::cspan<T> partials_dev = partial_products;
if (!basdv::is_active_device_pointer(partials_dev.data())) {
partials_dev_data.resize(slice_num_partials * reduction_size);
co_await xendv::strided_copy_host_to_device<T>(partials_dev_data, stream, partial_products,
num_partials, slice_num_partials,
partials_offset);
partials_dev = partials_dev_data;
} else {
SXT_RELEASE_ASSERT(partial_products.size() == slice_num_partials * reduction_size);
}
memmg::managed_array<unsigned> bit_table_partial_sums_dev{num_outputs, &resource};
basdv::async_copy_host_to_device(bit_table_partial_sums_dev, output_bit_table_partial_sums,
stream);

// combine reduce chunk
memmg::managed_array<T> res_dev{num_outputs, &resource};
auto f = [
// clang-format off
num_partials = slice_num_partials,
reduction_size = reduction_size,
partials_offset = partials_offset,
res = res_dev.data(),
partials = partials_dev.data(),
bit_table_partial_sums = bit_table_partial_sums_dev.data()
// clang-format on
] __device__
__host__(unsigned /*num_outputs*/, unsigned output_index) noexcept {
combine_reduce_chunk_kernel(res, partials, bit_table_partial_sums, num_partials,
reduction_size, partials_offset, output_index);
};
algi::launch_for_each_kernel(stream, f, num_outputs);
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
}

//--------------------------------------------------------------------------------------------------
// combine_reduce
//--------------------------------------------------------------------------------------------------
template <bascrv::element T>
xena::future<> combine_reduce(basct::span<T> res, const basit::split_options& split_options,
basct::cspan<unsigned> output_bit_table,
basct::cspan<T> partial_products) noexcept {
auto num_outputs = output_bit_table.size();
SXT_RELEASE_ASSERT(
// clang-format off
res.size() == num_outputs &&
output_bit_table.size() == num_outputs
// clang-format on
);

if (res.empty()) {
co_return;
}

// partials
memmg::managed_array<unsigned> bit_table_partial_sums(num_outputs);
std::partial_sum(output_bit_table.begin(), output_bit_table.end(),
bit_table_partial_sums.begin());
auto reduction_size = partial_products.size() / bit_table_partial_sums[num_outputs - 1];

// don't split if partials are already in device memory
if (basdv::is_active_device_pointer(partial_products.data())) {
co_return co_await combine_reduce_chunk(res, bit_table_partial_sums, partial_products,
reduction_size, 0);
}

// split
auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, num_outputs}, split_options);

// combine reduce
co_await xendv::concurrent_for_each(
chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> {
auto output_first = rng.a();

auto res_chunk = res.subspan(output_first, rng.size());
auto bit_table_partial_sums_chunk =
basct::subspan(bit_table_partial_sums, output_first, rng.size());
auto partials_offset = output_first > 0 ? bit_table_partial_sums[output_first - 1] : 0u;

co_await combine_reduce_chunk(res_chunk, bit_table_partial_sums_chunk, partial_products,
reduction_size, partials_offset);
});
}

template <bascrv::element T>
xena::future<> combine_reduce(basct::span<T> res, basct::cspan<unsigned> output_bit_table,
basct::cspan<T> partial_products) noexcept {
basit::split_options split_options{
.max_chunk_size = 1024,
.split_factor = basdv::get_num_devices(),
};
co_await combine_reduce(res, split_options, output_bit_table, partial_products);
}
} // namespace sxt::mtxpp2
105 changes: 105 additions & 0 deletions sxt/multiexp/pippenger2/combine_reduce.t.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/multiexp/pippenger2/combine_reduce.h"

#include <vector>

#include "sxt/base/curve/example_element.h"
#include "sxt/base/test/unit_test.h"
#include "sxt/execution/schedule/scheduler.h"
#include "sxt/memory/resource/managed_device_resource.h"

using namespace sxt;
using namespace sxt::mtxpp2;

TEST_CASE("we can combine and reduce partial products") {
using E = bascrv::element97;

std::vector<unsigned> output_bit_table;
std::vector<E> partial_products;
std::vector<E> res(1);

SECTION("we handle no outputs") {
res.clear();
auto fut = combine_reduce<E>(res, output_bit_table, partial_products);
REQUIRE(fut.ready());
}

SECTION("we can combine and reduce a single element") {
output_bit_table = {1};
partial_products = {3u};
auto fut = combine_reduce<E>(res, output_bit_table, partial_products);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 3u);
}

SECTION("we can combine and reduce elements already on device") {
output_bit_table = {1};
partial_products = {3u};
std::pmr::vector<E> partial_products_dev{partial_products.begin(), partial_products.end(),
memr::get_managed_device_resource()};
auto fut = combine_reduce<E>(res, output_bit_table, partial_products_dev);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 3u);
}

SECTION("we can combine and reduce a single output with a reduction size of two") {
output_bit_table = {1};
partial_products = {3u, 4u};
auto fut = combine_reduce<E>(res, output_bit_table, partial_products);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 7u);
}

SECTION("we can combine and reduce an output with a bit width of 2") {
output_bit_table = {2};
partial_products = {3u, 4u};
auto fut = combine_reduce<E>(res, output_bit_table, partial_products);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 11u);
}

SECTION("we can combine and reduce multiple outputs") {
output_bit_table = {1, 1};
partial_products = {3u, 4u};
res.resize(2);
auto fut = combine_reduce<E>(res, output_bit_table, partial_products);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 3u);
REQUIRE(res[1] == 4u);
}

SECTION("we can combine and reduce in chunks") {
output_bit_table = {1, 1};
partial_products = {3u, 4u};
res.resize(2);
basit::split_options split_options{
.max_chunk_size = 1u,
.split_factor = 2u,
};
auto fut = combine_reduce<E>(res, split_options, output_bit_table, partial_products);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 3u);
REQUIRE(res[1] == 4u);
}
}

0 comments on commit ff56d21

Please sign in to comment.