Skip to content

Commit

Permalink
chore(gpu): try resetting device to gpu 0 after sync local streams
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Nov 25, 2024
1 parent 2a18d6f commit de3e09d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ __host__ void host_unsigned_integer_div_rem_kb(
cuda_synchronize_stream(mem_ptr->sub_streams_1[j], gpu_indexes[j]);
cuda_synchronize_stream(mem_ptr->sub_streams_2[j], gpu_indexes[j]);
}
cudaSetDevice(gpu_indexes[0]);
}

template <typename Torus>
Expand Down Expand Up @@ -688,6 +689,7 @@ __host__ void host_integer_div_rem_kb(cudaStream_t const *streams,
cuda_synchronize_stream(int_mem_ptr->sub_streams_1[j], gpu_indexes[j]);
cuda_synchronize_stream(int_mem_ptr->sub_streams_2[j], gpu_indexes[j]);
}
cudaSetDevice(gpu_indexes[0]);
} else {
host_unsigned_integer_div_rem_kb<Torus>(
streams, gpu_indexes, gpu_count, quotient, remainder, numerator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
cuda_synchronize_stream(mem->local_streams_1[j], gpu_indexes[j]);
cuda_synchronize_stream(mem->local_streams_2[j], gpu_indexes[j]);
}
cudaSetDevice(gpu_indexes[0]);
}
} else {
PANIC("Cuda error (scalar shift): left scalar shift is never of the "
Expand Down

0 comments on commit de3e09d

Please sign in to comment.