Skip to content

Commit

Permalink
try a fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Feb 29, 2024
1 parent 38dfd87 commit 355d4b4
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/runtime/optimizer_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,12 @@ __host__ void AdamOptimizer::nccl_unified_update_task_gpu(
cudaEventRecord(t_start1, stream);
cudaEventRecord(t_start2, stream);

void *workSpace_ptr = meta->handle.workSpace;
void *allocate_ptr;
// = meta->handle.workSpace;
checkCUDA(
cudaMalloc(&allocate_ptr,meta->handle.workSpaceSize));

void *workSpace_ptr = allocate_ptr;

for (int i = 0; i < op->parameters_num; i++) {
cudaMemcpyAsync(workSpace_ptr,
Expand Down Expand Up @@ -272,7 +277,8 @@ __host__ void AdamOptimizer::nccl_unified_update_task_gpu(
cudaEventDestroy(t_start2);
printf("[optimizer] allreduce time = %.2lfms\n", elapsed);

workSpace_ptr = static_cast<char *>(meta->handle.workSpace);
// workSpace_ptr = static_cast<char *>(meta->handle.workSpace);
workSpace_ptr = static_cast<char *>(allocate_ptr);
float alpha_t = op->alpha_t;
float beta1_t = op->beta1_t;
float beta2_t = op->beta2_t;
Expand Down Expand Up @@ -304,6 +310,7 @@ __host__ void AdamOptimizer::nccl_unified_update_task_gpu(
checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end));
cudaEventDestroy(t_start);
cudaEventDestroy(t_end);
checkCUDA(cudaFree(allocate_ptr));
printf("[optimizer] total time = %.2lfms\n", elapsed);
}
#endif
Expand Down

0 comments on commit 355d4b4

Please sign in to comment.