Skip to content

Commit

Permalink
feat(gpu): signed comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Apr 3, 2024
1 parent 3c39abe commit cc72594
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 140 deletions.
72 changes: 72 additions & 0 deletions tfhe/benches/integer/signed_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,36 @@ mod cuda {
display_name: ne
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: unchecked_gt,
display_name: gt
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: unchecked_ge,
display_name: ge
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: unchecked_lt,
display_name: lt
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: unchecked_le,
display_name: le
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: unchecked_min,
display_name: min
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: unchecked_max,
display_name: max
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_add,
display_name: add,
Expand Down Expand Up @@ -1766,6 +1796,36 @@ mod cuda {
display_name: ne
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: gt,
display_name: gt
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: ge,
display_name: ge
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: lt,
display_name: lt
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: le,
display_name: le
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: min,
display_name: min
);

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: max,
display_name: max
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_add,
display_name: add,
Expand Down Expand Up @@ -1830,6 +1890,12 @@ mod cuda {
cuda_unchecked_rotate_right,
cuda_unchecked_eq,
cuda_unchecked_ne,
cuda_unchecked_gt,
cuda_unchecked_ge,
cuda_unchecked_lt,
cuda_unchecked_le,
cuda_unchecked_min,
cuda_unchecked_max,
);

criterion_group!(
Expand Down Expand Up @@ -1860,6 +1926,12 @@ mod cuda {
cuda_rotate_right,
cuda_eq,
cuda_ne,
cuda_gt,
cuda_ge,
cuda_lt,
cuda_le,
cuda_min,
cuda_max,
);

criterion_group!(
Expand Down
26 changes: 14 additions & 12 deletions tfhe/src/high_level_api/integers/unsigned/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,11 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner_result =
cuda_key
.key
.max(&self.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), stream);
let inner_result = cuda_key.key.max(
&*self.ciphertext.on_gpu(),
&*rhs.ciphertext.on_gpu(),
stream,
);
Self::new(inner_result)
}),
})
Expand Down Expand Up @@ -260,10 +261,11 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner_result =
cuda_key
.key
.min(&self.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), stream);
let inner_result = cuda_key.key.min(
&*self.ciphertext.on_gpu(),
&*rhs.ciphertext.on_gpu(),
stream,
);
Self::new(inner_result)
}),
})
Expand Down Expand Up @@ -421,7 +423,7 @@ where
let inner_result =
cuda_key
.key
.lt(&self.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), stream);
.lt(&*self.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), stream);
FheBool::new(inner_result)
}),
})
Expand Down Expand Up @@ -459,7 +461,7 @@ where
let inner_result =
cuda_key
.key
.le(&self.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), stream);
.le(&*self.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), stream);
FheBool::new(inner_result)
}),
})
Expand Down Expand Up @@ -497,7 +499,7 @@ where
let inner_result =
cuda_key
.key
.gt(&self.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), stream);
.gt(&*self.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), stream);
FheBool::new(inner_result)
}),
})
Expand Down Expand Up @@ -535,7 +537,7 @@ where
let inner_result =
cuda_key
.key
.ge(&self.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), stream);
.ge(&*self.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), stream);
FheBool::new(inner_result)
}),
})
Expand Down
Loading

0 comments on commit cc72594

Please sign in to comment.