diff --git a/src/parallel_ops/kernels/allreduce_kernels.cu b/src/parallel_ops/kernels/allreduce_kernels.cu index 2dc1caf19f..0e5c15008e 100644 --- a/src/parallel_ops/kernels/allreduce_kernels.cu +++ b/src/parallel_ops/kernels/allreduce_kernels.cu @@ -142,6 +142,9 @@ void inference_kernel_wrapper(Context ctx, int device_id = m->handle.device_id; ncclComm_t ncclComm = m->handle.ncclComm; DataType dtype = input.data_type; + if (num_elements == 0) { + return; + } tensorrt_llm::AllReduceStrategyType strategy = tensorrt_llm::SelectImplementation(