diff --git a/docs/guides/tfhers/shared-key.md b/docs/guides/tfhers/shared-key.md index 1951c1dc10..066ca9fa8e 100644 --- a/docs/guides/tfhers/shared-key.md +++ b/docs/guides/tfhers/shared-key.md @@ -97,7 +97,7 @@ decoded = tfhers_type.decode(result) We are going to create a TFHE-rs bridge that facilitates the seamless transfer of ciphertexts and keys between Concrete and TFHE-rs. ```python -tfhers_bridge = tfhers.new_bridge(circuit=circuit) +tfhers_bridge = tfhers.new_bridge(circuit) ``` ## Key generation diff --git a/frontends/concrete-python/concrete/fhe/compilation/circuit.py b/frontends/concrete-python/concrete/fhe/compilation/circuit.py index 873e2a4cf8..aee643bdda 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/circuit.py +++ b/frontends/concrete-python/concrete/fhe/compilation/circuit.py @@ -39,13 +39,6 @@ def __init__(self, module: FheModule): def _function(self) -> FheFunction: return getattr(self._module, self._name) - @property - def function_name(self) -> str: - """ - Return the name of the circuit. - """ - return self._name - def __str__(self): return self._function.graph.format() diff --git a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py index 02e4d3f85b..5c469ea1da 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py @@ -3,100 +3,118 @@ """ # pylint: disable=import-error,no-member,no-name-in-module -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from concrete.compiler import LweSecretKey, TfhersExporter, TfhersFheIntDescription -from concrete import fhe +import concrete.fhe as fhe from concrete.fhe.compilation.value import Value from .dtypes import EncryptionKeyChoice, TFHERSIntegerType class Bridge: - """TFHErs Bridge extend a Circuit with TFHErs functionalities. - - input_types (List[Optional[TFHERSIntegerType]]): maps every input to a type. None means - a non-tfhers type - output_types (List[Optional[TFHERSIntegerType]]): maps every output to a type. None means - a non-tfhers type - input_shapes (List[Optional[Tuple[int, ...]]]): maps every input to a shape. None means - a non-tfhers type - output_shapes (List[Optional[Tuple[int, ...]]]): maps every output to a shape. None means - a non-tfhers type + """TFHErs Bridge extend an Module with TFHErs functionalities. + + input_types_per_func (Dict[str, List[Optional[TFHERSIntegerType]]]): + maps every input to a type for every function in the module. None means a non-tfhers type + output_types_per_func (Dict[str, List[Optional[TFHERSIntegerType]]]): + maps every output to a type for every function in the module. None means a non-tfhers type + input_shapes_per_func (Dict[str, List[Optional[Tuple[int, ...]]]]): + maps every input to a shape for every function in the module. None means a non-tfhers type + output_shapes_per_func (Dict[str, List[Optional[Tuple[int, ...]]]]): + maps every output to a shape for every function in the module. None means a non-tfhers type """ - circuit: "fhe.Circuit" - input_types: List[Optional[TFHERSIntegerType]] - output_types: List[Optional[TFHERSIntegerType]] - input_shapes: List[Optional[Tuple[int, ...]]] - output_shapes: List[Optional[Tuple[int, ...]]] + module: "fhe.Module" + default_function: Optional[str] + input_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]] + output_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]] + input_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]] + output_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]] def __init__( self, - circuit: "fhe.Circuit", - input_types: List[Optional[TFHERSIntegerType]], - output_types: List[Optional[TFHERSIntegerType]], - input_shapes: List[Optional[Tuple[int, ...]]], - output_shapes: List[Optional[Tuple[int, ...]]], + module: "fhe.Module", + input_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]], + output_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]], + input_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]], + output_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]], ): - self.circuit = circuit - self.input_types = input_types - self.output_types = output_types - self.input_shapes = input_shapes - self.output_shapes = output_shapes + if module.function_count == 1: + self.default_function = next(iter(module.graphs.keys())) + else: + self.default_function = None + self.module = module + self.input_types_per_func = input_types_per_func + self.output_types_per_func = output_types_per_func + self.input_shapes_per_func = input_shapes_per_func + self.output_shapes_per_func = output_shapes_per_func + + def _get_default_func_or_raise_error(self, calling_func: str) -> str: + if self.default_function is not None: + return self.default_function + else: + raise RuntimeError( + "Module contains more than one function, so please provide 'func_name' while " + f"calling '{calling_func}'" + ) - def _input_type(self, input_idx: int) -> Optional[TFHERSIntegerType]: + def _input_type(self, func_name: str, input_idx: int) -> Optional[TFHERSIntegerType]: """Return the type of a certain input. Args: + func_name (str): name of the function the input belongs to input_idx (int): the input index to get the type of Returns: Optional[TFHERSIntegerType]: input type. None means a non-tfhers type """ - return self.input_types[input_idx] + return self.input_types_per_func[func_name][input_idx] - def _output_type(self, output_idx: int) -> Optional[TFHERSIntegerType]: + def _output_type(self, func_name: str, output_idx: int) -> Optional[TFHERSIntegerType]: """Return the type of a certain output. Args: + func_name (str): name of the function the output belongs to output_idx (int): the output index to get the type of Returns: Optional[TFHERSIntegerType]: output type. None means a non-tfhers type """ - return self.output_types[output_idx] + return self.output_types_per_func[func_name][output_idx] - def _input_shape(self, input_idx: int) -> Optional[Tuple[int, ...]]: + def _input_shape(self, func_name: str, input_idx: int) -> Optional[Tuple[int, ...]]: """Return the shape of a certain input. Args: + func_name (str): name of the function the input belongs to input_idx (int): the input index to get the shape of Returns: Optional[Tuple[int, ...]]: input shape. None means a non-tfhers type """ - return self.input_shapes[input_idx] + return self.input_shapes_per_func[func_name][input_idx] - def _output_shape(self, output_idx: int) -> Optional[Tuple[int, ...]]: # pragma: no cover + def _output_shape( + self, func_name: str, output_idx: int + ) -> Optional[Tuple[int, ...]]: # pragma: no cover """Return the shape of a certain output. Args: + func_name (str): name of the function the output belongs to output_idx (int): the output index to get the shape of Returns: Optional[Tuple[int, ...]]: output shape. None means a non-tfhers type """ - return self.output_shapes[output_idx] + return self.output_shapes_per_func[func_name][output_idx] - def _input_keyid(self, input_idx: int) -> int: - return self.circuit.client.specs.program_info.input_keyid_at( - input_idx, self.circuit.function_name - ) + def _input_keyid(self, func_name: str, input_idx: int) -> int: + return self.module.client.specs.program_info.input_keyid_at(input_idx, func_name) - def _input_variance(self, input_idx: int) -> float: - input_type = self._input_type(input_idx) + def _input_variance(self, func_name: str, input_idx: int) -> float: + input_type = self._input_type(func_name, input_idx) if input_type is None: # pragma: no cover msg = "input at 'input_idx' is not a TFHErs value" raise ValueError(msg) @@ -133,38 +151,48 @@ def _description_from_type( ks_first, ) - def import_value(self, buffer: bytes, input_idx: int) -> Value: + def import_value(self, buffer: bytes, input_idx: int, func_name: Optional[str] = None) -> Value: """Import a serialized TFHErs integer as a Value. Args: buffer (bytes): serialized integer input_idx (int): the index of the input expecting this value + func_name (Optional[str]): name of the function the value belongs to. + Doesn't need to be provided if there is a single function. Returns: fhe.TransportValue: imported value """ - input_type = self._input_type(input_idx) - input_shape = self._input_shape(input_idx) + if func_name is None: + func_name = self._get_default_func_or_raise_error("import_value") + + input_type = self._input_type(func_name, input_idx) + input_shape = self._input_shape(func_name, input_idx) if input_type is None or input_shape is None: # pragma: no cover msg = "input at 'input_idx' is not a TFHErs value" raise ValueError(msg) fheint_desc = self._description_from_type(input_type) - keyid = self._input_keyid(input_idx) - variance = self._input_variance(input_idx) + keyid = self._input_keyid(func_name, input_idx) + variance = self._input_variance(func_name, input_idx) return Value(TfhersExporter.import_int(buffer, fheint_desc, keyid, variance, input_shape)) - def export_value(self, value: Value, output_idx: int) -> bytes: + def export_value(self, value: Value, output_idx: int, func_name: Optional[str] = None) -> bytes: """Export a value as a serialized TFHErs integer. Args: value (TransportValue): value to export output_idx (int): the index corresponding to this output + func_name (Optional[str]): name of the function the value belongs to. + Doesn't need to be provided if there is a single function. Returns: bytes: serialized fheuint8 """ - output_type = self._output_type(output_idx) + if func_name is None: + func_name = self._get_default_func_or_raise_error("export_value") + + output_type = self._output_type(func_name, output_idx) if output_type is None: # pragma: no cover msg = "output at 'output_idx' is not a TFHErs value" raise ValueError(msg) @@ -174,18 +202,23 @@ def export_value(self, value: Value, output_idx: int) -> bytes: value._inner, fheint_desc # pylint: disable=protected-access ) - def serialize_input_secret_key(self, input_idx: int) -> bytes: + def serialize_input_secret_key(self, input_idx: int, func_name: Optional[str] = None) -> bytes: """Serialize secret key used for a specific input. Args: input_idx (int): input index corresponding to the key to serialize + func_name (Optional[str]): name of the function the key belongs to. + Doesn't need to be provided if there is a single function. Returns: bytes: serialized key """ - keyid = self._input_keyid(input_idx) + if func_name is None: + func_name = self._get_default_func_or_raise_error("serialize_input_secret_key") + + keyid = self._input_keyid(func_name, input_idx) # pylint: disable=protected-access - keys = self.circuit.client.keys + keys = self.module.client.keys assert keys is not None secret_key = keys._keyset.get_client_keys().get_secret_keys()[keyid] # type: ignore # pylint: enable=protected-access @@ -193,7 +226,7 @@ def serialize_input_secret_key(self, input_idx: int) -> bytes: def keygen_with_initial_keys( self, - input_idx_to_key_buffer: Dict[int, bytes], + input_idx_to_key_buffer: Dict[Union[Tuple[str, int], int], bytes], force: bool = False, seed: Optional[int] = None, encryption_seed: Optional[int] = None, @@ -210,30 +243,45 @@ def keygen_with_initial_keys( encryption_seed (Optional[int], default = None): seed for encryption randomness - input_idx_to_key_buffer (Dict[int, bytes]): initial keys to set before keygen + input_idx_to_key_buffer (Dict[Union[Tuple[str, int], int], bytes]): + initial keys to set before keygen. Two possible formats: the first is when you have + a single function. Here you can just provide the position of the input as index. + The second is when you have multiple functions. You will need to provide both the + name of the function and the input's position as index. Raises: RuntimeError: if failed to deserialize the key """ initial_keys: Dict[int, LweSecretKey] = {} - for input_idx in input_idx_to_key_buffer: - key_id = self._input_keyid(input_idx) + for idx in input_idx_to_key_buffer: + if isinstance(idx, tuple): + func_name, input_idx = idx + elif isinstance(idx, int) and self.default_function is not None: + input_idx = idx + func_name = self.default_function + else: + raise RuntimeError( + "Module contains more than one function, so please make sure to mention " + "the function name (not just the position) in input_idx_to_key_buffer. " + "An example index would be a tuple ('my_func', 1)." + ) + key_id = self._input_keyid(func_name, input_idx) # no need to deserialize the same key again if key_id in initial_keys: # pragma: no cover continue - key_buffer = input_idx_to_key_buffer[input_idx] - param = self.circuit.client.specs.program_info.get_keyset_info().secret_keys()[key_id] + key_buffer = input_idx_to_key_buffer[idx] + param = self.module.client.specs.program_info.get_keyset_info().secret_keys()[key_id] try: initial_keys[key_id] = LweSecretKey.deserialize(key_buffer, param) except Exception as e: # pragma: no cover msg = ( - f"failed deserializing key for input with index {input_idx}. Make sure the key" + f"failed deserializing key for input with index {idx}. Make sure the key" " is for the right input" ) raise RuntimeError(msg) from e - self.circuit.keygen( + self.module.keygen( force=force, seed=seed, encryption_seed=encryption_seed, @@ -241,33 +289,57 @@ def keygen_with_initial_keys( ) -def new_bridge(circuit: "fhe.Circuit") -> Bridge: - """Create a TFHErs bridge from a circuit. +def new_bridge(circuit_or_module: Union["fhe.Circuit", "fhe.Module"]) -> Bridge: + """Create a TFHErs bridge from a circuit or module. Args: - circuit (Circuit): compiled circuit + circuit (Union[Circuit, Module]): compiled circuit or module Returns: Bridge: TFHErs bridge """ - input_types: List[Optional[TFHERSIntegerType]] = [] - input_shapes: List[Optional[Tuple[int, ...]]] = [] - for input_node in circuit.graph.ordered_inputs(): - if isinstance(input_node.output.dtype, TFHERSIntegerType): - input_types.append(input_node.output.dtype) - input_shapes.append(input_node.output.shape) - else: - input_types.append(None) - input_shapes.append(None) - - output_types: List[Optional[TFHERSIntegerType]] = [] - output_shapes: List[Optional[Tuple[int, ...]]] = [] - for output_node in circuit.graph.ordered_outputs(): - if isinstance(output_node.output.dtype, TFHERSIntegerType): - output_types.append(output_node.output.dtype) - output_shapes.append(output_node.output.shape) - else: # pragma: no cover - output_types.append(None) - output_shapes.append(None) - - return Bridge(circuit, input_types, output_types, input_shapes, output_shapes) + if isinstance(circuit_or_module, fhe.Module): + module = circuit_or_module + else: + assert isinstance(circuit_or_module, fhe.Circuit) + module = circuit_or_module._module + + input_types_per_func = {} + output_types_per_func = {} + input_shapes_per_func = {} + output_shapes_per_func = {} + + for func_name, graph in module.graphs.items(): + input_types: List[Optional[TFHERSIntegerType]] = [] + input_shapes: List[Optional[Tuple[int, ...]]] = [] + for input_node in graph.ordered_inputs(): + if isinstance(input_node.output.dtype, TFHERSIntegerType): + input_types.append(input_node.output.dtype) + input_shapes.append(input_node.output.shape) + else: + input_types.append(None) + input_shapes.append(None) + + input_types_per_func[func_name] = input_types + input_shapes_per_func[func_name] = input_shapes + + output_types: List[Optional[TFHERSIntegerType]] = [] + output_shapes: List[Optional[Tuple[int, ...]]] = [] + for output_node in graph.ordered_outputs(): + if isinstance(output_node.output.dtype, TFHERSIntegerType): + output_types.append(output_node.output.dtype) + output_shapes.append(output_node.output.shape) + else: # pragma: no cover + output_types.append(None) + output_shapes.append(None) + + output_types_per_func[func_name] = output_types + output_shapes_per_func[func_name] = output_shapes + + return Bridge( + module, + input_types_per_func, + output_types_per_func, + input_shapes_per_func, + output_shapes_per_func, + ) diff --git a/frontends/concrete-python/examples/tfhers-ml/example.py b/frontends/concrete-python/examples/tfhers-ml/example.py index bd43d1ce1b..b1f5e078e2 100644 --- a/frontends/concrete-python/examples/tfhers-ml/example.py +++ b/frontends/concrete-python/examples/tfhers-ml/example.py @@ -94,7 +94,7 @@ def ccompilee(): circuit = compiler.compile(inputset) - tfhers_bridge = tfhers.new_bridge(circuit=circuit) + tfhers_bridge = tfhers.new_bridge(circuit) return circuit, tfhers_bridge diff --git a/frontends/concrete-python/examples/tfhers/example.py b/frontends/concrete-python/examples/tfhers/example.py index 953037fa1c..90a829fa98 100644 --- a/frontends/concrete-python/examples/tfhers/example.py +++ b/frontends/concrete-python/examples/tfhers/example.py @@ -51,7 +51,7 @@ def ccompilee(): inputset = [(tfhers_int(120), tfhers_int(120))] circuit = compiler.compile(inputset) - tfhers_bridge = tfhers.new_bridge(circuit=circuit) + tfhers_bridge = tfhers.new_bridge(circuit) return circuit, tfhers_bridge diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index fddc7713c3..e08f3c2d91 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -10,7 +10,7 @@ import numpy as np import pytest -from concrete import fhe +import concrete.fhe as fhe from concrete.fhe import tfhers @@ -50,7 +50,7 @@ def parameterize_partial_dtype(partial_dtype) -> tfhers.TFHERSIntegerType: def is_input_and_output_tfhers( - circuit: fhe.Circuit, + circuit: Union[fhe.Circuit, fhe.Module], lwe_dim: int, tfhers_ins: List[int], tfhers_outs: List[int], @@ -1092,6 +1092,254 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( os.remove(sum_ct_path) +@fhe.module() +class AddModuleOneFunc: + func_count = 1 + + @fhe.function({"x": "encrypted", "y": "encrypted"}) + def add(x, y): + x = tfhers.to_native(x) + y = tfhers.to_native(y) + return tfhers.from_native(x + y, TFHERS_UINT_8_3_2_4096) + + +@fhe.module() +class AddModuleTwoFunc: + func_count = 2 + + @fhe.function({"x": "encrypted", "y": "encrypted"}) + def add(x, y): + x = tfhers.to_native(x) + y = tfhers.to_native(y) + return tfhers.from_native(x + y, TFHERS_UINT_8_3_2_4096) + + @fhe.function({"x": "encrypted"}) + def inc(x): + return x + 1 + + +@pytest.mark.parametrize( + "module, func_count, parameters, tfhers_value_range", + [ + pytest.param( + AddModuleOneFunc, + 1, + { + "x": {"range": [0, 2**6], "status": "encrypted"}, + "y": {"range": [0, 2**6], "status": "encrypted"}, + }, + [0, 2**6], + id="AddModuleOneFunc", + ), + pytest.param( + AddModuleTwoFunc, + 2, + { + "x": {"range": [0, 2**6], "status": "encrypted"}, + "y": {"range": [0, 2**6], "status": "encrypted"}, + }, + [0, 2**6], + id="AddModuleTwoFunc", + ), + ], +) +def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen_with_modules( + module, func_count, parameters, tfhers_value_range, helpers +): + """ + Test different operations wrapped by tfhers conversion (2 tfhers inputs). + + Encryption/decryption are done in Rust using TFHErs, while Keygen is done in Concrete. + + We use modules. + """ + + # global dtype to use + dtype = TFHERS_UINT_8_3_2_4096 + # global function to use + function = lambda x, y: x + y + + # there is no point of using the cache here as new keys will be generated everytime + config = helpers.configuration().fork( + use_insecure_key_cache=False, insecure_key_cache_location=None + ) + + # Only valid when running in multi + if config.parameter_selection_strategy != fhe.ParameterSelectionStrategy.MULTI: + return + + inputset = [ + tuple(tfhers.TFHERSInteger(dtype, arg) for arg in inpt) + for inpt in helpers.generate_inputset(parameters) + ] + if func_count == 1: + add_module = module.compile({"add": inputset}, config) + else: + assert func_count == 2 + add_module = module.compile({"add": inputset, "inc": [(i,) for i in range(10)]}, config) + + assert is_input_and_output_tfhers( + add_module, + dtype.params.polynomial_size, + [0, 1], + [ + 0, + ], + ) + + sample = helpers.generate_sample(parameters) + + ###### TFHErs Keygen ########################################################## + _, client_key_path = tempfile.mkstemp() + _, server_key_path = tempfile.mkstemp() + _, sk_path = tempfile.mkstemp() + + tfhers_utils = ( + f"{os.path.dirname(os.path.abspath(__file__))}/../tfhers-utils/target/release/tfhers_utils" + ) + + assert ( + os.system( + f"{tfhers_utils} keygen -s {server_key_path} -c {client_key_path} --output-lwe-sk {sk_path}" + ) + == 0 + ) + + ###### Concrete Keygen ######################################################## + tfhers_bridge = tfhers.new_bridge(add_module) + + with open(sk_path, "rb") as f: + sk_buff = f.read() + + if func_count == 1: + # set sk for input 0 and generate the remaining keys + tfhers_bridge.keygen_with_initial_keys({0: sk_buff}, force=True) + else: + assert func_count == 2 + with pytest.raises(RuntimeError, match="Module contains more than one function"): + tfhers_bridge.keygen_with_initial_keys({0: sk_buff}, force=True) + tfhers_bridge.keygen_with_initial_keys({("add", 0): sk_buff}, force=True) + + ###### Full Concrete Execution ################################################ + concrete_encoded_sample = (dtype.encode(v) for v in sample) + concrete_encoded_result = add_module.add.encrypt_run_decrypt(*concrete_encoded_sample) + assert (dtype.decode(concrete_encoded_result) == function(*sample)).all() + + ###### TFHErs Encryption ###################################################### + + # encrypt inputs + ct1, ct2 = sample + _, ct1_path = tempfile.mkstemp() + _, ct2_path = tempfile.mkstemp() + + tfhers_utils = ( + f"{os.path.dirname(os.path.abspath(__file__))}/../tfhers-utils/target/release/tfhers_utils" + ) + assert ( + os.system(f"{tfhers_utils} encrypt-with-key --value={ct1} -c {ct1_path} --lwe-sk {sk_path}") + == 0 + ) + assert ( + os.system(f"{tfhers_utils} encrypt-with-key --value={ct2} -c {ct2_path} --lwe-sk {sk_path}") + == 0 + ) + + # import ciphertexts and run + cts = [] + with open(ct1_path, "rb") as f: + buff = f.read() + if func_count == 1: + cts.append(tfhers_bridge.import_value(buff, 0)) + else: + assert func_count == 2 + with pytest.raises(RuntimeError, match="Module contains more than one function"): + cts.append(tfhers_bridge.import_value(buff, 0)) + cts.append(tfhers_bridge.import_value(buff, 0, func_name="add")) + with open(ct2_path, "rb") as f: + buff = f.read() + if func_count == 1: + cts.append(tfhers_bridge.import_value(buff, 1)) + else: + assert func_count == 2 + with pytest.raises(RuntimeError, match="Module contains more than one function"): + cts.append(tfhers_bridge.import_value(buff, 1)) + cts.append(tfhers_bridge.import_value(buff, 1, func_name="add")) + os.remove(ct1_path) + os.remove(ct2_path) + + tfhers_encrypted_result = add_module.add.run(*cts) + + # concrete decryption should work + decrypted = add_module.add.decrypt(tfhers_encrypted_result) + assert (dtype.decode(decrypted) == function(*sample)).all() # type: ignore + + # tfhers decryption + if func_count == 1: + buff = tfhers_bridge.export_value(tfhers_encrypted_result, output_idx=0) # type: ignore + else: + assert func_count == 2 + with pytest.raises(RuntimeError, match="Module contains more than one function"): + buff = tfhers_bridge.export_value(tfhers_encrypted_result, output_idx=0) # type: ignore + buff = tfhers_bridge.export_value(tfhers_encrypted_result, output_idx=0, func_name="add") # type: ignore + _, ct_out_path = tempfile.mkstemp() + _, pt_path = tempfile.mkstemp() + with open(ct_out_path, "wb") as f: + f.write(buff) + + assert ( + os.system( + f"{tfhers_utils} decrypt-with-key" f" -c {ct_out_path} --lwe-sk {sk_path} -p {pt_path}" + ) + == 0 + ) + + with open(pt_path, "r", encoding="utf-8") as f: + result = int(f.read()) + assert result == function(*sample) + + ###### Compute with TFHErs #################################################### + _, random_ct_path = tempfile.mkstemp() + _, sum_ct_path = tempfile.mkstemp() + + # encrypt random value + random_value = np.random.randint(*tfhers_value_range) + assert ( + os.system( + f"{tfhers_utils} encrypt-with-key --value={random_value} -c {random_ct_path} --client-key {client_key_path}" + ) + == 0 + ) + + # add random value to the result ct + assert ( + os.system( + f"{tfhers_utils} add -c {ct_out_path} {random_ct_path} -s {server_key_path} -o {sum_ct_path}" + ) + == 0 + ) + + # decrypt result + assert ( + os.system( + f"{tfhers_utils} decrypt-with-key -c {sum_ct_path} --lwe-sk {sk_path} -p {pt_path}" + ) + == 0 + ) + + with open(pt_path, "r", encoding="utf-8") as f: + tfhers_result = int(f.read()) + assert result + random_value == tfhers_result + + # close remaining tempfiles + os.remove(client_key_path) + os.remove(server_key_path) + os.remove(sk_path) + os.remove(ct_out_path) + os.remove(pt_path) + os.remove(random_ct_path) + os.remove(sum_ct_path) + + @pytest.mark.parametrize( "function, parameters, tfhers_value_range, dtype", [