Skip to content

Commit

Permalink
chore(hlapi): add tests booleans on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Feb 29, 2024
1 parent 46a87c6 commit fab719c
Show file tree
Hide file tree
Showing 6 changed files with 349 additions and 91 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/aws_tfhe_gpu_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ jobs:
run: |
make test_c_api_gpu
- name: Run High Level API Tests
run: |
make test_high_level_api_gpu
- name: Slack Notification
if: ${{ always() }}
Expand Down
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,11 @@ test_high_level_api: install_rs_build_toolchain
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache -p $(TFHE_SPEC) \
-- high_level_api::

test_high_level_api_gpu: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) nextest run --cargo-profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),integer,internal-keycache,gpu -p $(TFHE_SPEC) \
-E "test(/high_level_api::.*gpu.*/)"

.PHONY: test_user_doc # Run tests from the .md documentation
test_user_doc: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) --doc \
Expand Down
32 changes: 14 additions & 18 deletions tfhe/src/high_level_api/booleans/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,13 @@ impl FheEq<bool> for FheBool {
InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
// with_thread_local_cuda_stream(|stream| {
// let inner =
// cuda_key
// .key
// .scalar_eq(&self.ciphertext.on_gpu(), u8::from(other),
// stream); InnerBoolean::Cuda(inner)
todo!()
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner =
cuda_key
.key
.scalar_eq(&self.ciphertext.on_gpu(), u8::from(other), stream);
InnerBoolean::Cuda(inner)
}),
});
Self::new(ciphertext)
}
Expand Down Expand Up @@ -366,15 +364,13 @@ impl FheEq<bool> for FheBool {
InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
// with_thread_local_cuda_stream(|stream| {
// let inner =
// cuda_key
// .key
// .scalar_ne(&self.ciphertext.on_gpu(), u8::from(other),
// stream); InnerBoolean::Cuda(inner)
todo!()
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner =
cuda_key
.key
.scalar_ne(&self.ciphertext.on_gpu(), u8::from(other), stream);
InnerBoolean::Cuda(inner)
}),
});
Self::new(ciphertext)
}
Expand Down
8 changes: 6 additions & 2 deletions tfhe/src/high_level_api/booleans/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ impl CompactFheBool {
assert_eq!(ct.blocks.len(), 1);
let mut block = BooleanBlock::new_unchecked(ct.blocks.into_iter().next().unwrap());
block.0.degree = Degree::new(1);
FheBool::new(block)
let mut ciphertext = FheBool::new(block);
ciphertext.ciphertext.move_to_device_of_server_key_if_set();
ciphertext
}
}

Expand Down Expand Up @@ -131,7 +133,9 @@ impl CompactFheBoolList {
assert_eq!(ct.blocks.len(), 1);
let mut block = BooleanBlock::new_unchecked(ct.blocks.into_iter().next().unwrap());
block.0.degree = Degree::new(1);
FheBool::new(block)
let mut ciphertext = FheBool::new(block);
ciphertext.ciphertext.move_to_device_of_server_key_if_set();
ciphertext
})
.collect::<Vec<_>>()
}
Expand Down
6 changes: 4 additions & 2 deletions tfhe/src/high_level_api/booleans/compressed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ impl CompressedFheBool {
///
/// See [CompressedFheBool] example.
pub fn decompress(&self) -> FheBool {
FheBool::new(BooleanBlock::new_unchecked(
let mut ciphertext = FheBool::new(BooleanBlock::new_unchecked(
self.ciphertext.clone().decompress(),
))
));
ciphertext.ciphertext.move_to_device_of_server_key_if_set();
ciphertext
}
}

Expand Down
Loading

0 comments on commit fab719c

Please sign in to comment.