Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(gpu): fix single gpu on device other than 0 #1880

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 48 additions & 52 deletions tfhe/src/core_crypto/gpu/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ where
///
/// The caller must ensure that the slice outlives the pointer this function returns,
/// or else it will end up pointing to garbage.
pub(crate) unsafe fn as_c_ptr(&self, gpu_index: u32) -> *const c_void {
self.ptrs[gpu_index as usize]
pub(crate) unsafe fn as_c_ptr(&self, index: usize) -> *const c_void {
self.ptrs[index]
}
pub(crate) fn gpu_index(&self, index: u32) -> GpuIndex {
self.gpu_indexes[index as usize]
pub(crate) fn gpu_index(&self, index: usize) -> GpuIndex {
self.gpu_indexes[index]
}
}

Expand All @@ -75,16 +75,16 @@ where
///
/// The caller must ensure that the slice outlives the pointer this function returns,
/// or else it will end up pointing to garbage.
pub(crate) unsafe fn as_mut_c_ptr(&mut self, gpu_index: u32) -> *mut c_void {
self.ptrs[gpu_index as usize]
pub(crate) unsafe fn as_mut_c_ptr(&mut self, index: usize) -> *mut c_void {
self.ptrs[index]
}

/// # Safety
///
/// The caller must ensure that the slice outlives the pointer this function returns,
/// or else it will end up pointing to garbage.
pub(crate) unsafe fn as_c_ptr(&self, index: u32) -> *const c_void {
self.ptrs[index as usize].cast_const()
pub(crate) unsafe fn as_c_ptr(&self, index: usize) -> *const c_void {
self.ptrs[index].cast_const()
}

/// Copies data between two `CudaSlice`
Expand All @@ -93,24 +93,20 @@ where
///
/// - [CudaStreams::synchronize] __must__ be called after the copy as soon as synchronization is
/// required.
pub unsafe fn copy_from_gpu_async(
&mut self,
src: &Self,
streams: &CudaStreams,
stream_index: u32,
) where
pub unsafe fn copy_from_gpu_async(&mut self, src: &Self, streams: &CudaStreams, index: usize)
where
T: Numeric,
{
assert_eq!(self.len(stream_index), src.len(stream_index));
let size = src.len(stream_index) * std::mem::size_of::<T>();
assert_eq!(self.len(index), src.len(index));
let size = src.len(index) * std::mem::size_of::<T>();
// We check that src is not empty to avoid invalid pointers
if size > 0 {
cuda_memcpy_async_gpu_to_gpu(
self.as_mut_c_ptr(stream_index),
src.as_c_ptr(stream_index),
self.as_mut_c_ptr(index),
src.as_c_ptr(index),
size as u64,
streams.ptr[stream_index as usize],
streams.gpu_indexes[stream_index as usize].0,
streams.ptr[index],
streams.gpu_indexes[index].0,
);
}
}
Expand All @@ -122,107 +118,107 @@ where
///
/// - [CudaStreams::synchronize] __must__ be called after the copy as soon as synchronization is
/// required.
pub unsafe fn copy_to_cpu_async(&self, dest: &mut [T], streams: &CudaStreams, stream_index: u32)
pub unsafe fn copy_to_cpu_async(&self, dest: &mut [T], streams: &CudaStreams, index: usize)
where
T: Numeric,
{
assert_eq!(self.len(stream_index), dest.len());
let size = self.len(stream_index) * std::mem::size_of::<T>();
assert_eq!(self.len(index), dest.len());
let size = self.len(index) * std::mem::size_of::<T>();
// We check that src is not empty to avoid invalid pointers
if size > 0 {
cuda_memcpy_async_to_cpu(
dest.as_mut_ptr().cast::<c_void>(),
self.as_c_ptr(stream_index),
self.as_c_ptr(index),
size as u64,
streams.ptr[stream_index as usize],
streams.gpu_indexes[stream_index as usize].0,
streams.ptr[index],
streams.gpu_indexes[index].0,
);
}
}

/// Returns the number of elements in the vector, also referred to as its ‘length’.
pub fn len(&self, index: u32) -> usize {
self.lengths[index as usize]
pub fn len(&self, index: usize) -> usize {
self.lengths[index]
}

/// Returns true if the ptr is empty
pub fn is_empty(&self, index: u32) -> bool {
self.lengths[index as usize] == 0
pub fn is_empty(&self, index: usize) -> bool {
self.lengths[index] == 0
}

pub(crate) fn get_mut<R>(&mut self, range: R, index: GpuIndex) -> Option<CudaSliceMut<T>>
pub(crate) fn get_mut<R>(&mut self, range: R, index: usize) -> Option<CudaSliceMut<T>>
where
R: std::ops::RangeBounds<usize>,
T: Numeric,
{
let (start, end) = range_bounds_to_start_end(self.len(index.0), range).into_inner();
let (start, end) = range_bounds_to_start_end(self.len(index), range).into_inner();

// Check the range is compatible with the vec
if end <= start || end > self.lengths[index.0 as usize] - 1 {
if end <= start || end > self.lengths[index] - 1 {
None
} else {
// Shift ptr
let shifted_ptr: *mut c_void =
self.ptrs[index.0 as usize].wrapping_byte_add(start * std::mem::size_of::<T>());
self.ptrs[index].wrapping_byte_add(start * std::mem::size_of::<T>());

// Compute the length
let new_len = end - start + 1;

// Create the slice
Some(unsafe {
CudaSliceMut::new(shifted_ptr, new_len, self.gpu_indexes[index.0 as usize])
})
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len, self.gpu_indexes[index]) })
}
}

pub(crate) fn split_at_mut(
&mut self,
mid: usize,
index: GpuIndex,
index: usize,
) -> (Option<CudaSliceMut<T>>, Option<CudaSliceMut<T>>)
where
T: Numeric,
{
// Check the index is compatible with the vec
if mid > self.lengths[index.0 as usize] - 1 {
if mid > self.lengths[index] - 1 {
(None, None)
} else if mid == 0 {
(
None,
Some(unsafe {
CudaSliceMut::new(
self.ptrs[index.0 as usize],
self.lengths[index.0 as usize],
index,
self.ptrs[index],
self.lengths[index],
self.gpu_indexes[index],
)
}),
)
} else if mid == self.lengths[index.0 as usize] - 1 {
} else if mid == self.lengths[index] - 1 {
(
Some(unsafe {
CudaSliceMut::new(
self.ptrs[index.0 as usize],
self.lengths[index.0 as usize],
index,
self.ptrs[index],
self.lengths[index],
self.gpu_indexes[index],
)
}),
None,
)
} else {
let new_len_1 = mid;
let new_len_2 = self.lengths[index.0 as usize] - mid;
let new_len_2 = self.lengths[index] - mid;
// Shift ptr
let shifted_ptr: *mut c_void =
self.ptrs[index.0 as usize].wrapping_byte_add(mid * std::mem::size_of::<T>());
self.ptrs[index].wrapping_byte_add(mid * std::mem::size_of::<T>());

// Create the slice
(
Some(unsafe { CudaSliceMut::new(self.ptrs[index.0 as usize], new_len_1, index) }),
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len_2, index) }),
Some(unsafe {
CudaSliceMut::new(self.ptrs[index], new_len_1, self.gpu_indexes[index])
}),
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len_2, self.gpu_indexes[index]) }),
)
}
}
pub(crate) fn gpu_index(&self, index: u32) -> GpuIndex {
self.gpu_indexes[index as usize]
pub(crate) fn gpu_index(&self, index: usize) -> GpuIndex {
self.gpu_indexes[index]
}
}
12 changes: 6 additions & 6 deletions tfhe/src/core_crypto/gpu/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ impl<T: Numeric> CudaVec<T> {
self.ptr[index as usize].cast_const()
}

