Skip to content

Commit

Permalink
fix hip
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Sep 13, 2024
1 parent 355d4b4 commit 8488ba0
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 4 deletions.
5 changes: 4 additions & 1 deletion include/flexflow/utils/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ template <typename DT>
__global__ void copy_kernel(DT *dst, const DT *src, Legion::coord_t size);

template <typename DT>
__global__ void copy_kernel_with_replicate(DT *dst, const DT *src, Legion::coord_t origin_size, Legion::coord_t size);
__global__ void copy_kernel_with_replicate(DT *dst,
const DT *src,
Legion::coord_t origin_size,
Legion::coord_t size);

template <typename T>
__global__ void add_kernel(T *data_ptr, T const *grad_ptr, size_t size);
Expand Down
6 changes: 6 additions & 0 deletions include/flexflow/utils/hip_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ __global__ void assign_kernel(DT *ptr, Legion::coord_t size, DT value);
template <typename DT>
__global__ void copy_kernel(DT *dst, const DT *src, Legion::coord_t size);

template <typename DT>
__global__ void copy_kernel_with_replicate(DT *dst,
const DT *src,
Legion::coord_t origin_size,
Legion::coord_t size);

template <typename T>
__global__ void add_kernel(T *data_ptr, T const *grad_ptr, size_t size);

Expand Down
10 changes: 7 additions & 3 deletions src/dataloader/dataloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ void SingleDataLoader::load_input(Task const *task,
int num_dims = full_input_domain.get_dim();
assert(num_dims + 1 == batch_input_domain.get_dim());
// assert the leading replica dim has a degree of one
assert(batch_input_domain.hi()[num_dims] ==
batch_input_domain.lo()[num_dims]);
// assert(batch_input_domain.hi()[num_dims] ==
// batch_input_domain.lo()[num_dims]);
coord_t batch_size = batch_input_domain.hi()[num_dims - 1] -
batch_input_domain.lo()[num_dims - 1] + 1;
coord_t replicate_num =
batch_input_domain.hi()[num_dims] - batch_input_domain.lo()[num_dims] + 1;
coord_t num_elements_per_batch = batch_input_domain.get_volume() / batch_size;
// FIXME: currently assume continous indices
assert(batch_size == meta->num_samples);
Expand All @@ -61,13 +63,15 @@ void SingleDataLoader::load_input(Task const *task,
// printf("ptr(%p, %p), idx0 %d nb_elements_per_batch %d, batch_size %d,
// %d\n", acc_full_input.ptr, input_zc, start_idx, num_elements_per_batch,
// batch_size, start_idx * num_elements_per_batch);
hipLaunchKernelGGL(HIP_KERNEL_NAME(copy_kernel<DT>),
assert(batch_input_domain.get_volume() % replicate_num == 0);
hipLaunchKernelGGL(HIP_KERNEL_NAME(copy_kernel_with_replicate<DT>),
GET_BLOCKS(batch_input_domain.get_volume()),
CUDA_NUM_THREADS,
0,
stream,
batch_input_ptr,
input_zc,
batch_input_domain.get_volume() / replicate_num,
batch_input_domain.get_volume());
checkCUDA(hipDeviceSynchronize());
}
Expand Down
20 changes: 20 additions & 0 deletions src/runtime/hip_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ __global__ void copy_kernel(DT *dst, const DT *src, coord_t size) {
}
}

template <typename DT>
__global__ void copy_kernel_with_replicate(DT *dst,
const DT *src,
coord_t origin_size,
coord_t size) {
CUDA_KERNEL_LOOP(i, size) {
dst[i] = src[i % origin_size];
}
}

template <typename DT>
__global__ void reluBackward(DT *grad_ptr, const DT *output, size_t n) {
CUDA_KERNEL_LOOP(i, n) {
Expand Down Expand Up @@ -404,6 +414,16 @@ template __global__ void

template __global__ void
copy_kernel<float>(float *dst, float const *src, coord_t size);

template __global__ void copy_kernel_with_replicate<float>(float *dst,
float const *src,
coord_t origin_size,
coord_t size);
template __global__ void copy_kernel_with_replicate<int32_t>(
int32_t *dst, int32_t const *src, coord_t origin_size, coord_t size);
template __global__ void copy_kernel_with_replicate<int64_t>(
int64_t *dst, int64_t const *src, coord_t origin_size, coord_t size);

template __global__ void
copy_kernel<int32_t>(int32_t *dst, int32_t const *src, coord_t size);
template __global__ void
Expand Down

0 comments on commit 8488ba0

Please sign in to comment.