Skip to content

Commit

Permalink
chore: take comments into account
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Jun 6, 2024
1 parent 5458d2c commit fd1213a
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 103 deletions.
6 changes: 0 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,6 @@ docs/index.rst
.artifacts
execution_time_of_individual_pytest_files.txt

# Docs: Advance Examples MNIST data
docs/advanced_examples/data/MNIST/

# Docs: Advance Examples FHE training deployment files
docs/advanced_examples/fhe_training

# Hybrid model artifacts
use_case_examples/hybrid_model/clients/
use_case_examples/hybrid_model/compiled_models/
Expand Down
5 changes: 5 additions & 0 deletions docs/advanced_examples/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# MNIST data
data/MNIST/

# FHE training deployment files
fhe_training/
52 changes: 27 additions & 25 deletions docs/advanced_examples/LogisticRegressionTraining.ipynb

Large diffs are not rendered by default.

48 changes: 0 additions & 48 deletions src/concrete/ml/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import onnx
import torch
from concrete.fhe import Exactness
from concrete.fhe import Value as EncryptedValue
from concrete.fhe.dtypes import Integer
from sklearn.base import is_classifier, is_regressor

Expand Down Expand Up @@ -681,50 +680,3 @@ def process_rounding_threshold_bits(rounding_threshold_bits):
rounding_threshold_bits = {"n_bits": n_bits_rounding, "method": method}

return rounding_threshold_bits


def serialize_encrypted_values(
*values_enc: Optional[EncryptedValue],
) -> Union[Optional[bytes], Optional[Tuple[bytes]]]:
"""Serialize encrypted values.
If a value is None, None is returned.
Args:
values_enc (Optional[EncryptedValue]): The values to serialize.
Returns:
Union[Optional[bytes], Optional[Tuple[bytes]]]: The serialized values.
"""
values_enc_serialized = tuple(
value_enc.serialize() if value_enc is not None else None for value_enc in values_enc
)

if len(values_enc_serialized) == 1:
return values_enc_serialized[0]

return values_enc_serialized


def deserialize_encrypted_values(
*values_serialized: Optional[bytes],
) -> Union[Optional[EncryptedValue], Optional[Tuple[EncryptedValue]]]:
"""Deserialize encrypted values.
If a value is None, None is returned.
Args:
values_serialized (Optional[bytes]): The values to deserialize.
Returns:
Union[Optional[EncryptedValue], Optional[Tuple[EncryptedValue]]]: The deserialized values.
"""
values_enc = tuple(
EncryptedValue.deserialize(value_serialized) if value_serialized is not None else None
for value_serialized in values_serialized
)

if len(values_enc) == 1:
return values_enc[0]

return values_enc
52 changes: 52 additions & 0 deletions src/concrete/ml/deployment/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Utility functions for deployment."""

from typing import Optional, Tuple, Union

from concrete import fhe


def serialize_encrypted_values(
*values_enc: Optional[fhe.Value],
) -> Union[Optional[bytes], Optional[Tuple[bytes]]]:
"""Serialize encrypted values.
If a value is None, None is returned.
Args:
values_enc (Optional[fhe.Value]): The values to serialize.
Returns:
Union[Optional[bytes], Optional[Tuple[bytes]]]: The serialized values.
"""
values_enc_serialized = tuple(
value_enc.serialize() if value_enc is not None else None for value_enc in values_enc
)

if len(values_enc_serialized) == 1:
return values_enc_serialized[0]

return values_enc_serialized


def deserialize_encrypted_values(
*values_serialized: Optional[bytes],
) -> Union[Optional[fhe.Value], Optional[Tuple[fhe.Value]]]:
"""Deserialize encrypted values.
If a value is None, None is returned.
Args:
values_serialized (Optional[bytes]): The values to deserialize.
Returns:
Union[Optional[fhe.Value], Optional[Tuple[fhe.Value]]]: The deserialized values.
"""
values_enc = tuple(
fhe.Value.deserialize(value_serialized) if value_serialized is not None else None
for value_serialized in values_serialized
)

if len(values_enc) == 1:
return values_enc[0]

return values_enc
48 changes: 28 additions & 20 deletions src/concrete/ml/deployment/fhe_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from ..common.debugging.custom_assert import assert_true
from ..common.serialization.dumpers import dump
from ..common.serialization.loaders import load
from ..common.utils import deserialize_encrypted_values, serialize_encrypted_values, to_tuple
from ..common.utils import to_tuple
from ..quantization import QuantizedModule
from ..version import __version__ as CML_VERSION
from ._utils import deserialize_encrypted_values, serialize_encrypted_values

try:
# 3.8 and above
Expand Down Expand Up @@ -155,13 +157,13 @@ def run(
# TODO: make desr / ser optional
assert_true(self.server is not None, "Model has not been loaded.")

q_data_enc = to_tuple(serialized_encrypted_quantized_data)
input_quant_encrypted = to_tuple(serialized_encrypted_quantized_data)

# Make sure no inputs are None, to avoid any crash in Concrete
assert not any(x is None for x in q_data_enc), "No input values should be None"
assert not any(x is None for x in input_quant_encrypted), "No input values should be None"

inputs_are_serialized = all(isinstance(x, bytes) for x in q_data_enc)
inputs_are_encrypted_values = all(isinstance(x, fhe.Value) for x in q_data_enc)
inputs_are_serialized = all(isinstance(x, bytes) for x in input_quant_encrypted)
inputs_are_encrypted_values = all(isinstance(x, fhe.Value) for x in input_quant_encrypted)

# Make sure inputs are either only serialized values or encrypted values
assert (
Expand All @@ -170,22 +172,24 @@ def run(

# Deserialize the values if they are all serialized
if inputs_are_serialized:
q_data_enc = to_tuple(deserialize_encrypted_values(*q_data_enc))
input_quant_encrypted = to_tuple(deserialize_encrypted_values(*input_quant_encrypted))

# Deserialize the evaluation keys if they are serialized
evaluation_keys = serialized_evaluation_keys
if isinstance(evaluation_keys, bytes):
evaluation_keys = fhe.EvaluationKeys.deserialize(evaluation_keys)

q_result_enc = self.server.run(*q_data_enc, evaluation_keys=evaluation_keys)
result_quant_encrypted = self.server.run(
*input_quant_encrypted, evaluation_keys=evaluation_keys
)

# If inputs were serialized, return serialized values as well
if inputs_are_serialized:
q_result_enc = serialize_encrypted_values(*to_tuple(q_result_enc))
result_quant_encrypted = serialize_encrypted_values(*to_tuple(result_quant_encrypted))

# Mypy complains because the outputs of `serialize_encrypted_values` can be None, but here
# we already made sure this is not the case
return q_result_enc # type: ignore[return-value]
return result_quant_encrypted # type: ignore[return-value]


class FHEModelDev:
Expand Down Expand Up @@ -400,15 +404,15 @@ def quantize_encrypt_serialize(
"""