pub(crate) fn as_slice<R>(&self, range: R, index: u32) -> Option<CudaSlice<T>>
pub(crate) fn as_slice<R>(&self, range: R, index: usize) -> Option<CudaSlice<T>>
where
R: std::ops::RangeBounds<usize>,
T: Numeric,
Expand All @@ -374,19 +374,19 @@ impl<T: Numeric> CudaVec<T> {
} else {
// Shift ptr
let shifted_ptr: *mut c_void =
self.ptr[index as usize].wrapping_byte_add(start * std::mem::size_of::<T>());
self.ptr[index].wrapping_byte_add(start * std::mem::size_of::<T>());

// Compute the length
let new_len = end - start + 1;

// Create the slice
Some(unsafe { CudaSlice::new(shifted_ptr, new_len, GpuIndex(index)) })
Some(unsafe { CudaSlice::new(shifted_ptr, new_len, self.gpu_indexes[index]) })
}
}

// clippy complains as we only manipulate pointers, but we want to keep rust semantics
#[allow(clippy::needless_pass_by_ref_mut)]
pub(crate) fn as_mut_slice<R>(&mut self, range: R, index: u32) -> Option<CudaSliceMut<T>>
pub(crate) fn as_mut_slice<R>(&mut self, range: R, index: usize) -> Option<CudaSliceMut<T>>
where
R: std::ops::RangeBounds<usize>,
T: Numeric,
Expand All @@ -399,13 +399,13 @@ impl<T: Numeric> CudaVec<T> {
} else {
// Shift ptr
let shifted_ptr: *mut c_void =
self.ptr[index as usize].wrapping_byte_add(start * std::mem::size_of::<T>());
self.ptr[index].wrapping_byte_add(start * std::mem::size_of::<T>());

// Compute the length
let new_len = end - start + 1;

// Create the slice
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len, GpuIndex(index)) })
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len, self.gpu_indexes[index]) })
}
}

Expand Down
Loading
Loading