From 8687b69769ab3957ffe841e670a248573bdb27a4 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Fri, 6 Dec 2024 10:31:12 +0100 Subject: [PATCH] fix(gpu): fix single gpu on device other than 0 --- tfhe/src/core_crypto/gpu/slice.rs | 100 +++++++++--------- tfhe/src/core_crypto/gpu/vec.rs | 12 +-- .../src/integer/gpu/server_key/radix/ilog2.rs | 53 ++++------ tfhe/src/integer/gpu/server_key/radix/mod.rs | 48 ++------- 4 files changed, 86 insertions(+), 127 deletions(-) diff --git a/tfhe/src/core_crypto/gpu/slice.rs b/tfhe/src/core_crypto/gpu/slice.rs index 8f13e22354..668c608f29 100644 --- a/tfhe/src/core_crypto/gpu/slice.rs +++ b/tfhe/src/core_crypto/gpu/slice.rs @@ -45,11 +45,11 @@ where /// /// The caller must ensure that the slice outlives the pointer this function returns, /// or else it will end up pointing to garbage. - pub(crate) unsafe fn as_c_ptr(&self, gpu_index: u32) -> *const c_void { - self.ptrs[gpu_index as usize] + pub(crate) unsafe fn as_c_ptr(&self, index: usize) -> *const c_void { + self.ptrs[index] } - pub(crate) fn gpu_index(&self, index: u32) -> GpuIndex { - self.gpu_indexes[index as usize] + pub(crate) fn gpu_index(&self, index: usize) -> GpuIndex { + self.gpu_indexes[index] } } @@ -75,16 +75,16 @@ where /// /// The caller must ensure that the slice outlives the pointer this function returns, /// or else it will end up pointing to garbage. - pub(crate) unsafe fn as_mut_c_ptr(&mut self, gpu_index: u32) -> *mut c_void { - self.ptrs[gpu_index as usize] + pub(crate) unsafe fn as_mut_c_ptr(&mut self, index: usize) -> *mut c_void { + self.ptrs[index] } /// # Safety /// /// The caller must ensure that the slice outlives the pointer this function returns, /// or else it will end up pointing to garbage. - pub(crate) unsafe fn as_c_ptr(&self, index: u32) -> *const c_void { - self.ptrs[index as usize].cast_const() + pub(crate) unsafe fn as_c_ptr(&self, index: usize) -> *const c_void { + self.ptrs[index].cast_const() } /// Copies data between two `CudaSlice` @@ -93,24 +93,20 @@ where /// /// - [CudaStreams::synchronize] __must__ be called after the copy as soon as synchronization is /// required. - pub unsafe fn copy_from_gpu_async( - &mut self, - src: &Self, - streams: &CudaStreams, - stream_index: u32, - ) where + pub unsafe fn copy_from_gpu_async(&mut self, src: &Self, streams: &CudaStreams, index: usize) + where T: Numeric, { - assert_eq!(self.len(stream_index), src.len(stream_index)); - let size = src.len(stream_index) * std::mem::size_of::(); + assert_eq!(self.len(index), src.len(index)); + let size = src.len(index) * std::mem::size_of::(); // We check that src is not empty to avoid invalid pointers if size > 0 { cuda_memcpy_async_gpu_to_gpu( - self.as_mut_c_ptr(stream_index), - src.as_c_ptr(stream_index), + self.as_mut_c_ptr(index), + src.as_c_ptr(index), size as u64, - streams.ptr[stream_index as usize], - streams.gpu_indexes[stream_index as usize].0, + streams.ptr[index], + streams.gpu_indexes[index].0, ); } } @@ -122,107 +118,107 @@ where /// /// - [CudaStreams::synchronize] __must__ be called after the copy as soon as synchronization is /// required. - pub unsafe fn copy_to_cpu_async(&self, dest: &mut [T], streams: &CudaStreams, stream_index: u32) + pub unsafe fn copy_to_cpu_async(&self, dest: &mut [T], streams: &CudaStreams, index: usize) where T: Numeric, { - assert_eq!(self.len(stream_index), dest.len()); - let size = self.len(stream_index) * std::mem::size_of::(); + assert_eq!(self.len(index), dest.len()); + let size = self.len(index) * std::mem::size_of::(); // We check that src is not empty to avoid invalid pointers if size > 0 { cuda_memcpy_async_to_cpu( dest.as_mut_ptr().cast::(), - self.as_c_ptr(stream_index), + self.as_c_ptr(index), size as u64, - streams.ptr[stream_index as usize], - streams.gpu_indexes[stream_index as usize].0, + streams.ptr[index], + streams.gpu_indexes[index].0, ); } } /// Returns the number of elements in the vector, also referred to as its ‘length’. - pub fn len(&self, index: u32) -> usize { - self.lengths[index as usize] + pub fn len(&self, index: usize) -> usize { + self.lengths[index] } /// Returns true if the ptr is empty - pub fn is_empty(&self, index: u32) -> bool { - self.lengths[index as usize] == 0 + pub fn is_empty(&self, index: usize) -> bool { + self.lengths[index] == 0 } - pub(crate) fn get_mut(&mut self, range: R, index: GpuIndex) -> Option> + pub(crate) fn get_mut(&mut self, range: R, index: usize) -> Option> where R: std::ops::RangeBounds, T: Numeric, { - let (start, end) = range_bounds_to_start_end(self.len(index.0), range).into_inner(); + let (start, end) = range_bounds_to_start_end(self.len(index), range).into_inner(); // Check the range is compatible with the vec - if end <= start || end > self.lengths[index.0 as usize] - 1 { + if end <= start || end > self.lengths[index] - 1 { None } else { // Shift ptr let shifted_ptr: *mut c_void = - self.ptrs[index.0 as usize].wrapping_byte_add(start * std::mem::size_of::()); + self.ptrs[index].wrapping_byte_add(start * std::mem::size_of::()); // Compute the length let new_len = end - start + 1; // Create the slice - Some(unsafe { - CudaSliceMut::new(shifted_ptr, new_len, self.gpu_indexes[index.0 as usize]) - }) + Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len, self.gpu_indexes[index]) }) } } pub(crate) fn split_at_mut( &mut self, mid: usize, - index: GpuIndex, + index: usize, ) -> (Option>, Option>) where T: Numeric, { // Check the index is compatible with the vec - if mid > self.lengths[index.0 as usize] - 1 { + if mid > self.lengths[index] - 1 { (None, None) } else if mid == 0 { ( None, Some(unsafe { CudaSliceMut::new( - self.ptrs[index.0 as usize], - self.lengths[index.0 as usize], - index, + self.ptrs[index], + self.lengths[index], + self.gpu_indexes[index], ) }), ) - } else if mid == self.lengths[index.0 as usize] - 1 { + } else if mid == self.lengths[index] - 1 { ( Some(unsafe { CudaSliceMut::new( - self.ptrs[index.0 as usize], - self.lengths[index.0 as usize], - index, + self.ptrs[index], + self.lengths[index], + self.gpu_indexes[index], ) }), None, ) } else { let new_len_1 = mid; - let new_len_2 = self.lengths[index.0 as usize] - mid; + let new_len_2 = self.lengths[index] - mid; // Shift ptr let shifted_ptr: *mut c_void = - self.ptrs[index.0 as usize].wrapping_byte_add(mid * std::mem::size_of::()); + self.ptrs[index].wrapping_byte_add(mid * std::mem::size_of::()); // Create the slice ( - Some(unsafe { CudaSliceMut::new(self.ptrs[index.0 as usize], new_len_1, index) }), - Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len_2, index) }), + Some(unsafe { + CudaSliceMut::new(self.ptrs[index], new_len_1, self.gpu_indexes[index]) + }), + Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len_2, self.gpu_indexes[index]) }), ) } } - pub(crate) fn gpu_index(&self, index: u32) -> GpuIndex { - self.gpu_indexes[index as usize] + pub(crate) fn gpu_index(&self, index: usize) -> GpuIndex { + self.gpu_indexes[index] } } diff --git a/tfhe/src/core_crypto/gpu/vec.rs b/tfhe/src/core_crypto/gpu/vec.rs index 8063457f53..d9a84ca205 100644 --- a/tfhe/src/core_crypto/gpu/vec.rs +++ b/tfhe/src/core_crypto/gpu/vec.rs @@ -361,7 +361,7 @@ impl CudaVec { self.ptr[index as usize].cast_const() } - pub(crate) fn as_slice(&self, range: R, index: u32) -> Option> + pub(crate) fn as_slice(&self, range: R, index: usize) -> Option> where R: std::ops::RangeBounds, T: Numeric, @@ -374,19 +374,19 @@ impl CudaVec { } else { // Shift ptr let shifted_ptr: *mut c_void = - self.ptr[index as usize].wrapping_byte_add(start * std::mem::size_of::()); + self.ptr[index].wrapping_byte_add(start * std::mem::size_of::()); // Compute the length let new_len = end - start + 1; // Create the slice - Some(unsafe { CudaSlice::new(shifted_ptr, new_len, GpuIndex(index)) }) + Some(unsafe { CudaSlice::new(shifted_ptr, new_len, self.gpu_indexes[index]) }) } } // clippy complains as we only manipulate pointers, but we want to keep rust semantics #[allow(clippy::needless_pass_by_ref_mut)] - pub(crate) fn as_mut_slice(&mut self, range: R, index: u32) -> Option> + pub(crate) fn as_mut_slice(&mut self, range: R, index: usize) -> Option> where R: std::ops::RangeBounds, T: Numeric, @@ -399,13 +399,13 @@ impl CudaVec { } else { // Shift ptr let shifted_ptr: *mut c_void = - self.ptr[index as usize].wrapping_byte_add(start * std::mem::size_of::()); + self.ptr[index].wrapping_byte_add(start * std::mem::size_of::()); // Compute the length let new_len = end - start + 1; // Create the slice - Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len, GpuIndex(index)) }) + Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len, self.gpu_indexes[index]) }) } } diff --git a/tfhe/src/integer/gpu/server_key/radix/ilog2.rs b/tfhe/src/integer/gpu/server_key/radix/ilog2.rs index 9c0fb78920..2697ed268d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/ilog2.rs +++ b/tfhe/src/integer/gpu/server_key/radix/ilog2.rs @@ -40,8 +40,7 @@ impl CudaServerKey { let lwe_size = ct.as_ref().d_blocks.0.lwe_dimension.to_lwe_size().0; // Allocate the necessary amount of memory - let mut tmp_radix = - CudaVec::new_async(num_ct_blocks * lwe_size, streams, streams.gpu_indexes[0].0); + let mut tmp_radix = CudaVec::new_async(num_ct_blocks * lwe_size, streams, 0); let lut = match direction { Direction::Trailing => self.generate_lookup_table(|x| { @@ -70,13 +69,9 @@ impl CudaServerKey { }), }; - tmp_radix.copy_from_gpu_async( - &ct.as_ref().d_blocks.0.d_vec, - streams, - streams.gpu_indexes[0].0, - ); + tmp_radix.copy_from_gpu_async(&ct.as_ref().d_blocks.0.d_vec, streams, 0); let mut output_slice = tmp_radix - .as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0].0) + .as_mut_slice(0..lwe_size * num_ct_blocks, 0) .unwrap(); let input_slice = ct @@ -84,7 +79,7 @@ impl CudaServerKey { .d_blocks .0 .d_vec - .as_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0].0) + .as_slice(0..lwe_size * num_ct_blocks, 0) .unwrap(); // Assign to each block its number of leading/trailing zeros/ones @@ -175,7 +170,7 @@ impl CudaServerKey { ); let mut generates_or_propagates = tmp_radix - .as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0].0) + .as_mut_slice(0..lwe_size * num_ct_blocks, 0) .unwrap(); match &self.bootstrapping_key { @@ -185,7 +180,7 @@ impl CudaServerKey { &mut output_cts .0 .d_vec - .as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0].0) + .as_mut_slice(0..lwe_size * num_ct_blocks, 0) .unwrap(), &mut generates_or_propagates, sum_lut.acc.acc.as_ref(), @@ -214,7 +209,7 @@ impl CudaServerKey { &mut output_cts .0 .d_vec - .as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0].0) + .as_mut_slice(0..lwe_size * num_ct_blocks, 0) .unwrap(), &mut generates_or_propagates, sum_lut.acc.acc.as_ref(), @@ -294,16 +289,13 @@ impl CudaServerKey { .d_blocks .0 .d_vec - .as_mut_slice(0..lwe_size, streams.gpu_indexes[0].0) + .as_mut_slice(0..lwe_size, 0) .unwrap(); let src_slice = leading_count_per_blocks .0 .d_vec - .as_mut_slice( - (i * lwe_size)..((i + 1) * lwe_size), - streams.gpu_indexes[0].0, - ) + .as_mut_slice((i * lwe_size)..((i + 1) * lwe_size), 0) .unwrap(); dest_slice.copy_from_gpu_async(&src_slice, streams, 0); cts.push(new_item); @@ -537,16 +529,13 @@ impl CudaServerKey { .d_blocks .0 .d_vec - .as_mut_slice(0..lwe_size, streams.gpu_indexes[0].0) + .as_mut_slice(0..lwe_size, 0) .unwrap(); let src_slice = leading_zeros_per_blocks .0 .d_vec - .as_mut_slice( - (i * lwe_size)..((i + 1) * lwe_size), - streams.gpu_indexes[0].0, - ) + .as_mut_slice((i * lwe_size)..((i + 1) * lwe_size), 0) .unwrap(); dest_slice.copy_from_gpu_async(&src_slice, streams, 0); cts.push(new_item); @@ -582,14 +571,14 @@ impl CudaServerKey { let mut message_blocks_slice = message_blocks .0 .d_vec - .as_mut_slice(0..lwe_size * counter_num_blocks, streams.gpu_indexes[0].0) + .as_mut_slice(0..lwe_size * counter_num_blocks, 0) .unwrap(); let result_slice = result .as_mut() .d_blocks .0 .d_vec - .as_slice(0..lwe_size * counter_num_blocks, streams.gpu_indexes[0].0) + .as_slice(0..lwe_size * counter_num_blocks, 0) .unwrap(); match &self.bootstrapping_key { @@ -666,7 +655,7 @@ impl CudaServerKey { .d_blocks .0 .d_vec - .as_mut_slice(0..lwe_size, streams.gpu_indexes[0].0) + .as_mut_slice(0..lwe_size, 0) .unwrap(); let mut carry_blocks_last = carry_blocks @@ -674,16 +663,16 @@ impl CudaServerKey { .d_vec .as_mut_slice( lwe_size * (counter_num_blocks - 1)..lwe_size * counter_num_blocks, - streams.gpu_indexes[0].0, + 0, ) .unwrap(); - carry_blocks_last.copy_from_gpu_async(&trivial_last_block_slice, streams, 0u32); + carry_blocks_last.copy_from_gpu_async(&trivial_last_block_slice, streams, 0); let mut carry_blocks_slice = carry_blocks .0 .d_vec - .as_mut_slice(0..lwe_size * counter_num_blocks, streams.gpu_indexes[0].0) + .as_mut_slice(0..lwe_size * counter_num_blocks, 0) .unwrap(); unsafe { match &self.bootstrapping_key { @@ -747,13 +736,13 @@ impl CudaServerKey { .d_blocks .0 .d_vec - .as_mut_slice(0..counter_num_blocks * lwe_size, streams.gpu_indexes[0].0) + .as_mut_slice(0..counter_num_blocks * lwe_size, 0) .unwrap(); let src_slice = message_blocks .0 .d_vec - .as_mut_slice(0..(counter_num_blocks * lwe_size), streams.gpu_indexes[0].0) + .as_mut_slice(0..(counter_num_blocks * lwe_size), 0) .unwrap(); dest_slice.copy_from_gpu_async(&src_slice, streams, 0); @@ -767,13 +756,13 @@ impl CudaServerKey { .d_blocks .0 .d_vec - .as_mut_slice(0..counter_num_blocks * lwe_size, streams.gpu_indexes[0].0) + .as_mut_slice(0..counter_num_blocks * lwe_size, 0) .unwrap(); let src_slice = carry_blocks .0 .d_vec - .as_mut_slice(0..(counter_num_blocks * lwe_size), streams.gpu_indexes[0].0) + .as_mut_slice(0..(counter_num_blocks * lwe_size), 0) .unwrap(); dest_slice.copy_from_gpu_async(&src_slice, streams, 0); diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index 5d4d17b641..0ea12c791d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -889,17 +889,9 @@ impl CudaServerKey { .d_blocks .0 .d_vec - .as_slice( - lwe_size * block_range.start..lwe_size * block_range.end, - streams.gpu_indexes[0].0, - ) - .unwrap(); - let mut output_slice = output - .d_blocks - .0 - .d_vec - .as_mut_slice(.., streams.gpu_indexes[0].0) + .as_slice(lwe_size * block_range.start..lwe_size * block_range.end, 0) .unwrap(); + let mut output_slice = output.d_blocks.0.d_vec.as_mut_slice(.., 0).unwrap(); let num_ct_blocks = block_range.len() as u32; match &self.bootstrapping_key { @@ -1054,22 +1046,16 @@ impl CudaServerKey { let lwe_dimension = input.d_blocks.lwe_dimension(); let lwe_size = lwe_dimension.to_lwe_size().0; - let input_slice = input - .d_blocks - .0 - .d_vec - .as_slice(.., streams.gpu_indexes[0].0) - .unwrap(); + let input_slice = input.d_blocks.0.d_vec.as_slice(.., 0).unwrap(); // The accumulator has been rotated, we can now proceed with the various sample extractions let function_count = lut.function_count(); let num_ct_blocks = input.d_blocks.lwe_ciphertext_count().0; let total_radixes_size = num_ct_blocks * lwe_size * function_count; - let mut output_radixes = - CudaVec::new(total_radixes_size, streams, streams.gpu_indexes[0].0); + let mut output_radixes = CudaVec::new(total_radixes_size, streams, 0); let mut output_slice = output_radixes - .as_mut_slice(0..total_radixes_size, streams.gpu_indexes[0].0) + .as_mut_slice(0..total_radixes_size, 0) .unwrap(); match &self.bootstrapping_key { @@ -1132,19 +1118,11 @@ impl CudaServerKey { for i in 0..function_count { let slice_size = num_ct_blocks * lwe_size; let mut ct = input.duplicate(streams); - let mut ct_slice = ct - .d_blocks - .0 - .d_vec - .as_mut_slice(0..slice_size, streams.gpu_indexes[0].0) - .unwrap(); + let mut ct_slice = ct.d_blocks.0.d_vec.as_mut_slice(0..slice_size, 0).unwrap(); let slice_size = num_ct_blocks * lwe_size; let output_slice = output_radixes - .as_mut_slice( - slice_size * i..slice_size * (i + 1), - streams.gpu_indexes[0].0, - ) + .as_mut_slice(slice_size * i..slice_size * (i + 1), 0) .unwrap(); ct_slice.copy_from_gpu_async(&output_slice, streams, 0); @@ -1197,16 +1175,12 @@ impl CudaServerKey { .d_blocks .0 .d_vec - .as_slice(lwe_size * (num_ct_blocks - 1).., streams.gpu_indexes[0].0) + .as_slice(lwe_size * (num_ct_blocks - 1).., 0) .unwrap(); let mut output_slice = output_radix - .as_mut_slice( - lwe_size * num_ct_blocks..lwe_size * new_num_ct_blocks, - streams.gpu_indexes[0].0, - ) + .as_mut_slice(lwe_size * num_ct_blocks..lwe_size * new_num_ct_blocks, 0) .unwrap(); - let (padding_block, new_blocks) = - output_slice.split_at_mut(lwe_size, streams.gpu_indexes[0]); + let (padding_block, new_blocks) = output_slice.split_at_mut(lwe_size, 0); let mut padding_block = padding_block.unwrap(); let mut new_blocks = new_blocks.unwrap(); @@ -1262,7 +1236,7 @@ impl CudaServerKey { } for i in 0..num_blocks - 1 { let mut output_block = new_blocks - .get_mut(lwe_size * i..lwe_size * (i + 1), streams.gpu_indexes[0]) + .get_mut(lwe_size * i..lwe_size * (i + 1), 0) .unwrap(); output_block.copy_from_gpu_async(&padding_block, streams, 0); }