# Quantize the values
q_x = to_tuple(self.model.quantize_input(*x))
x_quant = to_tuple(self.model.quantize_input(*x))

# Encrypt the values
q_x_enc = to_tuple(self.client.encrypt(*q_x))
x_quant_encrypted = to_tuple(self.client.encrypt(*x_quant))

# Serialize the encrypted values to be sent to the server
q_x_enc_serialized = serialize_encrypted_values(*q_x_enc)
x_quant_encrypted_serialized = serialize_encrypted_values(*x_quant_encrypted)

return q_x_enc_serialized
return x_quant_encrypted_serialized

# We should find a better name for `serialized_encrypted_quantized_result`
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4476
Expand All @@ -425,14 +429,14 @@ def deserialize_decrypt(
Union[Any, Tuple[Any, ...]]: The decrypted and deserialized values.
"""
# Deserialize the encrypted values
q_result_enc = to_tuple(
result_quant_encrypted = to_tuple(
deserialize_encrypted_values(*serialized_encrypted_quantized_result)
)

# Decrypt the values
q_result = self.client.decrypt(*q_result_enc)
result_quant = self.client.decrypt(*result_quant_encrypted)

return q_result
return result_quant

# We should find a better name for `serialized_encrypted_quantized_result`
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4476
Expand All @@ -449,17 +453,21 @@ def deserialize_decrypt_dequantize(
Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]: The clear float values.
"""
# Decrypt and deserialize the values
q_result = to_tuple(self.deserialize_decrypt(*serialized_encrypted_quantized_result))
result_quant = to_tuple(self.deserialize_decrypt(*serialized_encrypted_quantized_result))

# De-quantize the values
f_result = to_tuple(self.model.dequantize_output(*q_result))
result = to_tuple(self.model.dequantize_output(*result_quant))

# Apply post-processing to the de-quantized values
# Side note: `post_processing` method from built-in models (not Quantized Modules) only
# handles a single input. Calling the following is however not an issue for now as we expect
# 'f_result' to be a tuple of length 1 in this case anyway. Still, we need to make sure this
# 'result' to be a tuple of length 1 in this case anyway. Still, we need to make sure this
# does not break in the future if any built-in models starts to handle multiple outputs :
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4474
result = self.model.post_processing(*f_result)
assert len(result) == 1 or isinstance(
self.model, QuantizedModule
), "Only 'QuantizedModule' instances can handle multi-outputs."

result = self.model.post_processing(*result)

return result
8 changes: 5 additions & 3 deletions src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(
# Concrete ML attributes for FHE training
# These values are hardcoded for now
# We don't expose them in the __init__ arguments but they are taken
# into account when training, so wecan just modify them manually.
# into account when training, so we can just modify them manually.
# The number of bits used for training should be adjusted according to n-bits
# but for now we use this hardcoded values.
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4205
Expand Down Expand Up @@ -433,7 +433,9 @@ def _fit_encrypted(
The is the underlying function that fits the model in FHE if 'fit_encrypted' is enabled.
A quantized module is first built in order to generate the FHE circuit need for training.
Then, the method iterates over it in the clear.
Then, the method iterates over it in the clear so that outputs of an iteration are used as
inputs for the following iteration. Thanks to Concrete's composition feature, no
encryption/decryption steps are needed when the training is executed in FHE.
For more details on some of these arguments please refer to:
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html
Expand Down Expand Up @@ -603,7 +605,7 @@ def _fit_encrypted(
# A partial fit is similar to running a fit with a single iteration
max_iter = 1 if is_partial_fit else self.max_iter

# Quantize and encrypt the batches
# Iterate on the batches in order to quantize and encrypt them
X_batches_enc, y_batches_enc = [], []
for _ in range(max_iter):

Expand Down
2 changes: 1 addition & 1 deletion src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def remote_call(self, x: torch.Tensor) -> torch.Tensor: # pragma:no cover
inferences.append(decrypted_prediction)

# Concatenate results and move them back to proper device
return torch.Tensor(inferences).to(device=base_device)
return torch.Tensor(numpy.array(inferences)).to(device=base_device)


class HybridFHEModel:
Expand Down

0 comments on commit fd1213a

Please sign in to comment.