Skip to content

Commit

Permalink
feat(frontend): support modules in TFHE-rs bridge
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Dec 18, 2024
1 parent 04d7fb2 commit fda7af5
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 82 deletions.
234 changes: 153 additions & 81 deletions frontends/concrete-python/concrete/fhe/tfhers/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,100 +3,118 @@
"""

# pylint: disable=import-error,no-member,no-name-in-module
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

from concrete.compiler import LweSecretKey, TfhersExporter, TfhersFheIntDescription

from concrete import fhe
import concrete.fhe as fhe
from concrete.fhe.compilation.value import Value

from .dtypes import EncryptionKeyChoice, TFHERSIntegerType


class Bridge:
"""TFHErs Bridge extend a Circuit with TFHErs functionalities.
input_types (List[Optional[TFHERSIntegerType]]): maps every input to a type. None means
a non-tfhers type
output_types (List[Optional[TFHERSIntegerType]]): maps every output to a type. None means
a non-tfhers type
input_shapes (List[Optional[Tuple[int, ...]]]): maps every input to a shape. None means
a non-tfhers type
output_shapes (List[Optional[Tuple[int, ...]]]): maps every output to a shape. None means
a non-tfhers type
"""TFHErs Bridge extend an Module with TFHErs functionalities.
input_types_per_func (Dict[str, List[Optional[TFHERSIntegerType]]]):
maps every input to a type for every function in the module. None means a non-tfhers type
output_types_per_func (Dict[str, List[Optional[TFHERSIntegerType]]]):
maps every output to a type for every function in the module. None means a non-tfhers type
input_shapes_per_func (Dict[str, List[Optional[Tuple[int, ...]]]]):
maps every input to a shape for every function in the module. None means a non-tfhers type
output_shapes_per_func (Dict[str, List[Optional[Tuple[int, ...]]]]):
maps every output to a shape for every function in the module. None means a non-tfhers type
"""

circuit: "fhe.Circuit"
input_types: List[Optional[TFHERSIntegerType]]
output_types: List[Optional[TFHERSIntegerType]]
input_shapes: List[Optional[Tuple[int, ...]]]
output_shapes: List[Optional[Tuple[int, ...]]]
module: "fhe.Module"
default_function: Optional[str]
input_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]]
output_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]]
input_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]]
output_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]]

def __init__(
self,
circuit: "fhe.Circuit",
input_types: List[Optional[TFHERSIntegerType]],
output_types: List[Optional[TFHERSIntegerType]],
input_shapes: List[Optional[Tuple[int, ...]]],
output_shapes: List[Optional[Tuple[int, ...]]],
module: "fhe.Module",
input_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]],
output_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]],
input_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]],
output_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]],
):
self.circuit = circuit
self.input_types = input_types
self.output_types = output_types
self.input_shapes = input_shapes
self.output_shapes = output_shapes
if module.function_count == 1:
self.default_function = list(module.graphs.keys())[0]
else:
self.default_function = None
self.module = module
self.input_types_per_func = input_types_per_func
self.output_types_per_func = output_types_per_func
self.input_shapes_per_func = input_shapes_per_func
self.output_shapes_per_func = output_shapes_per_func

def _get_default_func_or_raise_error(self, calling_func: str) -> str:
if self.default_function is not None:
return self.default_function
else:
raise RuntimeError(
"Module contains more than one function, so please provide 'func_name' while "
f"calling '{calling_func}'"
)

def _input_type(self, input_idx: int) -> Optional[TFHERSIntegerType]:
def _input_type(self, func_name: str, input_idx: int) -> Optional[TFHERSIntegerType]:
"""Return the type of a certain input.
Args:
func_name (str): name of the function the input belongs to
input_idx (int): the input index to get the type of
Returns:
Optional[TFHERSIntegerType]: input type. None means a non-tfhers type
"""
return self.input_types[input_idx]
return self.input_types_per_func[func_name][input_idx]

def _output_type(self, output_idx: int) -> Optional[TFHERSIntegerType]:
def _output_type(self, func_name: str, output_idx: int) -> Optional[TFHERSIntegerType]:
"""Return the type of a certain output.
Args:
func_name (str): name of the function the output belongs to
output_idx (int): the output index to get the type of
Returns:
Optional[TFHERSIntegerType]: output type. None means a non-tfhers type
"""
return self.output_types[output_idx]
return self.output_types_per_func[func_name][output_idx]

def _input_shape(self, input_idx: int) -> Optional[Tuple[int, ...]]:
def _input_shape(self, func_name: str, input_idx: int) -> Optional[Tuple[int, ...]]:
"""Return the shape of a certain input.
Args:
func_name (str): name of the function the input belongs to
input_idx (int): the input index to get the shape of
Returns:
Optional[Tuple[int, ...]]: input shape. None means a non-tfhers type
"""
return self.input_shapes[input_idx]
return self.input_shapes_per_func[func_name][input_idx]

