diff --git a/conftest.py b/conftest.py index edf9a2ee15..c1d1ddb5c1 100644 --- a/conftest.py +++ b/conftest.py @@ -60,7 +60,10 @@ def pytest_addoption(parser): ) parser.addoption( - "--no-flaky", action="store_true", default=False, help="Don't run known flaky tests." + "--no-flaky", + action="store_true", + default=False, + help="Don't run known flaky tests.", ) @@ -382,8 +385,12 @@ def check_float_array_equal_impl( a, b, rtol=0, atol=0.001, error_information: Optional[str] = "" ): + max_atol = numpy.abs(a - b).max() + max_rtol = (numpy.abs(a - b) / numpy.abs(b)).max() + error_message = ( f"Not equal to tolerance rtol={rtol}, atol={atol}\na: {a}\nb: {b}\n" + f"Found {max_atol=}, {max_rtol=}\n" f"{error_information}" ) diff --git a/deps_licenses/licenses_mac_silicon_user.txt b/deps_licenses/licenses_mac_silicon_user.txt index 66c6556898..a293f1613f 100644 --- a/deps_licenses/licenses_mac_silicon_user.txt +++ b/deps_licenses/licenses_mac_silicon_user.txt @@ -1,18 +1,20 @@ Name, Version, License +Jinja2, 3.1.4, BSD License +MarkupSafe, 2.1.5, BSD License PyYAML, 6.0.1, MIT License -brevitas, 0.8.0, UNKNOWN -certifi, 2024.6.2, Mozilla Public License 2.0 (MPL 2.0) +brevitas, 0.10.2, UNKNOWN +certifi, 2024.7.4, Mozilla Public License 2.0 (MPL 2.0) charset-normalizer, 3.3.2, MIT License coloredlogs, 15.0.1, MIT License concrete-python, 2.7.0, BSD-3-Clause dependencies, 2.0.1, BSD License dill, 0.3.8, BSD License -filelock, 3.15.3, The Unlicense (Unlicense) +filelock, 3.15.4, The Unlicense (Unlicense) flatbuffers, 24.3.25, Apache Software License -fsspec, 2024.6.0, BSD License +fsspec, 2024.6.1, BSD License huggingface-hub, 0.23.4, Apache Software License humanfriendly, 10.0, MIT License -hummingbird-ml, 0.4.8, MIT License +hummingbird-ml, 0.4.11, MIT License idna, 3.7, BSD License importlib_resources, 6.4.0, Apache Software License joblib, 1.4.2, BSD License @@ -22,7 +24,7 @@ networkx, 3.1, BSD License numpy, 1.23.5, BSD License onnx, 1.16.1, Apache License v2.0 onnxconverter-common, 1.13.0, MIT License -onnxmltools, 1.11.0, Apache Software License +onnxmltools, 1.12.0, Apache Software License onnxoptimizer, 0.3.13, Apache License v2.0 onnxruntime, 1.18.0, MIT License packaging, 24.1, Apache Software License; BSD License @@ -35,16 +37,18 @@ requests, 2.32.3, Apache Software License scikit-learn, 1.1.3, BSD License scipy, 1.10.1, BSD License six, 1.16.0, MIT License -skl2onnx, 1.12, Apache Software License +skl2onnx, 1.17.0, Apache Software License skops, 0.5.0, MIT skorch, 0.11.0, new BSD 3-Clause -sympy, 1.12.1, BSD License +sympy, 1.13.0, BSD License tabulate, 0.8.10, MIT License threadpoolctl, 3.5.0, BSD License -torch, 1.13.1, BSD License +torch, 2.3.1, BSD License tqdm, 4.66.4, MIT License; Mozilla Public License 2.0 (MPL 2.0) -typing_extensions, 4.5.0, Python Software Foundation License +typing_extensions, 4.12.2, Python Software Foundation License tzdata, 2024.1, Apache Software License +unfoldNd, 0.2.2, MIT License urllib3, 2.2.2, MIT License xgboost, 1.6.2, Apache Software License z3-solver, 4.13.0.0, MIT License +zipp, 3.19.2, MIT License diff --git a/deps_licenses/licenses_mac_silicon_user.txt.md5 b/deps_licenses/licenses_mac_silicon_user.txt.md5 index 4dfc9a8918..8a918180a0 100644 --- a/deps_licenses/licenses_mac_silicon_user.txt.md5 +++ b/deps_licenses/licenses_mac_silicon_user.txt.md5 @@ -1 +1 @@ -adb925c3b7be3e651975febcf49b6543 +6d367701c3ef5eff8763f4e994e03681 diff --git a/src/concrete/ml/onnx/convert.py b/src/concrete/ml/onnx/convert.py index 56545e5df6..0529cea1c6 100644 --- a/src/concrete/ml/onnx/convert.py +++ b/src/concrete/ml/onnx/convert.py @@ -86,7 +86,11 @@ def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto): # Create a GEMM node which combines the MatMul and Add operations gemm_node = helper.make_node( "Gemm", # op_type - [matmul_node.input[0], matmul_node.input[1], bias_other_input_node_name], # inputs + [ + matmul_node.input[0], + matmul_node.input[1], + bias_other_input_node_name, + ], # inputs [add_node.output[0]], # outputs name="Gemm_Node", alpha=1.0, @@ -149,9 +153,14 @@ def get_equivalent_numpy_forward_from_torch( arguments = list(inspect.signature(torch_module.forward).parameters) + if isinstance(dummy_input, torch.Tensor): + dummy_input = dummy_input.to("cpu") + else: + dummy_input = tuple(elt.to("cpu") for elt in dummy_input) + # Export to ONNX torch.onnx.export( - torch_module, + torch_module.to("cpu"), dummy_input, str(output_onnx_file_path), opset_version=OPSET_VERSION_FOR_ONNX_EXPORT, diff --git a/src/concrete/ml/pytest/torch_models.py b/src/concrete/ml/pytest/torch_models.py index fccc67eb11..ad74ce0b61 100644 --- a/src/concrete/ml/pytest/torch_models.py +++ b/src/concrete/ml/pytest/torch_models.py @@ -6,11 +6,16 @@ import brevitas.nn as qnn import numpy import torch -from brevitas.quant import Int8ActPerTensorFloat, Int8WeightPerTensorFloat, IntBias +from brevitas.core.restrict_val import FloatRestrictValue, RestrictValueType +from brevitas.quant import (Int8AccumulatorAwareWeightQuant, + Int8AccumulatorAwareZeroCenterWeightQuant, + Int8ActPerTensorFloat, Int8WeightPerTensorFloat, + IntBias, Uint8ActPerTensorFloat) from torch import nn from torch.nn.utils import prune -from concrete.ml.quantization.qat_quantizers import Int8ActPerTensorPoT, Int8WeightPerTensorPoT +from concrete.ml.quantization.qat_quantizers import (Int8ActPerTensorPoT, + Int8WeightPerTensorPoT) # pylint: disable=too-many-lines @@ -38,7 +43,7 @@ def forward(self, x, y): return x + y + self.value, (x - y) ** 2 -class SimpleNet(torch.nn.Module): +class SimpleNet(nn.Module): """Fake torch model used to generate some onnx.""" def __init__(self) -> None: @@ -267,7 +272,9 @@ def __init__(self, input_output, activation_function, groups): super().__init__() self.activation_function = activation_function() - self.conv1 = nn.Conv2d(input_output, 3, 3, stride=1, padding=1, dilation=1, groups=groups) + self.conv1 = nn.Conv2d( + input_output, 3, 3, stride=1, padding=1, dilation=1, groups=groups + ) self.pool = nn.AvgPool2d(2, 2) self.conv2 = nn.Conv2d(3, 3, 1, stride=1, padding=0, dilation=1, groups=3) self.fc1 = nn.Linear(3 * 3 * 3, 5) @@ -292,7 +299,7 @@ def forward(self, x): return x -class NetWithLoops(torch.nn.Module): +class NetWithLoops(nn.Module): """Torch model, where we reuse some elements in a loop. Torch model, where we reuse some elements in a loop in the forward and don't expect the @@ -328,7 +335,9 @@ def forward(self, x): class MultiInputNN(nn.Module): """Torch model to test multiple inputs forward.""" - def __init__(self, input_output, activation_function): # pylint: disable=unused-argument + def __init__( + self, input_output, activation_function + ): # pylint: disable=unused-argument super().__init__() self.act = activation_function() @@ -351,7 +360,9 @@ class MultiInputNNConfigurable(nn.Module): layer1: nn.Module layer2: nn.Module - def __init__(self, use_conv, use_qat, input_output, n_bits): # pylint: disable=unused-argument + def __init__( + self, use_conv, use_qat, input_output, n_bits + ): # pylint: disable=unused-argument super().__init__() if use_conv: @@ -538,7 +549,7 @@ def step(x, bias): return x -class NetWithConcatUnsqueeze(torch.nn.Module): +class NetWithConcatUnsqueeze(nn.Module): """Torch model to test the concat and unsqueeze operators.""" def __init__(self, activation_function, input_output, n_fc_layers): @@ -616,7 +627,9 @@ def __init__(self, input_output, act): for idx in range(self.n_layers): out_features = in_features if idx == self.n_layers - 1 else in_features layer_name = f"fc{idx}" - layer = nn.Linear(in_features=in_features, out_features=out_features, bias=False) + layer = nn.Linear( + in_features=in_features, out_features=out_features, bias=False + ) self.feat.add_module(layer_name, layer) in_features = out_features @@ -654,7 +667,9 @@ def __init__(self, input_output, act): for idx in range(self.n_layers): out_features = in_features if idx == self.n_layers - 1 else in_features layer_name = f"fc{idx}" - layer = nn.Linear(in_features=in_features, out_features=out_features, bias=False) + layer = nn.Linear( + in_features=in_features, out_features=out_features, bias=False + ) self.feat.add_module(layer_name, layer) in_features = out_features @@ -713,7 +728,9 @@ class TinyQATCNN(nn.Module): should help keep the accumulator bit-width low. """ - def __init__(self, n_classes, n_bits, n_active, signed, narrow, power_of_two_scaling) -> None: + def __init__( + self, n_classes, n_bits, n_active, signed, narrow, power_of_two_scaling + ) -> None: """Construct the CNN with a configurable number of classes. Args: @@ -786,8 +803,12 @@ def __init__(self, n_classes, n_bits, n_active, signed, narrow, power_of_two_sca bias_quant=bias_quant, ) - self.quant4 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True, **q_args) - self.fc1 = qnn.QuantLinear(16, n_classes, weight_bit_width=3, bias=True, **q_args) + self.quant4 = qnn.QuantIdentity( + bit_width=a_bits, return_quant_tensor=True, **q_args + ) + self.fc1 = qnn.QuantLinear( + 16, n_classes, weight_bit_width=3, bias=True, **q_args + ) # Enable pruning, prepared for training self.toggle_pruning(True) @@ -853,7 +874,9 @@ def forward(self, x): class SimpleQAT(nn.Module): """Torch model implements a step function that needs Greater, Cast and Where.""" - def __init__(self, input_output, activation_function, n_bits=2, disable_bit_check=False): + def __init__( + self, input_output, activation_function, n_bits=2, disable_bit_check=False + ): super().__init__() self.act = activation_function() @@ -867,7 +890,9 @@ def __init__(self, input_output, activation_function, n_bits=2, disable_bit_chec n_bits_weights = n_bits # Generate the pattern 0, 1, ..., 2^N-1, 0, 1, .. 2^N-1, 0, 1.. - all_weights = numpy.mod(numpy.arange(numpy.prod(self.fc1.weight.shape)), 2**n_bits_weights) + all_weights = numpy.mod( + numpy.arange(numpy.prod(self.fc1.weight.shape)), 2**n_bits_weights + ) # Shuffle the pattern and reshape to weight shape numpy.random.shuffle(all_weights) @@ -877,7 +902,9 @@ def __init__(self, input_output, activation_function, n_bits=2, disable_bit_chec # but we may want to disable it to check that the QAT import catches the error if not disable_bit_check: # Ensure we have the correct max/min that produces the correct scale in Quantized Array - assert numpy.max(int_weights) - numpy.min(int_weights) == (2**n_bits_weights - 1) + assert numpy.max(int_weights) - numpy.min(int_weights) == ( + 2**n_bits_weights - 1 + ) # We want signed weights, so offset the generated weights int_weights = int_weights - 2 ** (n_bits_weights - 1) @@ -984,7 +1011,9 @@ def __init__(self, use_conv, use_qat, inp_size, n_bits): if use_qat: self.mixing_layer = nn.Sequential( qnn.QuantIdentity(bit_width=n_bits), - qnn.QuantConv2d(1, 1, 3, stride=1, bias=True, weight_bit_width=n_bits), + qnn.QuantConv2d( + 1, 1, 3, stride=1, bias=True, weight_bit_width=n_bits + ), ) layer_obj = self.mixing_layer[1] else: @@ -996,7 +1025,9 @@ def __init__(self, use_conv, use_qat, inp_size, n_bits): if use_qat: self.mixing_layer = nn.Sequential( qnn.QuantIdentity(bit_width=n_bits), - qnn.QuantLinear(inp_size, inp_size, bias=True, weight_bit_width=n_bits), + qnn.QuantLinear( + inp_size, inp_size, bias=True, weight_bit_width=n_bits + ), ) layer_obj = self.mixing_layer[1] else: @@ -1004,6 +1035,7 @@ def __init__(self, use_conv, use_qat, inp_size, n_bits): layer_obj = self.mixing_layer layer_obj.weight.data = torch.from_numpy(np_weights).float() + assert layer_obj.bias is not None layer_obj.bias.data = torch.rand(size=(1,)) def forward(self, x): @@ -1025,7 +1057,9 @@ class DoubleQuantQATMixNet(nn.Module): Used to test that it keeps the input TLU. """ - def __init__(self, use_conv, use_qat, inp_size, n_bits): # pylint: disable=unused-argument + def __init__( + self, use_conv, use_qat, inp_size, n_bits + ): # pylint: disable=unused-argument super().__init__() # A first quantizer @@ -1216,12 +1250,12 @@ def forward(self, x): # for example a 4d tensor NCHW, padded with [1, 2, 2, 3] is padded # along the last 2 dimensions, with 1 cell to the left and 2 to the right (dimension 4: W) # and 2 cells at the top and 3 at the bottom (dimension 3: H) - x = torch.nn.functional.pad(x, (3, 2)) - x = torch.nn.functional.pad(x, (1, 2, 3, 4)) + x = nn.functional.pad(x, (3, 2)) + x = nn.functional.pad(x, (1, 2, 3, 4)) # Concrete ML only supports padding on the last two dimensions as this is the # most common setting - x = torch.nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0)) + x = nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0)) return x @@ -1264,7 +1298,9 @@ def __init__( return_quant_tensor=True, ) - self.relu1 = qnn.QuantReLU(return_quant_tensor=True, bit_width=n_bits, act_quant=act_quant) + self.relu1 = qnn.QuantReLU( + return_quant_tensor=True, bit_width=n_bits, act_quant=act_quant + ) self.linear2 = qnn.QuantLinear( in_features=hidden_shape, out_features=hidden_shape, @@ -1275,7 +1311,9 @@ def __init__( return_quant_tensor=True, ) - self.relu2 = qnn.QuantReLU(return_quant_tensor=True, bit_width=n_bits, act_quant=act_quant) + self.relu2 = qnn.QuantReLU( + return_quant_tensor=True, bit_width=n_bits, act_quant=act_quant + ) self.linear3 = qnn.QuantLinear( in_features=hidden_shape, @@ -1340,7 +1378,12 @@ class ConcatFancyIndexing(nn.Module): """Concat with fancy indexing.""" def __init__( - self, input_shape, hidden_shape, output_shape, n_bits: int = 4, n_blocks: int = 3 + self, + input_shape, + hidden_shape, + output_shape, + n_bits: int = 4, + n_blocks: int = 3, ) -> None: """Torch Model. @@ -1355,17 +1398,26 @@ def __init__( self.n_blocks = n_blocks self.quant_1 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True) - self.fc1 = qnn.QuantLinear(input_shape, hidden_shape, bias=False, weight_bit_width=n_bits) + self.fc1 = qnn.QuantLinear( + input_shape, hidden_shape, bias=False, weight_bit_width=n_bits + ) - self.quant_concat = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True) + self.quant_concat = qnn.QuantIdentity( + bit_width=n_bits, return_quant_tensor=True + ) self.quant_2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True) self.fc2 = qnn.QuantLinear( - hidden_shape * self.n_blocks, hidden_shape, bias=True, weight_bit_width=n_bits + hidden_shape * self.n_blocks, + hidden_shape, + bias=True, + weight_bit_width=n_bits, ) self.quant_3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True) - self.fc4 = qnn.QuantLinear(hidden_shape, output_shape, bias=True, weight_bit_width=n_bits) + self.fc4 = qnn.QuantLinear( + hidden_shape, output_shape, bias=True, weight_bit_width=n_bits + ) def forward(self, x): """Forward pass. @@ -1393,7 +1445,7 @@ def forward(self, x): return x -class PartialQATModel(torch.nn.Module): +class PartialQATModel(nn.Module): """A model with a QAT Module.""" def __init__(self, input_shape: int, output_shape: int, n_bits: int): @@ -1442,7 +1494,7 @@ def forward(self, input1): return output -class ManualLogisticRegressionTraining(torch.nn.Module): +class ManualLogisticRegressionTraining(nn.Module): """PyTorch module for performing SGD training.""" def __init__(self, learning_rate=0.1): @@ -1497,7 +1549,9 @@ def predict(x, weights, bias): class AddNet(nn.Module): """Torch model that performs a simple addition between two inputs.""" - def __init__(self, use_conv, use_qat, input_output, n_bits): # pylint: disable=unused-argument + def __init__( + self, use_conv, use_qat, input_output, n_bits + ): # pylint: disable=unused-argument super().__init__() # No initialization needed for simple addition @@ -1605,7 +1659,9 @@ def forward(self, x): # pylint: disable-next=no-self-use class TorchDivide(torch.nn.Module): """Torch model that performs a encrypted division between two inputs.""" - def __init__(self, input_output, activation_function): # pylint: disable=unused-argument + def __init__( + self, input_output, activation_function + ): # pylint: disable=unused-argument super().__init__() @staticmethod @@ -1625,7 +1681,9 @@ def forward(x, y): class TorchMultiply(torch.nn.Module): """Torch model that performs a encrypted multiplication between two inputs.""" - def __init__(self, input_output, activation_function): # pylint: disable=unused-argument + def __init__( + self, input_output, activation_function + ): # pylint: disable=unused-argument super().__init__() @staticmethod @@ -1665,3 +1723,186 @@ def forward(self, x): x = self.relu(x) x = self.linear(x) return x + + +# pylint: disable-next=too-many-ancestors +class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat): + """CommonIntWeightPerChannelQuant.""" + + scaling_per_output_channel = True + + +# pylint: disable-next=too-many-ancestors +class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): + """CommonIntAccumulatorAwareWeightQuant.""" + + restrict_scaling_impl = FloatRestrictValue # backwards compatibility + bit_width = None + + +# pylint: disable-next=too-many-ancestors +class CommonIntAccumulatorAwareZeroCenterWeightQuant( + Int8AccumulatorAwareZeroCenterWeightQuant +): + """CommonIntAccumulatorAwareZeroCenterWeightQuant.""" + + bit_width = None + + +# pylint: disable-next=too-many-ancestors +class CommonUintActQuant(Uint8ActPerTensorFloat): + """CommonUintActQuant.""" + + bit_width = None + restrict_scaling_type = RestrictValueType.LOG_FP + + +def weight_init(layer: nn.Module): + """Initialize layer weights. + + Arguments: + layer (nn.Module): a conv2d layer + """ + + if isinstance(layer, nn.Conv2d): + nn.init.kaiming_normal_(layer.weight, nn.init.calculate_gain("relu")) + if layer.bias is not None: + layer.bias.data.zero_() + + +# pylint: disable-next=too-many-instance-attributes +class FloatLeNet(nn.Module): + """Floating point LeNet.""" + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d( + in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0 + ) + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d( + in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0 + ) + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu2 = nn.ReLU(inplace=True) + + self.fc1 = nn.Linear(400, 120, bias=True) + self.relu3 = nn.ReLU() + self.fc2 = nn.Linear(120, 84, bias=True) + self.relu4 = nn.ReLU() + self.fc3 = nn.Linear(84, 10, bias=True) + + self.apply(weight_init) + + def forward(self, x: torch.Tensor): + """Forward function. + + Arguments: + x (torch.Tensor): input image + + Returns: + Neural network prediction + """ + x = self.pool1(self.relu1(self.conv1(x))) + x = self.pool2(self.relu2(self.conv2(x))) + x = torch.flatten(x, 1) + x = self.relu3(self.fc1(x)) + x = self.relu4(self.fc2(x)) + x = self.fc3(x) + return x + + +# pylint: disable-next=too-many-instance-attributes +class QuantLeNet(FloatLeNet): + """Quantized LeNet with per-channel quantization.""" + + def __init__( + self, + weight_bit_width=4, + act_bit_width=4, + acc_bit_width=32, + weight_quant=CommonIntAccumulatorAwareWeightQuant, + ): + super().__init__() + + self.conv1 = qnn.QuantConv2d( + bias=False, + in_channels=1, + out_channels=6, + kernel_size=5, + stride=1, + padding=0, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu1 = qnn.QuantReLU( + inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width + ) + + self.conv2 = qnn.QuantConv2d( + bias=False, + in_channels=6, + out_channels=16, + kernel_size=5, + stride=1, + padding=0, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu2 = qnn.QuantReLU( + inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width + ) + + self.fc1 = qnn.QuantLinear( + 400, + 120, + bias=True, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.relu3 = qnn.QuantReLU( + act_quant=CommonUintActQuant, bit_width=act_bit_width + ) + self.fc2 = qnn.QuantLinear( + 120, + 84, + bias=True, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.relu4 = qnn.QuantReLU( + act_quant=CommonUintActQuant, bit_width=act_bit_width + ) + self.fc3 = qnn.QuantLinear( + 84, + 10, + bias=True, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + + self.apply(weight_init) diff --git a/src/concrete/ml/quantization/base_quantized_op.py b/src/concrete/ml/quantization/base_quantized_op.py index 86c818fb8b..7aadb7aaf1 100644 --- a/src/concrete/ml/quantization/base_quantized_op.py +++ b/src/concrete/ml/quantization/base_quantized_op.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, TextIO, Tuple, Type, Union, cast import numpy +import numpy.typing as npt from concrete import fhe @@ -122,6 +123,7 @@ def __init__( input_quant_opts: Optional[QuantizationOptions] = None, **attrs, ) -> None: + self.n_bits = n_bits_output if input_quant_opts is not None: @@ -916,7 +918,7 @@ def can_fuse(self) -> bool: def make_output_quant_parameters( self, q_values: Union[numpy.ndarray, Any], - scale: numpy.float64, + scale: npt.NDArray[numpy.float64], zero_point: Union[int, float, numpy.ndarray], ) -> QuantizedArray: """Build a quantized array from quantized integer results of the op and quantization params. @@ -1016,6 +1018,9 @@ def cnp_round( # Rounding to low bit-width with approximate can cause issues with overflow protection # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4345 x = fhe.round_bit_pattern( - x, lsbs_to_remove=lsbs_value, exactness=exactness, overflow_protection=False + x, + lsbs_to_remove=lsbs_value, + exactness=exactness, + overflow_protection=False, ) return x diff --git a/src/concrete/ml/quantization/post_training.py b/src/concrete/ml/quantization/post_training.py index 2d3f54f9fa..eae2330e3d 100644 --- a/src/concrete/ml/quantization/post_training.py +++ b/src/concrete/ml/quantization/post_training.py @@ -393,14 +393,14 @@ def _calibrate_layers_activation( assert isinstance(quant_result, QuantizedArray) return ( quant_result.dequant(), - quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None, + (quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None), ) # For QAT, the calibration is performed on raw data, performing # calibration on quantized that would confound inferred QAT and PTQ. return ( raw_result, - quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None, + (quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None), ) @abstractmethod @@ -625,7 +625,9 @@ def _quantize_layers(self, *input_calibration_data: numpy.ndarray): for input_name in variable_input_names ) output_calibration_data, layer_quantizer = self._process_layer( - quantized_op_instance, *curr_calibration_data, quantizers=layer_quant + quantized_op_instance, + *curr_calibration_data, + quantizers=layer_quant, ) node_results[output_name] = output_calibration_data node_override_quantizer[output_name] = layer_quantizer @@ -724,7 +726,9 @@ def quantize_module(self, *calibration_data: numpy.ndarray) -> QuantizedModule: return quantized_module def _process_input_quantizers( - self, quantized_module: QuantizedModule, calibration_data: Tuple[numpy.ndarray, ...] + self, + quantized_module: QuantizedModule, + calibration_data: Tuple[numpy.ndarray, ...], ): # pylint: disable=too-many-branches """Determine the quantizers for a quantized module. diff --git a/src/concrete/ml/quantization/quantized_ops.py b/src/concrete/ml/quantization/quantized_ops.py index 73cd18193a..6b720b425e 100644 --- a/src/concrete/ml/quantization/quantized_ops.py +++ b/src/concrete/ml/quantization/quantized_ops.py @@ -8,30 +8,20 @@ from typing import Any, Dict, Optional, Sequence, Set, Union import numpy +import numpy.typing as npt from concrete.fhe import conv as fhe_conv from concrete.fhe import maxpool as fhe_maxpool from concrete.fhe import tag, univariate, zeros from typing_extensions import SupportsIndex from ..common.debugging import assert_false, assert_true -from ..onnx.onnx_impl_utils import ( - compute_onnx_pool_padding, - numpy_onnx_pad, - onnx_avgpool_compute_norm_const, -) +from ..onnx.onnx_impl_utils import (compute_onnx_pool_padding, numpy_onnx_pad, + onnx_avgpool_compute_norm_const) from ..onnx.ops_impl import RawOpOutput -from .base_quantized_op import ( - ONNXOpInputOutputType, - QuantizedMixingOp, - QuantizedOp, - QuantizedOpUnivariateOfEncrypted, -) -from .quantizers import ( - QuantizationOptions, - QuantizedArray, - UniformQuantizationParameters, - UniformQuantizer, -) +from .base_quantized_op import (ONNXOpInputOutputType, QuantizedMixingOp, + QuantizedOp, QuantizedOpUnivariateOfEncrypted) +from .quantizers import (QuantizationOptions, QuantizedArray, + UniformQuantizationParameters, UniformQuantizer) def _check_op_input_zero_point(zero_point: Any, op_name: Optional[str]): @@ -162,7 +152,7 @@ def __init__( f"Got alpha == {alpha} and beta == {beta}.", ) - # pylint: disable-next=too-many-statements,too-many-locals + # pylint: disable-next=too-many-statements,too-many-locals,too-many-branches def q_impl( self, *q_inputs: ONNXOpInputOutputType, @@ -420,7 +410,16 @@ def copy_function(x): # Note that here we do not rescale to the output_scale and we do not add a zero-point # Any following Gemm/MatMul/Conv layers will do the rescaling (during re-quantization) # by calling _prepare_inputs_with_constants(...quantize_real_values=True) - m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale + m_matmul: npt.NDArray[numpy.float64] + + # Handle channel-wise quantization + if q_input2.quantizer.scale.shape == tuple(): + m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale + else: + assert q_input2.quantizer.scale.shape == (q_input2.qvalues.shape[0], 1) + weight_quant_scale = numpy.transpose(q_input2.quantizer.scale, axes=(1, 0)) + assert isinstance(weight_quant_scale, numpy.ndarray) + m_matmul = q_input1.quantizer.scale * weight_quant_scale # If this operation's result are network outputs, return # directly the integer values and a appropriate quantization parameters that @@ -566,7 +565,7 @@ def q_impl( # If this operator is the last one in the graph, # we rescale using the smallest scale to keep all information if self.produces_graph_output: - common_scale = min(q_input_0.quantizer.scale, q_input_1.quantizer.scale) + common_scale = numpy.minimum(q_input_0.quantizer.scale, q_input_1.quantizer.scale) # Otherwise we use the output op quantization scale else: common_scale = self.output_quant_params.scale @@ -953,7 +952,18 @@ def q_impl( # This is going to be compiled with a PBS (along with the following activation function) # Note that we don't re-quantize the output of the conv, this will be done by # any Gemm/Add/Conv layers that follow - m_matmul = q_input.quantizer.scale * q_weights.quantizer.scale + m_matmul: npt.NDArray[numpy.float64] + if q_weights.quantizer.scale.shape == tuple(): + m_matmul = q_input.quantizer.scale * q_weights.quantizer.scale + else: + expected_scale_shape = (q_weights.qvalues.shape[0], *(1 for _ in q_weights.qvalues.shape[1:])) + assert q_weights.quantizer.scale.shape == expected_scale_shape + weight_quant_scale = numpy.transpose( + q_weights.quantizer.scale, + axes=(1, 0, *(index for index in range(2, len(expected_scale_shape)))), + ) + assert isinstance(weight_quant_scale, numpy.ndarray) + m_matmul = q_input.quantizer.scale * weight_quant_scale bias_shape = (1, -1, 1) if is_conv1d else (1, -1, 1, 1) @@ -1046,7 +1056,10 @@ def __init__( self.auto_pad = attrs.get("auto_pad", "NOTSET") self.kernel_shape = attrs.get("kernel_shape", None) - assert_true(self.kernel_shape is not None, "Setting parameter 'kernel_shape' is required.") + assert_true( + self.kernel_shape is not None, + "Setting parameter 'kernel_shape' is required.", + ) self.count_include_pad = attrs.get("count_include_pad", 1) self.pads = attrs.get("pads", tuple([0] * 2 * (len(self.kernel_shape) - 2))) @@ -1365,7 +1378,10 @@ def q_impl( assert_true(pads.size == 4, "Not currently supporting padding of 3D tensors") pad_value = 0 if prepared_inputs[2] is None else prepared_inputs[2] - assert_true(pad_value == 0, "Concrete ML only supports padding with constant zero values") + assert_true( + pad_value == 0, + "Concrete ML only supports padding with constant zero values", + ) assert q_input.quantizer.zero_point is not None q_input_pad = numpy_onnx_pad(q_input.qvalues, pads, q_input.quantizer.zero_point, True) @@ -2289,7 +2305,7 @@ def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray: n_bits = int(self.constant_inputs[3]) self.output_quant_params = UniformQuantizationParameters( - scale=numpy.float64(self.constant_inputs[1]), + scale=numpy.array(self.constant_inputs[1], dtype=float), zero_point=int(self.constant_inputs[2]), offset=2 ** (n_bits - 1) if self.is_signed else 0, ) @@ -2913,7 +2929,11 @@ def q_impl( # Compute padding with floor and apply it to the input, pad with the input zero-point pool_pads = compute_onnx_pool_padding( - q_input.qvalues.shape, self.kernel_shape, self.pads, self.strides, ceil_mode=0 + q_input.qvalues.shape, + self.kernel_shape, + self.pads, + self.strides, + ceil_mode=0, ) # Can only pad with scalar zero-points, but zero-points can be float in special cases @@ -2929,7 +2949,14 @@ def q_impl( with tag(self.op_instance_name + ".unfold"): sum_result = fhe_conv( - q_input_pad, kernels, None, fake_pads, self.strides, None, None, n_in_channels + q_input_pad, + kernels, + None, + fake_pads, + self.strides, + None, + None, + n_in_channels, ) if self.debug_value_tracker is not None: diff --git a/src/concrete/ml/quantization/quantizers.py b/src/concrete/ml/quantization/quantizers.py index c1bd058a04..fc30d1ee83 100644 --- a/src/concrete/ml/quantization/quantizers.py +++ b/src/concrete/ml/quantization/quantizers.py @@ -8,6 +8,7 @@ import numpy from concrete.fhe.tracing.tracer import Tracer +from numpy import typing as npt from ..common.debugging import assert_true from ..common.serialization.dumpers import dump, dumps @@ -385,13 +386,13 @@ class UniformQuantizationParameters: The parameters are computed from quantization options and quantization statistics. """ - scale: Optional[numpy.float64] = None + scale: Optional[npt.NDArray[numpy.float64]] = None zero_point: Optional[Union[int, float, numpy.ndarray]] = None offset: Optional[int] = None def __init__( self, - scale: Optional[numpy.float64] = None, + scale: Optional[npt.NDArray[numpy.float64]] = None, zero_point: Optional[Union[int, float, numpy.ndarray]] = None, offset: Optional[int] = None, ): @@ -512,7 +513,7 @@ def compute_quantization_parameters( if numpy.abs(stats.rmax) < STABILITY_CONST: # If the value is a 0 we cannot do it since the scale would become 0 as well # resulting in division by 0 - self.scale = numpy.float64(1.0) + self.scale = numpy.array(1.0, dtype=float) # Ideally we should get rid of round here but it is risky # regarding the FHE compilation. # Indeed, the zero_point value for the weights has to be an integer @@ -521,7 +522,7 @@ def compute_quantization_parameters( else: # If the value is not a 0 we can tweak the scale factor so that # the value quantizes to 1 - self.scale = numpy.float64(stats.rmax) + self.scale = numpy.array(stats.rmax, dtype=float) self.zero_point = 0 else: if options.is_symmetric: @@ -556,13 +557,16 @@ def compute_quantization_parameters( "This can occur with a badly trained model.", ) unique_scales = numpy.unique(numpy.diff(stats.uvalues)) - self.scale = numpy.float64(unique_scales[0]) + self.scale = numpy.array(unique_scales[0], dtype=float) if self.scale is None: - self.scale = numpy.float64( - (stats.rmax - stats.rmin) / (2**options.n_bits - 1) - if stats.rmax != stats.rmin - else 1.0 + self.scale = numpy.array( + ( + (stats.rmax - stats.rmin) / (2**options.n_bits - 1) + if stats.rmax != stats.rmin + else 1.0 + ), + dtype=float, ) if options.is_qat: @@ -618,7 +622,7 @@ def __init__( # Force scale to be a float64 if self.scale is not None: - self.scale = numpy.float64(self.scale) + self.scale = numpy.array(self.scale, dtype=float) def __eq__(self, other) -> bool: @@ -908,7 +912,7 @@ def _values_setup( elif isinstance(values, Tracer): self.values = values else: - self.values = numpy.array(values) + self.values = numpy.array(values, dtype=float) # If no stats are provided, compute them. # Note that this cannot be done during tracing @@ -944,7 +948,7 @@ def _values_setup( elif isinstance(values, Tracer): self.qvalues = values else: - self.qvalues = numpy.array(values) # pragma: no cover + self.qvalues = numpy.array(values, dtype=float) # pragma: no cover # Populate self.values self.dequant() @@ -1014,7 +1018,7 @@ def update_values(self, values: Union[numpy.ndarray, Tracer]) -> Union[numpy.nda elif isinstance(values, Tracer): # pragma: no cover self.values = values else: # pragma: no cover - self.values = numpy.array(values) + self.values = numpy.array(values, dtype=float) return self.quant() def update_quantized_values( @@ -1033,7 +1037,7 @@ def update_quantized_values( elif isinstance(qvalues, Tracer): # pragma: no cover self.qvalues = qvalues else: # pragma: no cover - self.qvalues = numpy.array(qvalues) + self.qvalues = numpy.array(qvalues, dtype=float) return self.dequant() def quant(self) -> Union[numpy.ndarray, Tracer]: diff --git a/tests/torch/test_brevitas_qat.py b/tests/torch/test_brevitas_qat.py index 32331fc42d..5bdf779c20 100644 --- a/tests/torch/test_brevitas_qat.py +++ b/tests/torch/test_brevitas_qat.py @@ -23,11 +23,15 @@ from concrete.ml.pytest.torch_models import ( NetWithConstantsFoldedBeforeOps, QuantCustomModel, + QuantLeNet, TinyQATCNN, ) from concrete.ml.quantization.base_quantized_op import QuantizedMixingOp from concrete.ml.quantization.post_training import PowerOfTwoScalingRoundPBSAdapter -from concrete.ml.quantization.qat_quantizers import Int8ActPerTensorPoT, Int8WeightPerTensorPoT +from concrete.ml.quantization.qat_quantizers import ( + Int8ActPerTensorPoT, + Int8WeightPerTensorPoT, +) from concrete.ml.sklearn import _get_sklearn_neural_net_models from concrete.ml.sklearn.qnn_module import SparseQuantNeuralNetwork from concrete.ml.torch.compile import compile_brevitas_qat_model @@ -97,7 +101,11 @@ def train_brevitas_network_tinymnist(is_cnn, qat_bits, signed, narrow, pot_scali x_all = numpy.expand_dims(x_all.reshape((-1, 8, 8)), 1) x_train, x_test, y_train, y_test = train_test_split( - x_all, y_all, test_size=0.25, shuffle=True, random_state=numpy.random.randint(0, 2**15) + x_all, + y_all, + test_size=0.25, + shuffle=True, + random_state=numpy.random.randint(0, 2**15), ) def train_one_epoch(net, optimizer, train_loader): @@ -130,7 +138,9 @@ def train_one_epoch(net, optimizer, train_loader): while not trained_ok: # Create the tiny CNN module with 10 output classes if is_cnn: - net = TinyQATCNN(10, qat_bits, 4 if qat_bits <= 3 else 20, signed, narrow, pot_scaling) + net = TinyQATCNN( + 10, qat_bits, 4 if qat_bits <= 3 else 20, signed, narrow, pot_scaling + ) else: if pot_scaling: act_quant = Int8ActPerTensorPoT @@ -141,7 +151,9 @@ def train_one_epoch(net, optimizer, train_loader): weight_quant = Int8WeightPerTensorFloat bias_quant = None - net = QuantCustomModel(64, 10, 100, qat_bits, act_quant, weight_quant, bias_quant) + net = QuantCustomModel( + 64, 10, 100, qat_bits, act_quant, weight_quant, bias_quant + ) # Train a single epoch to have a fast test, accuracy should still be the same for both # FHE simulation and torch @@ -170,7 +182,9 @@ def train_one_epoch(net, optimizer, train_loader): # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3933 @pytest.mark.flaky @pytest.mark.parametrize("qat_bits", [3]) -@pytest.mark.parametrize("signed, narrow", [(True, False), (True, True), (False, False)]) +@pytest.mark.parametrize( + "signed, narrow", [(True, False), (True, True), (False, False)] +) def test_brevitas_tinymnist_cnn( qat_bits, signed, @@ -261,8 +275,12 @@ def test_with_concrete(quantized_module, test_loader, use_fhe_simulation): @pytest.mark.parametrize("n_outputs", [5]) @pytest.mark.parametrize("input_dim", [100]) @pytest.mark.parametrize("model_class", _get_sklearn_neural_net_models()) -@pytest.mark.parametrize("signed, narrow", [(True, False), (False, False), (True, True)]) -@pytest.mark.skip(reason="Torch dtype setting interferes with parallel test launch, and flaky test") +@pytest.mark.parametrize( + "signed, narrow", [(True, False), (False, False), (True, True)] +) +@pytest.mark.skip( + reason="Torch dtype setting interferes with parallel test launch, and flaky test" +) def test_brevitas_intermediary_values( n_layers, n_bits_w_a, @@ -415,9 +433,13 @@ def forward(self, x): ] # Make sure the quantization options were well set by brevitas - assert len(set(dbg_model.narrow_range_inp)) > 0 and dbg_model.narrow_range_inp[0] == narrow assert ( - len(set(dbg_model.narrow_range_weight)) > 0 and dbg_model.narrow_range_weight[0] == narrow + len(set(dbg_model.narrow_range_inp)) > 0 + and dbg_model.narrow_range_inp[0] == narrow + ) + assert ( + len(set(dbg_model.narrow_range_weight)) > 0 + and dbg_model.narrow_range_weight[0] == narrow ) # Iterate across conv/linear layers @@ -425,7 +447,9 @@ def forward(self, x): # pylint: disable-next=consider-using-enumerate for idx in range(len(cml_intermediary_values)): # Check if any activations are different between Brevitas and Concrete ML - diff_inp = numpy.abs(cml_intermediary_values[idx] - dbg_model.intermediary_values[idx]) + diff_inp = numpy.abs( + cml_intermediary_values[idx] - dbg_model.intermediary_values[idx] + ) error = "" if numpy.any(diff_inp) > 0: # If any mismatches, then extract them and print them @@ -487,7 +511,9 @@ def test_brevitas_constant_folding(default_configuration): model = NetWithConstantsFoldedBeforeOps(config, 2) - with pytest.raises(ValueError, match=".*Error occurred during quantization aware training.*"): + with pytest.raises( + ValueError, match=".*Error occurred during quantization aware training.*" + ): compile_brevitas_qat_model( model.to("cpu"), torch_inputset=data, @@ -517,7 +543,9 @@ def test_brevitas_power_of_two( the user's round PBS configuration. """ - net, x_all, _ = train_brevitas_network_tinymnist(is_cnn, n_bits, True, False, power_of_two) + net, x_all, _ = train_brevitas_network_tinymnist( + is_cnn, n_bits, True, False, power_of_two + ) utils.QUANT_ROUND_LIKE_ROUND_PBS = True @@ -535,19 +563,27 @@ def test_brevitas_power_of_two( num_round_pbs_layers = 0 for _, node_op in quantized_module.quant_layers_dict.values(): if isinstance(node_op, QuantizedMixingOp): - num_round_pbs_layers += 1 if node_op.rounding_threshold_bits is not None else 0 + num_round_pbs_layers += ( + 1 if node_op.rounding_threshold_bits is not None else 0 + ) if pot_should_be_applied: lsbs_to_remove = ( node_op.lsbs_to_remove["matmul"] - if (node_op.lsbs_to_remove is not None) and ("matmul" in node_op.lsbs_to_remove) + if (node_op.lsbs_to_remove is not None) + and ("matmul" in node_op.lsbs_to_remove) else None ) assert node_op.rounding_threshold_bits == lsbs_to_remove elif manual_rounding: # If manual rounding was set, LSBs_to_remove must be equal # to the accumulator size minus the requested rounding_threshold_bits - assert node_op.rounding_threshold_bits.get("n_bits", None) == manual_rounding - assert node_op.produces_graph_output or node_op.lsbs_to_remove is not None + assert ( + node_op.rounding_threshold_bits.get("n_bits", None) + == manual_rounding + ) + assert ( + node_op.produces_graph_output or node_op.lsbs_to_remove is not None + ) # The power-of-two optimization will only work # when Relu activations are used and scaling factors are forced to be 2**s @@ -586,7 +622,9 @@ def test_brevitas_power_of_two( # the number of rounding nodes should be equal num_rounding_mlir = quantized_module.fhe_circuit.mlir.count(".round") - assert num_rounding_mlir == 2, "Power-of-to adapter: Rounding nodes not found in MLIR" + assert ( + num_rounding_mlir == 2 + ), "Power-of-to adapter: Rounding nodes not found in MLIR" # Remove rounding in the network to perform inference without the optimization. # We expect a network that was optimized with the power-of-two adapter @@ -606,3 +644,21 @@ def test_brevitas_power_of_two( check_array_equal(y_pred_sim_round, y_pred_clear_round) check_array_equal(y_pred_clear_round, y_pred_clear_no_round) + + +def test_brevitas_channel_wise(check_float_array_equal): + """Make sure that we can compile brevitas channel-wise quantization""" + model = QuantLeNet() + model.eval() + + with torch.no_grad(): + batch_size = 3 + image_size = 1, 32, 32 + images = torch.rand((batch_size, *image_size)) + out = model(images).detach().numpy() + quantized_module = compile_brevitas_qat_model( + model, images, rounding_threshold_bits=6 + ) + out_qm = quantized_module(images.detach().numpy()) + + check_float_array_equal(out, out_qm, atol=0.01, rtol=1.0)