Skip to content

Commit

Permalink
fix(gpu): fix single gpu on device other than 0
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Dec 6, 2024
1 parent 802e98c commit b2072f1
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 128 deletions.
100 changes: 48 additions & 52 deletions tfhe/src/core_crypto/gpu/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ where
///
/// The caller must ensure that the slice outlives the pointer this function returns,
/// or else it will end up pointing to garbage.
pub(crate) unsafe fn as_c_ptr(&self, gpu_index: u32) -> *const c_void {
self.ptrs[gpu_index as usize]
pub(crate) unsafe fn as_c_ptr(&self, index: usize) -> *const c_void {
self.ptrs[index]
}
pub(crate) fn gpu_index(&self, index: u32) -> GpuIndex {
self.gpu_indexes[index as usize]
pub(crate) fn gpu_index(&self, index: usize) -> GpuIndex {
self.gpu_indexes[index]
}
}

Expand All @@ -75,16 +75,16 @@ where
///
/// The caller must ensure that the slice outlives the pointer this function returns,
/// or else it will end up pointing to garbage.
pub(crate) unsafe fn as_mut_c_ptr(&mut self, gpu_index: u32) -> *mut c_void {
self.ptrs[gpu_index as usize]
pub(crate) unsafe fn as_mut_c_ptr(&mut self, index: usize) -> *mut c_void {
self.ptrs[index]
}

/// # Safety
///
/// The caller must ensure that the slice outlives the pointer this function returns,
/// or else it will end up pointing to garbage.
pub(crate) unsafe fn as_c_ptr(&self, index: u32) -> *const c_void {
self.ptrs[index as usize].cast_const()
pub(crate) unsafe fn as_c_ptr(&self, index: usize) -> *const c_void {
self.ptrs[index].cast_const()
}

/// Copies data between two `CudaSlice`
Expand All @@ -93,24 +93,20 @@ where
///
/// - [CudaStreams::synchronize] __must__ be called after the copy as soon as synchronization is
/// required.
pub unsafe fn copy_from_gpu_async(
&mut self,
src: &Self,
streams: &CudaStreams,
stream_index: u32,
) where
pub unsafe fn copy_from_gpu_async(&mut self, src: &Self, streams: &CudaStreams, index: usize)
where
T: Numeric,
{
assert_eq!(self.len(stream_index), src.len(stream_index));
let size = src.len(stream_index) * std::mem::size_of::<T>();
assert_eq!(self.len(index), src.len(index));
let size = src.len(index) * std::mem::size_of::<T>();
// We check that src is not empty to avoid invalid pointers
if size > 0 {
cuda_memcpy_async_gpu_to_gpu(
self.as_mut_c_ptr(stream_index),
src.as_c_ptr(stream_index),
self.as_mut_c_ptr(index),
src.as_c_ptr(index),
size as u64,
streams.ptr[stream_index as usize],
streams.gpu_indexes[stream_index as usize].0,
streams.ptr[index],
streams.gpu_indexes[index].0,
);
}
}
Expand All @@ -122,107 +118,107 @@ where
///
/// - [CudaStreams::synchronize] __must__ be called after the copy as soon as synchronization is
/// required.
pub unsafe fn copy_to_cpu_async(&self, dest: &mut [T], streams: &CudaStreams, stream_index: u32)
pub unsafe fn copy_to_cpu_async(&self, dest: &mut [T], streams: &CudaStreams, index: usize)
where
T: Numeric,
{
assert_eq!(self.len(stream_index), dest.len());
let size = self.len(stream_index) * std::mem::size_of::<T>();
assert_eq!(self.len(index), dest.len());
let size = self.len(index) * std::mem::size_of::<T>();
// We check that src is not empty to avoid invalid pointers
if size > 0 {
cuda_memcpy_async_to_cpu(
dest.as_mut_ptr().cast::<c_void>(),
self.as_c_ptr(stream_index),
self.as_c_ptr(index),
size as u64,
streams.ptr[stream_index as usize],
streams.gpu_indexes[stream_index as usize].0,
streams.ptr[index],
streams.gpu_indexes[index].0,
);
}
}

/// Returns the number of elements in the vector, also referred to as its ‘length’.
pub fn len(&self, index: u32) -> usize {
self.lengths[index as usize]
pub fn len(&self, index: usize) -> usize {
self.lengths[index]
}

/// Returns true if the ptr is empty
pub fn is_empty(&self, index: u32) -> bool {
self.lengths[index as usize] == 0
pub fn is_empty(&self, index: usize) -> bool {
self.lengths[index] == 0
}

pub(crate) fn get_mut<R>(&mut self, range: R, index: GpuIndex) -> Option<CudaSliceMut<T>>
pub(crate) fn get_mut<R>(&mut self, range: R, index: usize) -> Option<CudaSliceMut<T>>
where
R: std::ops::RangeBounds<usize>,
T: Numeric,
{
let (start, end) = range_bounds_to_start_end(self.len(index.0), range).into_inner();
let (start, end) = range_bounds_to_start_end(self.len(index), range).into_inner();

// Check the range is compatible with the vec
if end <= start || end > self.lengths[index.0 as usize] - 1 {
if end <= start || end > self.lengths[index] - 1 {
None
} else {
// Shift ptr
let shifted_ptr: *mut c_void =
self.ptrs[index.0 as usize].wrapping_byte_add(start * std::mem::size_of::<T>());
self.ptrs[index].wrapping_byte_add(start * std::mem::size_of::<T>());

// Compute the length
let new_len = end - start + 1;

// Create the slice
Some(unsafe {
CudaSliceMut::new(shifted_ptr, new_len, self.gpu_indexes[index.0 as usize])
})
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len, self.gpu_indexes[index]) })
}
}

pub(crate) fn split_at_mut(
&mut self,
mid: usize,
index: GpuIndex,
index: usize,
) -> (Option<CudaSliceMut<T>>, Option<CudaSliceMut<T>>)
where
T: Numeric,
{
// Check the index is compatible with the vec
if mid > self.lengths[index.0 as usize] - 1 {
if mid > self.lengths[index] - 1 {
(None, None)
} else if mid == 0 {
(
None,
Some(unsafe {
CudaSliceMut::new(
self.ptrs[index.0 as usize],
self.lengths[index.0 as usize],
index,
self.ptrs[index],
self.lengths[index],
self.gpu_indexes[index],
)
}),
)
} else if mid == self.lengths[index.0 as usize] - 1 {
} else if mid == self.lengths[index] - 1 {
(
Some(unsafe {
CudaSliceMut::new(
self.ptrs[index.0 as usize],
self.lengths[index.0 as usize],
index,
self.ptrs[index],
self.lengths[index],
self.gpu_indexes[index],
)
}),
None,
)
} else {
let new_len_1 = mid;
let new_len_2 = self.lengths[index.0 as usize] - mid;
let new_len_2 = self.lengths[index] - mid;
// Shift ptr
let shifted_ptr: *mut c_void =
self.ptrs[index.0 as usize].wrapping_byte_add(mid * std::mem::size_of::<T>());
self.ptrs[index].wrapping_byte_add(mid * std::mem::size_of::<T>());

// Create the slice
(
Some(unsafe { CudaSliceMut::new(self.ptrs[index.0 as usize], new_len_1, index) }),
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len_2, index) }),
Some(unsafe {
CudaSliceMut::new(self.ptrs[index], new_len_1, self.gpu_indexes[index])
}),
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len_2, self.gpu_indexes[index]) }),
)
}
}
pub(crate) fn gpu_index(&self, index: u32) -> GpuIndex {
self.gpu_indexes[index as usize]
pub(crate) fn gpu_index(&self, index: usize) -> GpuIndex {
self.gpu_indexes[index]
}
}
12 changes: 6 additions & 6 deletions tfhe/src/core_crypto/gpu/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ impl<T: Numeric> CudaVec<T> {
self.ptr[index as usize].cast_const()
}