def _output_shape(self, output_idx: int) -> Optional[Tuple[int, ...]]: # pragma: no cover
def _output_shape(
self, func_name: str, output_idx: int
) -> Optional[Tuple[int, ...]]: # pragma: no cover
"""Return the shape of a certain output.
Args:
func_name (str): name of the function the output belongs to
output_idx (int): the output index to get the shape of
Returns:
Optional[Tuple[int, ...]]: output shape. None means a non-tfhers type
"""
return self.output_shapes[output_idx]
return self.output_shapes_per_func[func_name][output_idx]

def _input_keyid(self, input_idx: int) -> int:
return self.circuit.client.specs.program_info.input_keyid_at(
input_idx, self.circuit.function_name
)
def _input_keyid(self, func_name: str, input_idx: int) -> int:
return self.module.client.specs.program_info.input_keyid_at(input_idx, func_name)

def _input_variance(self, input_idx: int) -> float:
input_type = self._input_type(input_idx)
def _input_variance(self, func_name: str, input_idx: int) -> float:
input_type = self._input_type(func_name, input_idx)
if input_type is None: # pragma: no cover
msg = "input at 'input_idx' is not a TFHErs value"
raise ValueError(msg)
Expand Down Expand Up @@ -133,38 +151,48 @@ def _description_from_type(
ks_first,
)

def import_value(self, buffer: bytes, input_idx: int) -> Value:
def import_value(self, buffer: bytes, input_idx: int, func_name: Optional[str] = None) -> Value:
"""Import a serialized TFHErs integer as a Value.
Args:
buffer (bytes): serialized integer
input_idx (int): the index of the input expecting this value
func_name (Optional[str]): name of the function the value belongs to.
Doesn't need to be provided if there is a single function.
Returns:
fhe.TransportValue: imported value
"""
input_type = self._input_type(input_idx)
input_shape = self._input_shape(input_idx)
if func_name is None:
func_name = self._get_default_func_or_raise_error("import_value")

input_type = self._input_type(func_name, input_idx)
input_shape = self._input_shape(func_name, input_idx)
if input_type is None or input_shape is None: # pragma: no cover
msg = "input at 'input_idx' is not a TFHErs value"
raise ValueError(msg)

fheint_desc = self._description_from_type(input_type)
keyid = self._input_keyid(input_idx)
variance = self._input_variance(input_idx)
keyid = self._input_keyid(func_name, input_idx)
variance = self._input_variance(func_name, input_idx)
return Value(TfhersExporter.import_int(buffer, fheint_desc, keyid, variance, input_shape))

def export_value(self, value: Value, output_idx: int) -> bytes:
def export_value(self, value: Value, output_idx: int, func_name: Optional[str] = None) -> bytes:
"""Export a value as a serialized TFHErs integer.
Args:
value (TransportValue): value to export
output_idx (int): the index corresponding to this output
func_name (Optional[str]): name of the function the value belongs to.
Doesn't need to be provided if there is a single function.
Returns:
bytes: serialized fheuint8
"""
output_type = self._output_type(output_idx)
if func_name is None:
func_name = self._get_default_func_or_raise_error("export_value")

output_type = self._output_type(func_name, output_idx)
if output_type is None: # pragma: no cover
msg = "output at 'output_idx' is not a TFHErs value"
raise ValueError(msg)
Expand All @@ -174,26 +202,31 @@ def export_value(self, value: Value, output_idx: int) -> bytes:
value._inner, fheint_desc # pylint: disable=protected-access
)

def serialize_input_secret_key(self, input_idx: int) -> bytes:
def serialize_input_secret_key(self, input_idx: int, func_name: Optional[str] = None) -> bytes:
"""Serialize secret key used for a specific input.
Args:
input_idx (int): input index corresponding to the key to serialize
func_name (Optional[str]): name of the function the key belongs to.
Doesn't need to be provided if there is a single function.
Returns:
bytes: serialized key
"""
keyid = self._input_keyid(input_idx)
if func_name is None:
func_name = self._get_default_func_or_raise_error("serialize_input_secret_key")

keyid = self._input_keyid(func_name, input_idx)
# pylint: disable=protected-access
keys = self.circuit.client.keys
keys = self.module.client.keys
assert keys is not None
secret_key = keys._keyset.get_client_keys().get_secret_keys()[keyid] # type: ignore
# pylint: enable=protected-access
return secret_key.serialize()

