From 8b3590dd287fd57e184eb62417dcd3f7de101a90 Mon Sep 17 00:00:00 2001 From: Bourgerie Quentin Date: Fri, 3 May 2024 16:14:27 +0200 Subject: [PATCH] test(frontend-python): Add minimal test of public key encryption API --- .../concrete/compiler/public_key_set.py | 2 +- .../concrete-python/concrete/fhe/__init__.py | 8 +- .../concrete/fhe/compilation/client.py | 7 +- .../concrete/fhe/compilation/configuration.py | 7 +- .../fhe/compilation/value_exporter.py | 85 ++++++++++++------- frontends/concrete-python/tests/conftest.py | 14 +++ .../tests/execution/test_public_keys.py | 55 ++++++++++++ third_party/llvm-project | 2 +- 8 files changed, 137 insertions(+), 43 deletions(-) create mode 100644 frontends/concrete-python/tests/execution/test_public_keys.py diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_key_set.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_key_set.py index 7ccfb42d1f..9eb2d9354a 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_key_set.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_key_set.py @@ -60,4 +60,4 @@ def deserialize(serialized_key_set: bytes) -> "PublicKeySet": raise TypeError( f"serialized_key_set must be of type bytes, not {type(serialized_key_set)}" ) - return PublicKeySet.wrap(_PublicKeySet.deserialize(serialized_key_set)) \ No newline at end of file + return PublicKeySet.wrap(_PublicKeySet.deserialize(serialized_key_set)) diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index a684c5f681..add92bf46e 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -4,7 +4,13 @@ # pylint: disable=import-error,no-name-in-module -from concrete.compiler import EvaluationKeys, Parameter, PublicArguments, PublicResult, PublicKeyKind +from concrete.compiler import ( + EvaluationKeys, + Parameter, + PublicArguments, + PublicKeyKind, + PublicResult, +) from .compilation import ( DEFAULT_GLOBAL_P_ERROR, diff --git a/frontends/concrete-python/concrete/fhe/compilation/client.py b/frontends/concrete-python/concrete/fhe/compilation/client.py index 334a95a54e..907c4e2f1f 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/client.py +++ b/frontends/concrete-python/concrete/fhe/compilation/client.py @@ -10,7 +10,7 @@ from typing import List, Optional, Tuple, Union import numpy as np -from concrete.compiler import EvaluationKeys, ValueDecrypter, PublicKeySet +from concrete.compiler import EvaluationKeys, PublicKeySet, ValueDecrypter from .keys import Keys from .specs import ClientSpecs @@ -139,8 +139,9 @@ def encrypt( self.keygen(force=False) keyset = self.keys._keyset # pylint: disable=protected-access - return ValueExporter.new_private(keyset, self.specs.client_parameters, function_name).encrypt(*args) - + return ValueExporter.new_private( + keyset, self.specs.client_parameters, function_name + ).encrypt(*args) def decrypt( self, diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index de0fab9bc7..042b81f6de 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -9,16 +9,13 @@ from typing import List, Optional, Tuple, Union, get_type_hints import numpy as np +from mlir._mlir_libs._concretelang._compiler import PublicKeyKind from ..dtypes import Integer from ..representation import GraphProcessor from ..values import ValueDescription from .utils import friendly_type_format -from mlir._mlir_libs._concretelang._compiler import ( - PublicKeyKind -) - MAXIMUM_TLU_BIT_WIDTH = 16 DEFAULT_P_ERROR = None @@ -1060,7 +1057,7 @@ def __init__( enable_tlu_fusing: bool = True, print_tlu_fusing: bool = False, optimize_tlu_based_on_original_bit_width: Union[bool, int] = 8, - with_public_keys: PublicKeyKind = PublicKeyKind.NONE + with_public_keys: PublicKeyKind = PublicKeyKind.NONE, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode diff --git a/frontends/concrete-python/concrete/fhe/compilation/value_exporter.py b/frontends/concrete-python/concrete/fhe/compilation/value_exporter.py index 24bfce8c3b..35247362a5 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/value_exporter.py +++ b/frontends/concrete-python/concrete/fhe/compilation/value_exporter.py @@ -5,32 +5,37 @@ # pylint: disable=import-error,no-name-in-module import json -from typing import Optional, Union, List, Tuple, Dict +from typing import Dict, List, Optional, Tuple, Union import numpy as np -from concrete.compiler import ValueExporter as _ValueExporter, SimulatedValueExporter, KeySet, PublicKeySet, ClientParameters +from concrete.compiler import ClientParameters, KeySet, PublicKeySet, SimulatedValueExporter +from concrete.compiler import ValueExporter as _ValueExporter from ..dtypes import SignedInteger, UnsignedInteger -from .value import Value from ..values import ValueDescription - +from .value import Value # pylint: enable=import-error,no-name-in-module class ValueExporter: - _exporter: Union[_ValueExporter, SimulatedValueExporter] _client_parameters: ClientParameters _function_name: str - def __init__(self, client_parameters: ClientParameters, exporter: Union[_ValueExporter, SimulatedValueExporter], function_name: str): - - if not (isinstance(exporter, _ValueExporter) or isinstance(exporter, SimulatedValueExporter)) : + def __init__( + self, + client_parameters: ClientParameters, + exporter: Union[_ValueExporter, SimulatedValueExporter], + function_name: str, + ): + if not ( + isinstance(exporter, _ValueExporter) or isinstance(exporter, SimulatedValueExporter) + ): raise TypeError( f"value_exporter must be of type SimulatedValueExporter or ValueExporter, not {type(exporter)}" ) - if not isinstance(client_parameters, ClientParameters) : + if not isinstance(client_parameters, ClientParameters): raise TypeError( f"client_parameters must be of type ClientParameters or ValueExporter, not {type(client_parameters)}" ) @@ -40,7 +45,7 @@ def __init__(self, client_parameters: ClientParameters, exporter: Union[_ValueEx def new_private(keyset: KeySet, client_parameters: ClientParameters, function_name: str): """ - Create a new value exporter for private encryption + Create a new value exporter for private encryption Args: function_name (str): @@ -50,11 +55,15 @@ def new_private(keyset: KeySet, client_parameters: ClientParameters, function_na Optional[Union[Value, Tuple[Optional[Value], ...]]]: encrypted argument(s) for evaluation """ - return ValueExporter(client_parameters, _ValueExporter.new(keyset, client_parameters, function_name), function_name) + return ValueExporter( + client_parameters, + _ValueExporter.new(keyset, client_parameters, function_name), + function_name, + ) def new_public(keyset: PublicKeySet, client_parameters: ClientParameters, function_name: str): """ - Create a new value exporter for private encryption + Create a new value exporter for private encryption Args: function_name (str): @@ -64,11 +73,15 @@ def new_public(keyset: PublicKeySet, client_parameters: ClientParameters, functi Optional[Union[Value, Tuple[Optional[Value], ...]]]: encrypted argument(s) for evaluation """ - return ValueExporter(client_parameters, _ValueExporter.new_public(keyset, client_parameters, function_name), function_name) + return ValueExporter( + client_parameters, + _ValueExporter.new_public(keyset, client_parameters, function_name), + function_name, + ) def new_simulated(client_parameters: ClientParameters, function_name: str): """ - Create a new value exporter for simulate encryption + Create a new value exporter for simulate encryption Args: function_name (str): @@ -78,7 +91,11 @@ def new_simulated(client_parameters: ClientParameters, function_name: str): Optional[Union[Value, Tuple[Optional[Value], ...]]]: encrypted argument(s) for evaluation """ - return ValueExporter(client_parameters, SimulatedValueExporter.new(client_parameters, function_name), function_name) + return ValueExporter( + client_parameters, + SimulatedValueExporter.new(client_parameters, function_name), + function_name, + ) def encrypt( self, @@ -86,7 +103,9 @@ def encrypt( function_name: str = "main", ) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]: print(args) - ordered_sanitized_args = ValueExporter._validate_input_args(self._client_parameters, *args, function_name=function_name) + ordered_sanitized_args = ValueExporter._validate_input_args( + self._client_parameters, *args, function_name=function_name + ) print(ordered_sanitized_args) exported = [ None @@ -100,21 +119,21 @@ def encrypt( ] return tuple(exported) if len(exported) != 1 else exported[0] - + def _validate_input_args( client_parameters: ClientParameters, *args: Optional[Union[int, np.ndarray, List]], function_name: str = "main", ) -> List[Optional[Union[int, np.ndarray]]]: """Validate input arguments. - + Args: client_specs (ClientSpecs): client specification *args (Optional[Union[int, np.ndarray, List]]): argument(s) for evaluation function_name (str): name of the function to verify - + Returns: List[Optional[Union[int, np.ndarray]]]: ordered validated args """ @@ -127,20 +146,20 @@ def _validate_input_args( if len(args) != len(input_specs): message = f"Expected {len(input_specs)} inputs but got {len(args)}" raise ValueError(message) - + sanitized_args: Dict[int, Optional[Union[int, np.ndarray]]] = {} for index, (arg, spec) in enumerate(zip(args, input_specs)): if arg is None: sanitized_args[index] = None continue - + if isinstance(arg, list): arg = np.array(arg) - + is_valid = isinstance(arg, (int, np.integer)) or ( isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer) ) - + if "lweCiphertext" in spec["typeInfo"].keys(): type_info = spec["typeInfo"]["lweCiphertext"] is_encrypted = True @@ -157,38 +176,40 @@ def _validate_input_args( else: message = f"Expected a valid type in {spec['typeInfo'].keys()}" raise ValueError(message) - + expected_dtype = SignedInteger(width) if is_signed else UnsignedInteger(width) expected_value = ValueDescription(expected_dtype, shape, is_encrypted) if is_valid: expected_min = expected_dtype.min() expected_max = expected_dtype.max() - + if not is_encrypted: # clear integers are signless # (e.g., 8-bit clear integer can be in range -128, 255) expected_min = -(expected_max // 2) - 1 - + actual_min = arg if isinstance(arg, int) else arg.min() actual_max = arg if isinstance(arg, int) else arg.max() actual_shape = () if isinstance(arg, int) else arg.shape - + is_valid = ( actual_min >= expected_min and actual_max <= expected_max and actual_shape == expected_value.shape ) - + if is_valid: sanitized_args[index] = arg - + if not is_valid: try: actual_value = str(ValueDescription.of(arg, is_encrypted=is_encrypted)) except ValueError: actual_value = type(arg).__name__ - message = f"Expected argument {index} to be {expected_value} but it's {actual_value}" + message = ( + f"Expected argument {index} to be {expected_value} but it's {actual_value}" + ) raise ValueError(message) - + ordered_sanitized_args = [sanitized_args[i] for i in range(len(sanitized_args))] - return ordered_sanitized_args \ No newline at end of file + return ordered_sanitized_args diff --git a/frontends/concrete-python/tests/conftest.py b/frontends/concrete-python/tests/conftest.py index dc0b288395..33339f52c4 100644 --- a/frontends/concrete-python/tests/conftest.py +++ b/frontends/concrete-python/tests/conftest.py @@ -263,6 +263,7 @@ def check_execution( sample: Union[Any, List[Any]], retries: int = 1, only_simulation: bool = False, + with_public_keys: bool = False, ): """ Assert that `circuit` behaves the same as `function` on `sample`. @@ -282,6 +283,9 @@ def check_execution( only_simulation (bool, default = False): whether to just check simulation but not execution + + with_public_keys (bool, default = False): + whether to check with public key encryption """ if not isinstance(sample, list): sample = [sample] @@ -304,6 +308,16 @@ def sanitize(values): if not only_simulation: for i in range(retries): expected = sanitize(function(*deepcopy(sample))) + encrypter = ( + fhe.ValueExporter.new_public( + circuit.keys.public_key_set, + circuit.server.client_specs.client_parameters, + "main", + ) + if with_public_keys + else circuit + ) + encrypted_args = encrypter.encrypt(*deepcopy(sample)) actual = sanitize(circuit.encrypt_run_decrypt(*deepcopy(sample))) if all(np.array_equal(e, a) for e, a in zip(expected, actual)): diff --git a/frontends/concrete-python/tests/execution/test_public_keys.py b/frontends/concrete-python/tests/execution/test_public_keys.py new file mode 100644 index 0000000000..466864d88c --- /dev/null +++ b/frontends/concrete-python/tests/execution/test_public_keys.py @@ -0,0 +1,55 @@ +import numpy as np +import pytest + +from concrete import fhe + + +@pytest.mark.parametrize( + "function", + [ + pytest.param( + lambda x, y: x + y, + id="x + y", + ), + ], +) +@pytest.mark.parametrize( + "parameters", + [ + { + "x": {"range": [0, 60], "status": "clear"}, + "y": {"range": [0, 60], "status": "encrypted"}, + }, + { + "x": {"range": [0, 60], "status": "encrypted"}, + "y": {"range": [0, 60], "status": "clear"}, + }, + { + "x": {"range": [0, 60], "status": "encrypted"}, + "y": {"range": [0, 60], "status": "encrypted"}, + }, + { + "x": {"range": [0, 60], "status": "clear", "shape": (3,)}, + "y": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + }, + { + "x": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + "y": {"range": [0, 60], "status": "clear", "shape": (3,)}, + }, + ], +) +def test_add(function, parameters, helpers): + """ + Test add where both of the operators are dynamic. + """ + + parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) + configuration = helpers.configuration() + configuration.with_public_keys = fhe.PublicKeyKind.COMPACT + compiler = fhe.Compiler(function, parameter_encryption_statuses) + + inputset = helpers.generate_inputset(parameters) + circuit = compiler.compile(inputset, configuration) + sample = helpers.generate_sample(parameters) + print(circuit.server.client_specs.client_parameters.serialize()) + helpers.check_execution(circuit, function, sample, 1, False, True) diff --git a/third_party/llvm-project b/third_party/llvm-project index f5aec278e8..7a73da8849 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit f5aec278e8dfd68bfb1a15c3abb6f0852cea7046 +Subproject commit 7a73da884993674d8aa8f3f286feaaba719dd593