pub(crate) fn as_slice<R>(&self, range: R, index: u32) -> Option<CudaSlice<T>>
pub(crate) fn as_slice<R>(&self, range: R, index: usize) -> Option<CudaSlice<T>>
where
R: std::ops::RangeBounds<usize>,
T: Numeric,
Expand All @@ -374,19 +374,19 @@ impl<T: Numeric> CudaVec<T> {
} else {
// Shift ptr
let shifted_ptr: *mut c_void =
self.ptr[index as usize].wrapping_byte_add(start * std::mem::size_of::<T>());
self.ptr[index].wrapping_byte_add(start * std::mem::size_of::<T>());

// Compute the length
let new_len = end - start + 1;

// Create the slice
Some(unsafe { CudaSlice::new(shifted_ptr, new_len, GpuIndex(index)) })
Some(unsafe { CudaSlice::new(shifted_ptr, new_len, self.gpu_indexes[index]) })
}
}

// clippy complains as we only manipulate pointers, but we want to keep rust semantics
#[allow(clippy::needless_pass_by_ref_mut)]
pub(crate) fn as_mut_slice<R>(&mut self, range: R, index: u32) -> Option<CudaSliceMut<T>>
pub(crate) fn as_mut_slice<R>(&mut self, range: R, index: usize) -> Option<CudaSliceMut<T>>
where
R: std::ops::RangeBounds<usize>,
T: Numeric,
Expand All @@ -399,13 +399,13 @@ impl<T: Numeric> CudaVec<T> {
} else {
// Shift ptr
let shifted_ptr: *mut c_void =
self.ptr[index as usize].wrapping_byte_add(start * std::mem::size_of::<T>());
self.ptr[index].wrapping_byte_add(start * std::mem::size_of::<T>());

// Compute the length
let new_len = end - start + 1;

// Create the slice
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len, GpuIndex(index)) })
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len, self.gpu_indexes[index]) })
}
}

Expand Down
Loading

0 comments on commit b2072f1

Please sign in to comment.