def keygen_with_initial_keys(
self,
input_idx_to_key_buffer: Dict[int, bytes],
input_idx_to_key_buffer: Dict[Union[Tuple[str, int], int], bytes],
force: bool = False,
seed: Optional[int] = None,
encryption_seed: Optional[int] = None,
Expand All @@ -210,20 +243,35 @@ def keygen_with_initial_keys(
encryption_seed (Optional[int], default = None):
seed for encryption randomness
input_idx_to_key_buffer (Dict[int, bytes]): initial keys to set before keygen
input_idx_to_key_buffer (Dict[Union[Tuple[str, int], int], bytes]):
initial keys to set before keygen. Two possible formats: the first is when you have
a single function. Here you can just provide the position of the input as index.
The second is when you have multiple functions. You will need to provide both the
name of the function and the input's position as index.
Raises:
RuntimeError: if failed to deserialize the key
"""
initial_keys: Dict[int, LweSecretKey] = {}
for input_idx in input_idx_to_key_buffer:
key_id = self._input_keyid(input_idx)
for idx in input_idx_to_key_buffer:
if isinstance(idx, tuple):
func_name, input_idx = idx
elif isinstance(idx, int) and self.default_function is not None:
input_idx = idx
func_name = self.default_function
else:
raise RuntimeError(
"Module contains more than one function, so please make sure to mention "
"the function name (not just the position) in input_idx_to_key_buffer. "
"An example index would be a tuple ('my_func', 1)."
)
key_id = self._input_keyid(func_name, input_idx)
# no need to deserialize the same key again
if key_id in initial_keys: # pragma: no cover
continue

key_buffer = input_idx_to_key_buffer[input_idx]
param = self.circuit.client.specs.program_info.get_keyset_info().secret_keys()[key_id]
param = self.module.client.specs.program_info.get_keyset_info().secret_keys()[key_id]
try:
initial_keys[key_id] = LweSecretKey.deserialize(key_buffer, param)
except Exception as e: # pragma: no cover
Expand All @@ -233,41 +281,65 @@ def keygen_with_initial_keys(
)
raise RuntimeError(msg) from e

self.circuit.keygen(
self.module.keygen(
force=force,
seed=seed,
encryption_seed=encryption_seed,
initial_keys=initial_keys,
)


def new_bridge(circuit: "fhe.Circuit") -> Bridge:
"""Create a TFHErs bridge from a circuit.
def new_bridge(circuit_or_module: Union["fhe.Circuit", "fhe.Module"]) -> Bridge:
"""Create a TFHErs bridge from a circuit or module.
Args:
circuit (Circuit): compiled circuit
circuit (Union[Circuit, Module]): compiled circuit or module
Returns:
Bridge: TFHErs bridge
"""
input_types: List[Optional[TFHERSIntegerType]] = []
input_shapes: List[Optional[Tuple[int, ...]]] = []
for input_node in circuit.graph.ordered_inputs():
if isinstance(input_node.output.dtype, TFHERSIntegerType):
input_types.append(input_node.output.dtype)
input_shapes.append(input_node.output.shape)
else:
input_types.append(None)
input_shapes.append(None)

output_types: List[Optional[TFHERSIntegerType]] = []
output_shapes: List[Optional[Tuple[int, ...]]] = []
for output_node in circuit.graph.ordered_outputs():
if isinstance(output_node.output.dtype, TFHERSIntegerType):
output_types.append(output_node.output.dtype)
output_shapes.append(output_node.output.shape)
else: # pragma: no cover
output_types.append(None)
output_shapes.append(None)

return Bridge(circuit, input_types, output_types, input_shapes, output_shapes)
if isinstance(circuit_or_module, fhe.Module):
module = circuit_or_module
else:
assert isinstance(circuit_or_module, fhe.Circuit)
module = circuit_or_module._module

input_types_per_func = {}
output_types_per_func = {}
input_shapes_per_func = {}
output_shapes_per_func = {}

for func_name, graph in module.graphs.items():
input_types: List[Optional[TFHERSIntegerType]] = []
input_shapes: List[Optional[Tuple[int, ...]]] = []
for input_node in graph.ordered_inputs():
if isinstance(input_node.output.dtype, TFHERSIntegerType):
input_types.append(input_node.output.dtype)
input_shapes.append(input_node.output.shape)
else:
input_types.append(None)
input_shapes.append(None)

input_types_per_func[func_name] = input_types
input_shapes_per_func[func_name] = input_shapes

output_types: List[Optional[TFHERSIntegerType]] = []
output_shapes: List[Optional[Tuple[int, ...]]] = []
for output_node in graph.ordered_outputs():
if isinstance(output_node.output.dtype, TFHERSIntegerType):
output_types.append(output_node.output.dtype)
output_shapes.append(output_node.output.shape)
else: # pragma: no cover
output_types.append(None)
output_shapes.append(None)

output_types_per_func[func_name] = output_types
output_shapes_per_func[func_name] = output_shapes

return Bridge(
module,
input_types_per_func,
output_types_per_func,
input_shapes_per_func,
output_shapes_per_func,
)
2 changes: 1 addition & 1 deletion frontends/concrete-python/examples/tfhers/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def ccompilee():
inputset = [(tfhers_int(120), tfhers_int(120))]
circuit = compiler.compile(inputset)

tfhers_bridge = tfhers.new_bridge(circuit=circuit)
tfhers_bridge = tfhers.new_bridge(circuit)
return circuit, tfhers_bridge


Expand Down

0 comments on commit fda7af5

Please sign in to comment.