Skip to content

Commit

Permalink
Adding initial support for Wave32. Currently I don't have a method to…
Browse files Browse the repository at this point in the history
… switch between Wave32 and Wave64
  • Loading branch information
seanofthemillers committed Sep 3, 2024
1 parent 7c81027 commit 345baee
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 7 deletions.
6 changes: 5 additions & 1 deletion include/RAJA/policy/hip/policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,12 @@ struct DeviceConstants
// values for HIP warp size and max block size.
//
#if defined(__HIP_PLATFORM_AMD__)
constexpr DeviceConstants device_constants(64, 1024, 64); // MI300A
#if defined(RAJA_HIP_WAVE64)
constexpr DeviceConstants device_constants(RAJA_HIP_WAVESIZE, 1024, 64); // MI300A
// constexpr DeviceConstants device_constants(64, 1024, 128); // MI250X
#else
constexpr DeviceConstants device_constants(32, 1024, 64); // Radeon cards (e.g. gfx1100)
#endif
#elif defined(__HIP_PLATFORM_NVIDIA__)
constexpr DeviceConstants device_constants(32, 1024, 32); // V100
#endif
Expand Down
10 changes: 6 additions & 4 deletions include/RAJA/policy/tensor/arch/hip/hip_wave.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ namespace expt

public:

static constexpr int s_num_elem = 64;
// static constexpr int s_num_elem = 64;
static constexpr int s_num_elem = policy::hip::device_constants.WARP_SIZE;

/*!
* @brief Default constructor, zeros register contents
Expand Down Expand Up @@ -780,8 +781,8 @@ namespace expt

// Third: mask off everything but output_segment
// this is because all output segments are valid at this point
// (5-segbits), the 5 is since the warp-width is 32 == 1<<5
int our_output_segment = get_lane()>>(6-segbits);
const int log2_warp_size = 32-1-__builtin_clz(warpSize);
int our_output_segment = get_lane()>>(log2_warp_size-segbits);
bool in_output_segment = our_output_segment == output_segment;
if(!in_output_segment){
result.get_raw_value() = 0;
Expand Down Expand Up @@ -828,8 +829,9 @@ namespace expt

// First: tree reduce values within each segment
element_type x = m_value;
const int log2_warp_size = 32-1-__builtin_clz(warpSize);
RAJA_UNROLL
for(int i = 0;i < 6-segbits; ++ i){
for(int i = 0;i < log2_warp_size-segbits; ++ i){

// tree shuffle
int delta = s_num_elem >> (i+1);
Expand Down
3 changes: 2 additions & 1 deletion include/RAJA/policy/tensor/arch/hip/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ namespace expt {
struct RegisterTraits<RAJA::expt::hip_wave_register, T>{
using element_type = T;
using register_policy = RAJA::expt::hip_wave_register;
static constexpr camp::idx_t s_num_elem = 64;
static constexpr camp::idx_t s_num_elem = policy::hip::device_constants.WARP_SIZE;

static constexpr camp::idx_t s_num_bits = sizeof(T) * s_num_elem;
using int_element_type = int32_t;
};
Expand Down
4 changes: 3 additions & 1 deletion test/include/RAJA_test-tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ struct TensorTestHelper<RAJA::expt::hip_wave_register>
void exec(BODY const &body){
hipDeviceSynchronize();

RAJA::forall<RAJA::hip_exec<64>>(RAJA::RangeSegment(0,64),
constexpr int warp_size = RAJA::policy::hip::device_constants.WARP_SIZE;

RAJA::forall<RAJA::hip_exec<warp_size>>(RAJA::RangeSegment(0,warp_size),
[=] RAJA_HOST_DEVICE (int ){
body();
});
Expand Down

0 comments on commit 345baee

Please sign in to comment.