Skip to content

Commit

Permalink
feat: support embedding layers (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Jul 26, 2024
1 parent c8908fa commit 296bc8c
Show file tree
Hide file tree
Showing 15 changed files with 468 additions and 90 deletions.
1 change: 1 addition & 0 deletions docs/deep-learning/onnx_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ Concrete ML supports the following operators for evaluation and conversion to an
- Mul
- Neg
- Not
- OneHot
- Or
- PRelu
- Pad
Expand Down
4 changes: 4 additions & 0 deletions src/concrete/ml/common/serialization/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from numpy.random import RandomState
from skorch.dataset import ValidSplit

from ...common.utils import Exactness
from ...quantization.base_quantized_op import ALL_QUANTIZED_OPS
from ...quantization.quantized_module import QuantizedModule
from ...quantization.quantizers import (
Expand Down Expand Up @@ -202,6 +203,9 @@ def object_hook(d: Any) -> Any:
model_class.__name__: model_class for model_class in serializable_classes
}

if type_name == "Exactness":
return Exactness(serialized_value)

# If the value reaches this point and the initial object was properly serialized, we
# expect it to be a class from Concrete ML that implements a `load_dict` method
if type_name in SERIALIZABLE_CLASSES:
Expand Down
4 changes: 4 additions & 0 deletions src/concrete/ml/common/serialization/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from concrete import fhe

from ...common.utils import Exactness
from . import USE_SKOPS

# If USE_SKOPS is False or Skops can't be imported, default to pickle
Expand Down Expand Up @@ -214,6 +215,9 @@ def default(self, o: Any) -> Any:
if isinstance(o, tuple):
return dump_name_and_value("tuple", list(o))

if isinstance(o, Exactness):
return dump_name_and_value("Exactness", o.value)

# Dump the numpy integer value along its dtype
if isinstance(o, numpy.integer):
kwargs = {"dtype": str(o.dtype)}
Expand Down
74 changes: 55 additions & 19 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import onnxoptimizer
import torch
from onnx import helper
from typing_extensions import TypeAlias

from ..common.debugging import assert_true
from ..onnx.onnx_model_manipulations import convert_first_gather_to_matmul
from .onnx_utils import (
IMPLEMENTED_ONNX_OPS,
check_onnx_model,
Expand All @@ -21,6 +23,11 @@
get_op_type,
)

NumpyForwardCallable: TypeAlias = Callable[..., Tuple[numpy.ndarray, ...]]
ONNXAndNumpyForwards: TypeAlias = Tuple[
NumpyForwardCallable, Optional[onnx.ModelProto], NumpyForwardCallable, onnx.ModelProto
]

OPSET_VERSION_FOR_ONNX_EXPORT = 14


Expand Down Expand Up @@ -120,7 +127,7 @@ def get_equivalent_numpy_forward_from_torch(
torch_module: torch.nn.Module,
dummy_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
output_onnx_file: Union[None, Path, str] = None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
) -> ONNXAndNumpyForwards:
"""Get the numpy equivalent forward of the provided torch Module.
Args:
Expand All @@ -132,9 +139,8 @@ def get_equivalent_numpy_forward_from_torch(
Defaults to None.
Returns:
Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.GraphProto]: The function that will
execute the equivalent numpy code to the passed torch_module and the generated ONNX
model.
ONNXAndNumpyForwards: The function that will execute the equivalent numpy code to the
passed torch_module and the generated ONNX model.
"""
output_onnx_file_path = Path(
tempfile.mkstemp(suffix=".onnx")[1] if output_onnx_file is None else output_onnx_file
Expand Down Expand Up @@ -165,20 +171,24 @@ def get_equivalent_numpy_forward_from_torch(
if use_tempfile:
output_onnx_file_path.unlink()

equivalent_numpy_forward, equivalent_onnx_model = get_equivalent_numpy_forward_from_onnx(
equivalent_onnx_model
numpy_preprocessing, onnx_preprocessing, equivalent_numpy_forward, equivalent_onnx_model = (
get_equivalent_numpy_forward_from_onnx(equivalent_onnx_model)
)
with output_onnx_file_path.open("wb") as file:
file.write(equivalent_onnx_model.SerializeToString())

return (
numpy_preprocessing,
onnx_preprocessing,
equivalent_numpy_forward,
equivalent_onnx_model,
)


def preprocess_onnx_model(onnx_model: onnx.ModelProto, check_model: bool) -> onnx.ModelProto:
"""Get the numpy equivalent forward of the provided ONNX model.
def preprocess_onnx_model(
onnx_model: onnx.ModelProto, check_model: bool
) -> Tuple[Optional[onnx.ModelProto], onnx.ModelProto]:
"""Preprocess the ONNX model to be used for numpy execution.
Args:
onnx_model (onnx.ModelProto): the ONNX model for which to get the equivalent numpy
Expand All @@ -191,7 +201,9 @@ def preprocess_onnx_model(onnx_model: onnx.ModelProto, check_model: bool) -> onn
model to numpy.
Returns:
onnx.ModelProto: The preprocessed ONNX model.
Tuple[Optional[onnx.ModelProto], onnx.ModelProto]: The preprocessing ONNX model and
preprocessed ONNX model. The preprocessing model is None if there is no preprocessing
required.
"""

# All onnx models should be checked, "check_model" parameter must be removed
Expand Down Expand Up @@ -236,13 +248,21 @@ def preprocess_onnx_model(onnx_model: onnx.ModelProto, check_model: bool) -> onn
f"Available ONNX operators: {', '.join(sorted(IMPLEMENTED_ONNX_OPS))}"
)

return equivalent_onnx_model
# Convert the first Gather node to a matrix multiplication with one-hot encoding
# In FHE, embedding is either a TLU or a matmul with a one-hot.
# The second case allows for leveled operation thus much faster.
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4532
onnx_preprocessing, equivalent_onnx_model = convert_first_gather_to_matmul(
equivalent_onnx_model
)

return onnx_preprocessing, equivalent_onnx_model


def get_equivalent_numpy_forward_from_onnx(
onnx_model: onnx.ModelProto,
check_model: bool = True,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
) -> ONNXAndNumpyForwards:
"""Get the numpy equivalent forward of the provided ONNX model.
Args:
Expand All @@ -252,23 +272,39 @@ def get_equivalent_numpy_forward_from_onnx(
Defaults to True.
Returns:
Callable[..., Tuple[numpy.ndarray, ...]]: The function that will execute
the equivalent numpy function.
ONNXAndNumpyForwards: The function that will execute the equivalent numpy function.
"""

equivalent_onnx_model = preprocess_onnx_model(onnx_model, check_model)
onnx_preprocessing, equivalent_onnx_model = preprocess_onnx_model(onnx_model, check_model)

def create_numpy_forward(model: Optional[onnx.ModelProto]) -> NumpyForwardCallable:
"""Create numpy forward function.
Args:
model (onnx.ModelProto): The ONNX model to execute.
Returns:
NumpyForwardCallable: The numpy equivalent of the ONNX model.
"""
if model is None:
# Return the inputs as is
return lambda *args: args
return lambda *args: execute_onnx_with_numpy(model.graph, *args)

# Return lambda of numpy equivalent of onnx execution
return (
lambda *args: execute_onnx_with_numpy(equivalent_onnx_model.graph, *args)
), equivalent_onnx_model
create_numpy_forward(onnx_preprocessing),
onnx_preprocessing,
create_numpy_forward(equivalent_onnx_model),
equivalent_onnx_model,
)


def get_equivalent_numpy_forward_from_onnx_tree(
onnx_model: onnx.ModelProto,
check_model: bool = True,
lsbs_to_remove_for_trees: Optional[Tuple[int, int]] = None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
) -> Tuple[NumpyForwardCallable, onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided ONNX model for tree-based models only.
Args:
Expand All @@ -283,11 +319,11 @@ def get_equivalent_numpy_forward_from_onnx_tree(
comparison operation. Default to None, as it is not applicable to other types of models.
Returns:
Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]: The function that will
Tuple[NumpyForwardCallable, onnx.ModelProto]: The function that will
execute the equivalent numpy function.
"""

equivalent_onnx_model = preprocess_onnx_model(onnx_model, check_model)
_, equivalent_onnx_model = preprocess_onnx_model(onnx_model, check_model)

# Return lambda of numpy equivalent of onnx execution
return (
Expand Down
103 changes: 102 additions & 1 deletion src/concrete/ml/onnx/onnx_model_manipulations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Some code to manipulate models."""

from copy import deepcopy
from typing import Iterable, List
from typing import Iterable, List, Optional, Tuple

import onnx

Expand Down Expand Up @@ -289,3 +289,104 @@ def _clean_graph_at_node_name(

# Keep the output node
keep_following_outputs_discard_others(onnx_model, [output_to_follow])


# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4532
# Function to convert the first Gather nodes
# to matrix multiplications with one-hot encoding as a pre-processing step
def convert_first_gather_to_matmul(
onnx_model: onnx.ModelProto,
) -> Tuple[Optional[onnx.ModelProto], onnx.ModelProto]:
"""Convert the first Gather node to a matrix multiplication node.
In FHE, Gather is a costly operation since it can involve many PBS.
When it appears first in the onnx model, we can remove it and replace it by a matrix
multiplication node by converting the indices to a one-hot encoding.
Args:
onnx_model (onnx.ModelProto): The onnx model.
Returns:
Tuple[Optional[onnx.ModelProto], onnx.ModelProto]: The pre-processing model and the modified
onnx model.
"""
pre_processing_nodes = []
modified = False
depth_tensors = []
gather_depths = {}

for node in onnx_model.graph.node:
if node.op_type == "Gather" and node.input[1] in [
input.name for input in onnx_model.graph.input
]:
# Extract the inputs and output of the Gather node
data_input = node.input[0]
indices_input = node.input[1]
gather_output = node.output[0]

# Find the shape of the data_input (embedding matrix)
data_shape_initializer = next(
(init for init in onnx_model.graph.initializer if init.name == data_input), None
)
assert data_shape_initializer is not None, f"Shape of {data_input} not found"

# Extract the depth arg for the OneHot node using the embedding matrix
depth = data_shape_initializer.dims[0]
depth_name = f"depth_{node.name}"
gather_depths[node.name] = depth

# Create a node for OneHot operation
pre_processed_x = f"pre_processed_{indices_input}"
pre_processing_nodes.append(
onnx.helper.make_node(
"OneHot",
inputs=[indices_input, depth_name, "values"],
outputs=[pre_processed_x],
)
)

# Store the depth tensor for this Gather node
depth_tensor = onnx.helper.make_tensor(depth_name, onnx.TensorProto.INT64, [1], [depth])
depth_tensors.append(depth_tensor)

# Replace Gather node with MatMul node
matmul_node = onnx.helper.make_node(
"MatMul", inputs=[indices_input, data_input], outputs=[gather_output]
)
onnx_model.graph.node.remove(node)

# Insert the new node at the beginning of the graph
onnx_model.graph.node.insert(0, matmul_node)
modified = True

if not modified:
return None, onnx_model

# Create a tensor for the values of the OneHot node
values_tensor = onnx.helper.make_tensor("values", onnx.TensorProto.FLOAT, [2], [0.0, 1.0])

# Create pre-processing graph
pre_processing_graph = onnx.helper.make_graph(
pre_processing_nodes,
"pre_processing_graph",
[
onnx.helper.make_tensor_value_info(node.input[0], onnx.TensorProto.INT64, [None])
for node in pre_processing_nodes
],
[
onnx.helper.make_tensor_value_info(
node.output[0], onnx.TensorProto.FLOAT, [None, depth]
)
for node, depth in zip(pre_processing_nodes, gather_depths.values())
],
depth_tensors + [values_tensor],
)

pre_processing_onnx = onnx.helper.make_model(
pre_processing_graph, opset_imports=[onnx.helper.make_opsetid("", 14)]
)

# Check the pre-processing onnx
onnx.checker.check_model(pre_processing_onnx)

return pre_processing_onnx, onnx_model
Loading

0 comments on commit 296bc8c

Please sign in to comment.