diff --git a/tfhe/docs/guides/run_on_gpu.md b/tfhe/docs/guides/run_on_gpu.md index 2ee6b5356d..0269032b82 100644 --- a/tfhe/docs/guides/run_on_gpu.md +++ b/tfhe/docs/guides/run_on_gpu.md @@ -256,3 +256,78 @@ fn main() { } ``` + +## Array types + +It is possible to use array types on GPU, just as [on CPU](array.md). Here is an example showing how to do it: +```rust +use tfhe::{ConfigBuilder, set_server_key, ClearArray, ClientKey, CompressedServerKey}; +use tfhe::array::GpuFheUint32Array; +use tfhe::prelude::*; + +fn main() { + let config = ConfigBuilder::default().build(); + + let cks = ClientKey::generate(config); + let compressed_server_key = CompressedServerKey::new(&cks); + + let gpu_key = compressed_server_key.decompress_to_gpu(); + set_server_key(gpu_key); + + let num_elems = 4 * 4; + let clear_xs = (0..num_elems as u32).collect::>(); + let clear_ys = vec![1u32; num_elems]; + + // Encrypted 2D array with values + // [[ 0, 1, 2, 3] + // [ 4, 5, 6, 7] + // [ 8, 9, 10, 11] + // [ 12, 13, 14, 15]] + let xs = GpuFheUint32Array::try_encrypt((clear_xs.as_slice(), vec![4, 4]), &cks).unwrap(); + // Encrypted 2D array with values + // [[ 1, 1, 1, 1] + // [ 1, 1, 1, 1] + // [ 1, 1, 1, 1] + // [ 1, 1, 1, 1]] + let ys = GpuFheUint32Array::try_encrypt((clear_ys.as_slice(), vec![4, 4]), &cks).unwrap(); + + assert_eq!(xs.num_dim(), 2); + assert_eq!(xs.shape(), &[4, 4]); + assert_eq!(ys.num_dim(), 2); + assert_eq!(ys.shape(), &[4, 4]); + + // Take a sub slice + // [[ 10, 11] + // [ 14, 15]] + let xss = xs.slice(&[2..4, 2..4]); + // Take a sub slice + // [[ 1, 1] + // [ 1, 1]] + let yss = ys.slice(&[2..4, 2..4]); + + assert_eq!(xss.num_dim(), 2); + assert_eq!(xss.shape(), &[2, 2]); + assert_eq!(yss.num_dim(), 2); + assert_eq!(yss.shape(), &[2, 2]); + + let r = &xss + &yss; + + // Result is + // [[ 11, 12] + // [ 15, 16]] + let result: Vec = r.decrypt(&cks); + assert_eq!(result, vec![11, 12, 15, 16]); + + // Clear 2D array with values + // [[ 10, 20] + // [ 30, 40]] + let clear_array = ClearArray::new(vec![10u32, 20u32, 30u32, 40u32], vec![2, 2]); + let r = &xss + &clear_array; + + // Result is + // [[ 20, 31] + // [ 44, 55]] + let r: Vec = r.decrypt(&cks); + assert_eq!(r, vec![20, 31, 44, 55]); +} +```