Skip to content

Commit

Permalink
add reduction support to dynamic_forall
Browse files Browse the repository at this point in the history
  • Loading branch information
artv3 committed Oct 31, 2024
1 parent 2fcd22e commit 01da58c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 38 deletions.
33 changes: 15 additions & 18 deletions examples/dynamic-forall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,28 +77,25 @@ int main(int argc, char *argv[])

//----------------------------------------------------------------------------//

std::cout << "\n Running C-style vector addition...\n";

// _cstyle_vector_add_start
for (int i = 0; i < N; ++i) {
c[i] = a[i] + b[i];
}
// _cstyle_vector_add_end

checkResult(c, N);
//printResult(c, N);


//----------------------------------------------------------------------------//
// Example of dynamic policy selection for forall
//----------------------------------------------------------------------------//
std::cout << "\n Running dynamic forall vector addition and reductions...\n";

int sum = 0;
using VAL_INT_SUM = RAJA::expt::ValOp<int, RAJA::operators::plus>;

RAJA::RangeSegment range(0, N);

//policy is chosen from the list
RAJA::expt::dynamic_forall<policy_list>(pol, RAJA::RangeSegment(0, N), [=] RAJA_HOST_DEVICE (int i) {
RAJA::expt::dynamic_forall<policy_list>(pol, range,
RAJA::expt::Reduce<RAJA::operators::plus>(&sum),
RAJA::expt::KernelName("RAJA dynamic forall"),
[=] RAJA_HOST_DEVICE (int i, VAL_INT_SUM &_sum) {

c[i] = a[i] + b[i];
_sum += 1;
});
// _rajaseq_vector_add_end

std::cout<<"Sum = "<<sum<<", expected sum: "<<N<<std::endl;
checkResult(c, N);
//printResult(c, N);

Expand Down Expand Up @@ -126,9 +123,9 @@ void checkResult(int* res, int len)
if ( res[i] != 0 ) { correct = false; }
}
if ( correct ) {
std::cout << "\n\t result -- PASS\n";
std::cout << "\n\t Vector sum result -- PASS\n";
} else {
std::cout << "\n\t result -- FAIL\n";
std::cout << "\n\t Vector sum result -- FAIL\n";
}
}

Expand Down
40 changes: 20 additions & 20 deletions include/RAJA/pattern/forall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,63 +653,63 @@ namespace expt
template<camp::idx_t IDX, typename POLICY_LIST>
struct dynamic_helper
{
template<typename SEGMENT, typename BODY>
static void invoke_forall(const int pol, SEGMENT const &seg, BODY const &body)
template<typename SEGMENT, typename... PARAMS>
static void invoke_forall(const int pol, SEGMENT const &seg, PARAMS&&... params)
{
if(IDX==pol){
using t_pol = typename camp::at<POLICY_LIST,camp::num<IDX>>::type;
RAJA::forall<t_pol>(seg, body);
RAJA::forall<t_pol>(seg, params...);
return;
}
dynamic_helper<IDX-1, POLICY_LIST>::invoke_forall(pol, seg, body);
dynamic_helper<IDX-1, POLICY_LIST>::invoke_forall(pol, seg, params...);
}

template<typename SEGMENT, typename BODY>
template<typename SEGMENT, typename... PARAMS>
static resources::EventProxy<resources::Resource>
invoke_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, BODY const &body)
invoke_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, PARAMS&&... params)
{

using t_pol = typename camp::at<POLICY_LIST,camp::num<IDX>>::type;
using resource_type = typename resources::get_resource<t_pol>::type;

if(IDX==pol){
RAJA::forall<t_pol>(r.get<resource_type>(), seg, body);
RAJA::forall<t_pol>(r.get<resource_type>(), seg, params...);

//Return a generic event proxy from r,
//because forall returns a typed event proxy
return {r};
}

return dynamic_helper<IDX-1, POLICY_LIST>::invoke_forall(r, pol, seg, body);
return dynamic_helper<IDX-1, POLICY_LIST>::invoke_forall(r, pol, seg, params...);
}

};

template<typename POLICY_LIST>
struct dynamic_helper<0, POLICY_LIST>
{
template<typename SEGMENT, typename BODY>
template<typename SEGMENT, typename... PARAMS>
static void
invoke_forall(const int pol, SEGMENT const &seg, BODY const &body)
invoke_forall(const int pol, SEGMENT const &seg, PARAMS&&... params)
{
if(0==pol){
using t_pol = typename camp::at<POLICY_LIST,camp::num<0>>::type;
RAJA::forall<t_pol>(seg, body);
RAJA::forall<t_pol>(seg, params...);
return;
}
RAJA_ABORT_OR_THROW("Policy enum not supported ");
}

template<typename SEGMENT, typename BODY>
template<typename SEGMENT, typename... PARAMS>
static resources::EventProxy<resources::Resource>
invoke_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, BODY const &body)
invoke_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, PARAMS&&... params)
{
if(pol != 0) RAJA_ABORT_OR_THROW("Policy value out of range ");

using t_pol = typename camp::at<POLICY_LIST,camp::num<0>>::type;
using resource_type = typename resources::get_resource<t_pol>::type;

RAJA::forall<t_pol>(r.get<resource_type>(), seg, body);
RAJA::forall<t_pol>(r.get<resource_type>(), seg, params...);

//Return a generic event proxy from r,
//because forall returns a typed event proxy
Expand All @@ -718,21 +718,21 @@ namespace expt

};

template<typename POLICY_LIST, typename SEGMENT, typename BODY>
void dynamic_forall(const int pol, SEGMENT const &seg, BODY const &body)
template<typename POLICY_LIST, typename SEGMENT, typename... PARAMS>
void dynamic_forall(const int pol, SEGMENT const &seg, PARAMS&&... params)
{
constexpr int N = camp::size<POLICY_LIST>::value;
static_assert(N > 0, "RAJA policy list must not be empty");

if(pol > N-1) {
RAJA_ABORT_OR_THROW("Policy enum not supported");
}
dynamic_helper<N-1, POLICY_LIST>::invoke_forall(pol, seg, body);
dynamic_helper<N-1, POLICY_LIST>::invoke_forall(pol, seg, params...);
}

template<typename POLICY_LIST, typename SEGMENT, typename BODY>
template<typename POLICY_LIST, typename SEGMENT, typename... PARAMS>
resources::EventProxy<resources::Resource>
dynamic_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, BODY const &body)
dynamic_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, PARAMS&&... params)
{
constexpr int N = camp::size<POLICY_LIST>::value;
static_assert(N > 0, "RAJA policy list must not be empty");
Expand All @@ -741,7 +741,7 @@ namespace expt
RAJA_ABORT_OR_THROW("Policy value out of range");
}

return dynamic_helper<N-1, POLICY_LIST>::invoke_forall(r, pol, seg, body);
return dynamic_helper<N-1, POLICY_LIST>::invoke_forall(r, pol, seg, params...);
}

} // namespace expt
Expand Down

0 comments on commit 01da58c

Please sign in to comment.