Skip to content

Commit

Permalink
test(frontend-python): Add minimal test of public key encryption API
Browse files Browse the repository at this point in the history
  • Loading branch information
BourgerieQuentin committed May 14, 2024
1 parent 94d3a4b commit 8b3590d
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ def deserialize(serialized_key_set: bytes) -> "PublicKeySet":
raise TypeError(
f"serialized_key_set must be of type bytes, not {type(serialized_key_set)}"
)
return PublicKeySet.wrap(_PublicKeySet.deserialize(serialized_key_set))
return PublicKeySet.wrap(_PublicKeySet.deserialize(serialized_key_set))
8 changes: 7 additions & 1 deletion frontends/concrete-python/concrete/fhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

# pylint: disable=import-error,no-name-in-module

from concrete.compiler import EvaluationKeys, Parameter, PublicArguments, PublicResult, PublicKeyKind
from concrete.compiler import (
EvaluationKeys,
Parameter,
PublicArguments,
PublicKeyKind,
PublicResult,
)

from .compilation import (
DEFAULT_GLOBAL_P_ERROR,
Expand Down
7 changes: 4 additions & 3 deletions frontends/concrete-python/concrete/fhe/compilation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import List, Optional, Tuple, Union

import numpy as np
from concrete.compiler import EvaluationKeys, ValueDecrypter, PublicKeySet
from concrete.compiler import EvaluationKeys, PublicKeySet, ValueDecrypter

from .keys import Keys
from .specs import ClientSpecs
Expand Down Expand Up @@ -139,8 +139,9 @@ def encrypt(

self.keygen(force=False)
keyset = self.keys._keyset # pylint: disable=protected-access
return ValueExporter.new_private(keyset, self.specs.client_parameters, function_name).encrypt(*args)

return ValueExporter.new_private(
keyset, self.specs.client_parameters, function_name
).encrypt(*args)

def decrypt(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
from typing import List, Optional, Tuple, Union, get_type_hints

import numpy as np
from mlir._mlir_libs._concretelang._compiler import PublicKeyKind

from ..dtypes import Integer
from ..representation import GraphProcessor
from ..values import ValueDescription
from .utils import friendly_type_format

from mlir._mlir_libs._concretelang._compiler import (
PublicKeyKind
)

MAXIMUM_TLU_BIT_WIDTH = 16

DEFAULT_P_ERROR = None
Expand Down Expand Up @@ -1060,7 +1057,7 @@ def __init__(
enable_tlu_fusing: bool = True,
print_tlu_fusing: bool = False,
optimize_tlu_based_on_original_bit_width: Union[bool, int] = 8,
with_public_keys: PublicKeyKind = PublicKeyKind.NONE
with_public_keys: PublicKeyKind = PublicKeyKind.NONE,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,37 @@
# pylint: disable=import-error,no-name-in-module

import json
from typing import Optional, Union, List, Tuple, Dict
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from concrete.compiler import ValueExporter as _ValueExporter, SimulatedValueExporter, KeySet, PublicKeySet, ClientParameters
from concrete.compiler import ClientParameters, KeySet, PublicKeySet, SimulatedValueExporter
from concrete.compiler import ValueExporter as _ValueExporter

from ..dtypes import SignedInteger, UnsignedInteger
from .value import Value
from ..values import ValueDescription

from .value import Value

# pylint: enable=import-error,no-name-in-module


class ValueExporter:

_exporter: Union[_ValueExporter, SimulatedValueExporter]
_client_parameters: ClientParameters
_function_name: str

def __init__(self, client_parameters: ClientParameters, exporter: Union[_ValueExporter, SimulatedValueExporter], function_name: str):

if not (isinstance(exporter, _ValueExporter) or isinstance(exporter, SimulatedValueExporter)) :
def __init__(
self,
client_parameters: ClientParameters,
exporter: Union[_ValueExporter, SimulatedValueExporter],
function_name: str,
):
if not (
isinstance(exporter, _ValueExporter) or isinstance(exporter, SimulatedValueExporter)
):
raise TypeError(
f"value_exporter must be of type SimulatedValueExporter or ValueExporter, not {type(exporter)}"
)
if not isinstance(client_parameters, ClientParameters) :
if not isinstance(client_parameters, ClientParameters):
raise TypeError(
f"client_parameters must be of type ClientParameters or ValueExporter, not {type(client_parameters)}"
)
Expand All @@ -40,7 +45,7 @@ def __init__(self, client_parameters: ClientParameters, exporter: Union[_ValueEx

def new_private(keyset: KeySet, client_parameters: ClientParameters, function_name: str):
"""
Create a new value exporter for private encryption
Create a new value exporter for private encryption
Args:
function_name (str):
Expand All @@ -50,11 +55,15 @@ def new_private(keyset: KeySet, client_parameters: ClientParameters, function_na
Optional[Union[Value, Tuple[Optional[Value], ...]]]:
encrypted argument(s) for evaluation
"""
return ValueExporter(client_parameters, _ValueExporter.new(keyset, client_parameters, function_name), function_name)
return ValueExporter(
client_parameters,
_ValueExporter.new(keyset, client_parameters, function_name),
function_name,
)

def new_public(keyset: PublicKeySet, client_parameters: ClientParameters, function_name: str):
"""
Create a new value exporter for private encryption
Create a new value exporter for private encryption
Args:
function_name (str):
Expand All @@ -64,11 +73,15 @@ def new_public(keyset: PublicKeySet, client_parameters: ClientParameters, functi
Optional[Union[Value, Tuple[Optional[Value], ...]]]:
encrypted argument(s) for evaluation
"""
return ValueExporter(client_parameters, _ValueExporter.new_public(keyset, client_parameters, function_name), function_name)
return ValueExporter(
client_parameters,
_ValueExporter.new_public(keyset, client_parameters, function_name),
function_name,
)

def new_simulated(client_parameters: ClientParameters, function_name: str):
"""
Create a new value exporter for simulate encryption
Create a new value exporter for simulate encryption
Args:
function_name (str):
Expand All @@ -78,15 +91,21 @@ def new_simulated(client_parameters: ClientParameters, function_name: str):
Optional[Union[Value, Tuple[Optional[Value], ...]]]:
encrypted argument(s) for evaluation
"""
return ValueExporter(client_parameters, SimulatedValueExporter.new(client_parameters, function_name), function_name)
return ValueExporter(
client_parameters,
SimulatedValueExporter.new(client_parameters, function_name),
function_name,
)

def encrypt(
self,
*args: Optional[Union[int, np.ndarray, List]],
function_name: str = "main",
) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]:
print(args)
ordered_sanitized_args = ValueExporter._validate_input_args(self._client_parameters, *args, function_name=function_name)
ordered_sanitized_args = ValueExporter._validate_input_args(
self._client_parameters, *args, function_name=function_name
)
print(ordered_sanitized_args)
exported = [
None
Expand All @@ -100,21 +119,21 @@ def encrypt(
]

return tuple(exported) if len(exported) != 1 else exported[0]

def _validate_input_args(
client_parameters: ClientParameters,
*args: Optional[Union[int, np.ndarray, List]],
function_name: str = "main",
) -> List[Optional[Union[int, np.ndarray]]]:
"""Validate input arguments.
Args:
client_specs (ClientSpecs):
client specification
*args (Optional[Union[int, np.ndarray, List]]):
argument(s) for evaluation
function_name (str): name of the function to verify
Returns:
List[Optional[Union[int, np.ndarray]]]: ordered validated args
"""
Expand All @@ -127,20 +146,20 @@ def _validate_input_args(
if len(args) != len(input_specs):
message = f"Expected {len(input_specs)} inputs but got {len(args)}"
raise ValueError(message)

sanitized_args: Dict[int, Optional[Union[int, np.ndarray]]] = {}
for index, (arg, spec) in enumerate(zip(args, input_specs)):
if arg is None:
sanitized_args[index] = None
continue

if isinstance(arg, list):
arg = np.array(arg)

is_valid = isinstance(arg, (int, np.integer)) or (
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
)

if "lweCiphertext" in spec["typeInfo"].keys():
type_info = spec["typeInfo"]["lweCiphertext"]
is_encrypted = True
Expand All @@ -157,38 +176,40 @@ def _validate_input_args(
else:
message = f"Expected a valid type in {spec['typeInfo'].keys()}"
raise ValueError(message)

expected_dtype = SignedInteger(width) if is_signed else UnsignedInteger(width)
expected_value = ValueDescription(expected_dtype, shape, is_encrypted)
if is_valid:
expected_min = expected_dtype.min()
expected_max = expected_dtype.max()

if not is_encrypted:
# clear integers are signless
# (e.g., 8-bit clear integer can be in range -128, 255)
expected_min = -(expected_max // 2) - 1

actual_min = arg if isinstance(arg, int) else arg.min()
actual_max = arg if isinstance(arg, int) else arg.max()
actual_shape = () if isinstance(arg, int) else arg.shape

is_valid = (
actual_min >= expected_min
and actual_max <= expected_max
and actual_shape == expected_value.shape
)

if is_valid:
sanitized_args[index] = arg

if not is_valid:
try:
actual_value = str(ValueDescription.of(arg, is_encrypted=is_encrypted))
except ValueError:
actual_value = type(arg).__name__
message = f"Expected argument {index} to be {expected_value} but it's {actual_value}"
message = (
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
)
raise ValueError(message)

ordered_sanitized_args = [sanitized_args[i] for i in range(len(sanitized_args))]
return ordered_sanitized_args
return ordered_sanitized_args
14 changes: 14 additions & 0 deletions frontends/concrete-python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def check_execution(
sample: Union[Any, List[Any]],
retries: int = 1,
only_simulation: bool = False,
with_public_keys: bool = False,
):
"""
Assert that `circuit` behaves the same as `function` on `sample`.
Expand All @@ -282,6 +283,9 @@ def check_execution(
only_simulation (bool, default = False):
whether to just check simulation but not execution
with_public_keys (bool, default = False):
whether to check with public key encryption
"""
if not isinstance(sample, list):
sample = [sample]
Expand All @@ -304,6 +308,16 @@ def sanitize(values):
if not only_simulation:
for i in range(retries):
expected = sanitize(function(*deepcopy(sample)))
encrypter = (
fhe.ValueExporter.new_public(
circuit.keys.public_key_set,
circuit.server.client_specs.client_parameters,
"main",
)
if with_public_keys
else circuit
)
encrypted_args = encrypter.encrypt(*deepcopy(sample))
actual = sanitize(circuit.encrypt_run_decrypt(*deepcopy(sample)))

if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
Expand Down
55 changes: 55 additions & 0 deletions frontends/concrete-python/tests/execution/test_public_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
import pytest

from concrete import fhe


@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x, y: x + y,
id="x + y",
),
],
)
@pytest.mark.parametrize(
"parameters",
[
{
"x": {"range": [0, 60], "status": "clear"},
"y": {"range": [0, 60], "status": "encrypted"},
},
{
"x": {"range": [0, 60], "status": "encrypted"},
"y": {"range": [0, 60], "status": "clear"},
},
{
"x": {"range": [0, 60], "status": "encrypted"},
"y": {"range": [0, 60], "status": "encrypted"},
},
{
"x": {"range": [0, 60], "status": "clear", "shape": (3,)},
"y": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
"y": {"range": [0, 60], "status": "clear", "shape": (3,)},
},
],
)
def test_add(function, parameters, helpers):
"""
Test add where both of the operators are dynamic.
"""

parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
configuration.with_public_keys = fhe.PublicKeyKind.COMPACT
compiler = fhe.Compiler(function, parameter_encryption_statuses)

inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
print(circuit.server.client_specs.client_parameters.serialize())
helpers.check_execution(circuit, function, sample, 1, False, True)
2 changes: 1 addition & 1 deletion third_party/llvm-project

0 comments on commit 8b3590d

Please sign in to comment.