diff --git a/conftest.py b/conftest.py index edf9a2ee1..1cf5308be 100644 --- a/conftest.py +++ b/conftest.py @@ -60,7 +60,10 @@ def pytest_addoption(parser): ) parser.addoption( - "--no-flaky", action="store_true", default=False, help="Don't run known flaky tests." + "--no-flaky", + action="store_true", + default=False, + help="Don't run known flaky tests.", ) diff --git a/src/concrete/ml/onnx/convert.py b/src/concrete/ml/onnx/convert.py index 56545e5df..0529cea1c 100644 --- a/src/concrete/ml/onnx/convert.py +++ b/src/concrete/ml/onnx/convert.py @@ -86,7 +86,11 @@ def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto): # Create a GEMM node which combines the MatMul and Add operations gemm_node = helper.make_node( "Gemm", # op_type - [matmul_node.input[0], matmul_node.input[1], bias_other_input_node_name], # inputs + [ + matmul_node.input[0], + matmul_node.input[1], + bias_other_input_node_name, + ], # inputs [add_node.output[0]], # outputs name="Gemm_Node", alpha=1.0, @@ -149,9 +153,14 @@ def get_equivalent_numpy_forward_from_torch( arguments = list(inspect.signature(torch_module.forward).parameters) + if isinstance(dummy_input, torch.Tensor): + dummy_input = dummy_input.to("cpu") + else: + dummy_input = tuple(elt.to("cpu") for elt in dummy_input) + # Export to ONNX torch.onnx.export( - torch_module, + torch_module.to("cpu"), dummy_input, str(output_onnx_file_path), opset_version=OPSET_VERSION_FOR_ONNX_EXPORT, diff --git a/src/concrete/ml/pytest/torch_models.py b/src/concrete/ml/pytest/torch_models.py index fccc67eb1..84a05d492 100644 --- a/src/concrete/ml/pytest/torch_models.py +++ b/src/concrete/ml/pytest/torch_models.py @@ -6,7 +6,15 @@ import brevitas.nn as qnn import numpy import torch -from brevitas.quant import Int8ActPerTensorFloat, Int8WeightPerTensorFloat, IntBias +from brevitas.core.restrict_val import FloatRestrictValue, RestrictValueType +from brevitas.quant import ( + Int8AccumulatorAwareWeightQuant, + Int8AccumulatorAwareZeroCenterWeightQuant, + Int8ActPerTensorFloat, + Int8WeightPerTensorFloat, + IntBias, + Uint8ActPerTensorFloat, +) from torch import nn from torch.nn.utils import prune @@ -38,7 +46,7 @@ def forward(self, x, y): return x + y + self.value, (x - y) ** 2 -class SimpleNet(torch.nn.Module): +class SimpleNet(nn.Module): """Fake torch model used to generate some onnx.""" def __init__(self) -> None: @@ -292,7 +300,7 @@ def forward(self, x): return x -class NetWithLoops(torch.nn.Module): +class NetWithLoops(nn.Module): """Torch model, where we reuse some elements in a loop. Torch model, where we reuse some elements in a loop in the forward and don't expect the @@ -538,7 +546,7 @@ def step(x, bias): return x -class NetWithConcatUnsqueeze(torch.nn.Module): +class NetWithConcatUnsqueeze(nn.Module): """Torch model to test the concat and unsqueeze operators.""" def __init__(self, activation_function, input_output, n_fc_layers): @@ -1004,6 +1012,7 @@ def __init__(self, use_conv, use_qat, inp_size, n_bits): layer_obj = self.mixing_layer layer_obj.weight.data = torch.from_numpy(np_weights).float() + assert layer_obj.bias is not None layer_obj.bias.data = torch.rand(size=(1,)) def forward(self, x): @@ -1216,12 +1225,12 @@ def forward(self, x): # for example a 4d tensor NCHW, padded with [1, 2, 2, 3] is padded # along the last 2 dimensions, with 1 cell to the left and 2 to the right (dimension 4: W) # and 2 cells at the top and 3 at the bottom (dimension 3: H) - x = torch.nn.functional.pad(x, (3, 2)) - x = torch.nn.functional.pad(x, (1, 2, 3, 4)) + x = nn.functional.pad(x, (3, 2)) + x = nn.functional.pad(x, (1, 2, 3, 4)) # Concrete ML only supports padding on the last two dimensions as this is the # most common setting - x = torch.nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0)) + x = nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0)) return x @@ -1340,7 +1349,12 @@ class ConcatFancyIndexing(nn.Module): """Concat with fancy indexing.""" def __init__( - self, input_shape, hidden_shape, output_shape, n_bits: int = 4, n_blocks: int = 3 + self, + input_shape, + hidden_shape, + output_shape, + n_bits: int = 4, + n_blocks: int = 3, ) -> None: """Torch Model. @@ -1361,7 +1375,10 @@ def __init__( self.quant_2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True) self.fc2 = qnn.QuantLinear( - hidden_shape * self.n_blocks, hidden_shape, bias=True, weight_bit_width=n_bits + hidden_shape * self.n_blocks, + hidden_shape, + bias=True, + weight_bit_width=n_bits, ) self.quant_3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True) @@ -1393,7 +1410,7 @@ def forward(self, x): return x -class PartialQATModel(torch.nn.Module): +class PartialQATModel(nn.Module): """A model with a QAT Module.""" def __init__(self, input_shape: int, output_shape: int, n_bits: int): @@ -1442,7 +1459,7 @@ def forward(self, input1): return output -class ManualLogisticRegressionTraining(torch.nn.Module): +class ManualLogisticRegressionTraining(nn.Module): """PyTorch module for performing SGD training.""" def __init__(self, learning_rate=0.1): @@ -1665,3 +1682,176 @@ def forward(self, x): x = self.relu(x) x = self.linear(x) return x + + +# pylint: disable-next=too-many-ancestors +class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat): + """CommonIntWeightPerChannelQuant.""" + + scaling_per_output_channel = True + + +# pylint: disable-next=too-many-ancestors +class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): + """CommonIntAccumulatorAwareWeightQuant.""" + + restrict_scaling_impl = FloatRestrictValue # backwards compatibility + bit_width = None + + +# pylint: disable-next=too-many-ancestors +class CommonIntAccumulatorAwareZeroCenterWeightQuant(Int8AccumulatorAwareZeroCenterWeightQuant): + """CommonIntAccumulatorAwareZeroCenterWeightQuant.""" + + bit_width = None + + +# pylint: disable-next=too-many-ancestors +class CommonUintActQuant(Uint8ActPerTensorFloat): + """CommonUintActQuant.""" + + bit_width = None + restrict_scaling_type = RestrictValueType.LOG_FP + + +def weight_init(layer: nn.Module): + """Initialize layer weights. + + Arguments: + layer (nn.Module): a conv2d layer + """ + + if isinstance(layer, nn.Conv2d): + nn.init.kaiming_normal_(layer.weight, nn.init.calculate_gain("relu")) + if layer.bias is not None: + layer.bias.data.zero_() + + +# pylint: disable-next=too-many-instance-attributes +class FloatLeNet(nn.Module): + """Floating point LeNet.""" + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0) + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0) + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu2 = nn.ReLU(inplace=True) + + self.fc1 = nn.Linear(400, 120, bias=True) + self.relu3 = nn.ReLU() + self.fc2 = nn.Linear(120, 84, bias=True) + self.relu4 = nn.ReLU() + self.fc3 = nn.Linear(84, 10, bias=True) + + self.apply(weight_init) + + def forward(self, x: torch.Tensor): + """Forward function. + + Arguments: + x (torch.Tensor): input image + + Returns: + Neural network prediction + """ + x = self.pool1(self.relu1(self.conv1(x))) + x = self.pool2(self.relu2(self.conv2(x))) + x = torch.flatten(x, 1) + x = self.relu3(self.fc1(x)) + x = self.relu4(self.fc2(x)) + x = self.fc3(x) + return x + + +# pylint: disable-next=too-many-instance-attributes +class QuantLeNet(FloatLeNet): + """Quantized LeNet with per-channel quantization.""" + + def __init__( + self, + weight_bit_width=4, + act_bit_width=4, + acc_bit_width=32, + weight_quant=CommonIntAccumulatorAwareWeightQuant, + ): + super().__init__() + + self.conv1 = qnn.QuantConv2d( + bias=False, + in_channels=1, + out_channels=6, + kernel_size=5, + stride=1, + padding=0, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu1 = qnn.QuantReLU( + inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width + ) + + self.conv2 = qnn.QuantConv2d( + bias=False, + in_channels=6, + out_channels=16, + kernel_size=5, + stride=1, + padding=0, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu2 = qnn.QuantReLU( + inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width + ) + + self.fc1 = qnn.QuantLinear( + 400, + 120, + bias=True, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.relu3 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width) + self.fc2 = qnn.QuantLinear( + 120, + 84, + bias=True, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.relu4 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width) + self.fc3 = qnn.QuantLinear( + 84, + 10, + bias=True, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + + self.apply(weight_init) diff --git a/src/concrete/ml/quantization/base_quantized_op.py b/src/concrete/ml/quantization/base_quantized_op.py index 86c818fb8..7aadb7aaf 100644 --- a/src/concrete/ml/quantization/base_quantized_op.py +++ b/src/concrete/ml/quantization/base_quantized_op.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, TextIO, Tuple, Type, Union, cast import numpy +import numpy.typing as npt from concrete import fhe @@ -122,6 +123,7 @@ def __init__( input_quant_opts: Optional[QuantizationOptions] = None, **attrs, ) -> None: + self.n_bits = n_bits_output if input_quant_opts is not None: @@ -916,7 +918,7 @@ def can_fuse(self) -> bool: def make_output_quant_parameters( self, q_values: Union[numpy.ndarray, Any], - scale: numpy.float64, + scale: npt.NDArray[numpy.float64], zero_point: Union[int, float, numpy.ndarray], ) -> QuantizedArray: """Build a quantized array from quantized integer results of the op and quantization params. @@ -1016,6 +1018,9 @@ def cnp_round( # Rounding to low bit-width with approximate can cause issues with overflow protection # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4345 x = fhe.round_bit_pattern( - x, lsbs_to_remove=lsbs_value, exactness=exactness, overflow_protection=False + x, + lsbs_to_remove=lsbs_value, + exactness=exactness, + overflow_protection=False, ) return x diff --git a/src/concrete/ml/quantization/post_training.py b/src/concrete/ml/quantization/post_training.py index 2d3f54f9f..eae2330e3 100644 --- a/src/concrete/ml/quantization/post_training.py +++ b/src/concrete/ml/quantization/post_training.py @@ -393,14 +393,14 @@ def _calibrate_layers_activation( assert isinstance(quant_result, QuantizedArray) return ( quant_result.dequant(), - quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None, + (quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None), ) # For QAT, the calibration is performed on raw data, performing # calibration on quantized that would confound inferred QAT and PTQ. return ( raw_result, - quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None, + (quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None), ) @abstractmethod @@ -625,7 +625,9 @@ def _quantize_layers(self, *input_calibration_data: numpy.ndarray): for input_name in variable_input_names ) output_calibration_data, layer_quantizer = self._process_layer( - quantized_op_instance, *curr_calibration_data, quantizers=layer_quant + quantized_op_instance, + *curr_calibration_data, + quantizers=layer_quant, ) node_results[output_name] = output_calibration_data node_override_quantizer[output_name] = layer_quantizer @@ -724,7 +726,9 @@ def quantize_module(self, *calibration_data: numpy.ndarray) -> QuantizedModule: return quantized_module def _process_input_quantizers( - self, quantized_module: QuantizedModule, calibration_data: Tuple[numpy.ndarray, ...] + self, + quantized_module: QuantizedModule, + calibration_data: Tuple[numpy.ndarray, ...], ): # pylint: disable=too-many-branches """Determine the quantizers for a quantized module. diff --git a/src/concrete/ml/quantization/quantized_ops.py b/src/concrete/ml/quantization/quantized_ops.py index 73cd18193..eda6e6004 100644 --- a/src/concrete/ml/quantization/quantized_ops.py +++ b/src/concrete/ml/quantization/quantized_ops.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Optional, Sequence, Set, Union import numpy +import numpy.typing as npt from concrete.fhe import conv as fhe_conv from concrete.fhe import maxpool as fhe_maxpool from concrete.fhe import tag, univariate, zeros @@ -162,7 +163,7 @@ def __init__( f"Got alpha == {alpha} and beta == {beta}.", ) - # pylint: disable-next=too-many-statements,too-many-locals + # pylint: disable-next=too-many-statements,too-many-locals,too-many-branches def q_impl( self, *q_inputs: ONNXOpInputOutputType, @@ -420,7 +421,16 @@ def copy_function(x): # Note that here we do not rescale to the output_scale and we do not add a zero-point # Any following Gemm/MatMul/Conv layers will do the rescaling (during re-quantization) # by calling _prepare_inputs_with_constants(...quantize_real_values=True) - m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale + m_matmul: npt.NDArray[numpy.float64] + + # Handle channel-wise quantization + if q_input2.quantizer.scale.shape == tuple(): + m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale + else: + assert q_input2.quantizer.scale.shape == (q_input2.qvalues.shape[0], 1) + weight_quant_scale = numpy.transpose(q_input2.quantizer.scale, axes=(1, 0)) + assert isinstance(weight_quant_scale, numpy.ndarray) + m_matmul = q_input1.quantizer.scale * weight_quant_scale # If this operation's result are network outputs, return # directly the integer values and a appropriate quantization parameters that @@ -566,7 +576,7 @@ def q_impl( # If this operator is the last one in the graph, # we rescale using the smallest scale to keep all information if self.produces_graph_output: - common_scale = min(q_input_0.quantizer.scale, q_input_1.quantizer.scale) + common_scale = numpy.minimum(q_input_0.quantizer.scale, q_input_1.quantizer.scale) # Otherwise we use the output op quantization scale else: common_scale = self.output_quant_params.scale @@ -953,7 +963,21 @@ def q_impl( # This is going to be compiled with a PBS (along with the following activation function) # Note that we don't re-quantize the output of the conv, this will be done by # any Gemm/Add/Conv layers that follow - m_matmul = q_input.quantizer.scale * q_weights.quantizer.scale + m_matmul: npt.NDArray[numpy.float64] + if q_weights.quantizer.scale.shape == tuple(): + m_matmul = q_input.quantizer.scale * q_weights.quantizer.scale + else: + expected_scale_shape = ( + q_weights.qvalues.shape[0], + *(1 for _ in q_weights.qvalues.shape[1:]), + ) + assert q_weights.quantizer.scale.shape == expected_scale_shape + weight_quant_scale = numpy.transpose( + q_weights.quantizer.scale, + axes=(1, 0, *(index for index in range(2, len(expected_scale_shape)))), + ) + assert isinstance(weight_quant_scale, numpy.ndarray) + m_matmul = q_input.quantizer.scale * weight_quant_scale bias_shape = (1, -1, 1) if is_conv1d else (1, -1, 1, 1) @@ -1046,7 +1070,10 @@ def __init__( self.auto_pad = attrs.get("auto_pad", "NOTSET") self.kernel_shape = attrs.get("kernel_shape", None) - assert_true(self.kernel_shape is not None, "Setting parameter 'kernel_shape' is required.") + assert_true( + self.kernel_shape is not None, + "Setting parameter 'kernel_shape' is required.", + ) self.count_include_pad = attrs.get("count_include_pad", 1) self.pads = attrs.get("pads", tuple([0] * 2 * (len(self.kernel_shape) - 2))) @@ -1365,7 +1392,10 @@ def q_impl( assert_true(pads.size == 4, "Not currently supporting padding of 3D tensors") pad_value = 0 if prepared_inputs[2] is None else prepared_inputs[2] - assert_true(pad_value == 0, "Concrete ML only supports padding with constant zero values") + assert_true( + pad_value == 0, + "Concrete ML only supports padding with constant zero values", + ) assert q_input.quantizer.zero_point is not None q_input_pad = numpy_onnx_pad(q_input.qvalues, pads, q_input.quantizer.zero_point, True) @@ -2289,7 +2319,7 @@ def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray: n_bits = int(self.constant_inputs[3]) self.output_quant_params = UniformQuantizationParameters( - scale=numpy.float64(self.constant_inputs[1]), + scale=numpy.array(self.constant_inputs[1], dtype=float), zero_point=int(self.constant_inputs[2]), offset=2 ** (n_bits - 1) if self.is_signed else 0, ) @@ -2913,7 +2943,11 @@ def q_impl( # Compute padding with floor and apply it to the input, pad with the input zero-point pool_pads = compute_onnx_pool_padding( - q_input.qvalues.shape, self.kernel_shape, self.pads, self.strides, ceil_mode=0 + q_input.qvalues.shape, + self.kernel_shape, + self.pads, + self.strides, + ceil_mode=0, ) # Can only pad with scalar zero-points, but zero-points can be float in special cases @@ -2929,7 +2963,14 @@ def q_impl( with tag(self.op_instance_name + ".unfold"): sum_result = fhe_conv( - q_input_pad, kernels, None, fake_pads, self.strides, None, None, n_in_channels + q_input_pad, + kernels, + None, + fake_pads, + self.strides, + None, + None, + n_in_channels, ) if self.debug_value_tracker is not None: diff --git a/src/concrete/ml/quantization/quantizers.py b/src/concrete/ml/quantization/quantizers.py index c1bd058a0..fc30d1ee8 100644 --- a/src/concrete/ml/quantization/quantizers.py +++ b/src/concrete/ml/quantization/quantizers.py @@ -8,6 +8,7 @@ import numpy from concrete.fhe.tracing.tracer import Tracer +from numpy import typing as npt from ..common.debugging import assert_true from ..common.serialization.dumpers import dump, dumps @@ -385,13 +386,13 @@ class UniformQuantizationParameters: The parameters are computed from quantization options and quantization statistics. """ - scale: Optional[numpy.float64] = None + scale: Optional[npt.NDArray[numpy.float64]] = None zero_point: Optional[Union[int, float, numpy.ndarray]] = None offset: Optional[int] = None def __init__( self, - scale: Optional[numpy.float64] = None, + scale: Optional[npt.NDArray[numpy.float64]] = None, zero_point: Optional[Union[int, float, numpy.ndarray]] = None, offset: Optional[int] = None, ): @@ -512,7 +513,7 @@ def compute_quantization_parameters( if numpy.abs(stats.rmax) < STABILITY_CONST: # If the value is a 0 we cannot do it since the scale would become 0 as well # resulting in division by 0 - self.scale = numpy.float64(1.0) + self.scale = numpy.array(1.0, dtype=float) # Ideally we should get rid of round here but it is risky # regarding the FHE compilation. # Indeed, the zero_point value for the weights has to be an integer @@ -521,7 +522,7 @@ def compute_quantization_parameters( else: # If the value is not a 0 we can tweak the scale factor so that # the value quantizes to 1 - self.scale = numpy.float64(stats.rmax) + self.scale = numpy.array(stats.rmax, dtype=float) self.zero_point = 0 else: if options.is_symmetric: @@ -556,13 +557,16 @@ def compute_quantization_parameters( "This can occur with a badly trained model.", ) unique_scales = numpy.unique(numpy.diff(stats.uvalues)) - self.scale = numpy.float64(unique_scales[0]) + self.scale = numpy.array(unique_scales[0], dtype=float) if self.scale is None: - self.scale = numpy.float64( - (stats.rmax - stats.rmin) / (2**options.n_bits - 1) - if stats.rmax != stats.rmin - else 1.0 + self.scale = numpy.array( + ( + (stats.rmax - stats.rmin) / (2**options.n_bits - 1) + if stats.rmax != stats.rmin + else 1.0 + ), + dtype=float, ) if options.is_qat: @@ -618,7 +622,7 @@ def __init__( # Force scale to be a float64 if self.scale is not None: - self.scale = numpy.float64(self.scale) + self.scale = numpy.array(self.scale, dtype=float) def __eq__(self, other) -> bool: @@ -908,7 +912,7 @@ def _values_setup( elif isinstance(values, Tracer): self.values = values else: - self.values = numpy.array(values) + self.values = numpy.array(values, dtype=float) # If no stats are provided, compute them. # Note that this cannot be done during tracing @@ -944,7 +948,7 @@ def _values_setup( elif isinstance(values, Tracer): self.qvalues = values else: - self.qvalues = numpy.array(values) # pragma: no cover + self.qvalues = numpy.array(values, dtype=float) # pragma: no cover # Populate self.values self.dequant() @@ -1014,7 +1018,7 @@ def update_values(self, values: Union[numpy.ndarray, Tracer]) -> Union[numpy.nda elif isinstance(values, Tracer): # pragma: no cover self.values = values else: # pragma: no cover - self.values = numpy.array(values) + self.values = numpy.array(values, dtype=float) return self.quant() def update_quantized_values( @@ -1033,7 +1037,7 @@ def update_quantized_values( elif isinstance(qvalues, Tracer): # pragma: no cover self.qvalues = qvalues else: # pragma: no cover - self.qvalues = numpy.array(qvalues) + self.qvalues = numpy.array(qvalues, dtype=float) return self.dequant() def quant(self) -> Union[numpy.ndarray, Tracer]: diff --git a/tests/torch/test_brevitas_qat.py b/tests/torch/test_brevitas_qat.py index 32331fc42..c3075ace4 100644 --- a/tests/torch/test_brevitas_qat.py +++ b/tests/torch/test_brevitas_qat.py @@ -23,6 +23,7 @@ from concrete.ml.pytest.torch_models import ( NetWithConstantsFoldedBeforeOps, QuantCustomModel, + QuantLeNet, TinyQATCNN, ) from concrete.ml.quantization.base_quantized_op import QuantizedMixingOp @@ -56,7 +57,6 @@ def forward_test_torch(net, test_loader): # Accumulate the ground truth labels endidx = idx + target.shape[0] all_targets[idx:endidx] = target.numpy() - # Run forward and get the raw predictions first raw_pred = net(data).detach().numpy() @@ -97,7 +97,11 @@ def train_brevitas_network_tinymnist(is_cnn, qat_bits, signed, narrow, pot_scali x_all = numpy.expand_dims(x_all.reshape((-1, 8, 8)), 1) x_train, x_test, y_train, y_test = train_test_split( - x_all, y_all, test_size=0.25, shuffle=True, random_state=numpy.random.randint(0, 2**15) + x_all, + y_all, + test_size=0.25, + shuffle=True, + random_state=numpy.random.randint(0, 2**15), ) def train_one_epoch(net, optimizer, train_loader): @@ -584,6 +588,7 @@ def test_brevitas_power_of_two( # Ensure rounding was compiled in the circuit # the number of rounding nodes should be equal + assert quantized_module.fhe_circuit is not None num_rounding_mlir = quantized_module.fhe_circuit.mlir.count(".round") assert num_rounding_mlir == 2, "Power-of-to adapter: Rounding nodes not found in MLIR" @@ -606,3 +611,21 @@ def test_brevitas_power_of_two( check_array_equal(y_pred_sim_round, y_pred_clear_round) check_array_equal(y_pred_clear_round, y_pred_clear_no_round) + + +@pytest.mark.flaky +def test_brevitas_channel_wise(check_float_array_equal): + """Make sure that we can compile brevitas channel-wise quantization""" + model = QuantLeNet() + model.eval() + + with torch.no_grad(): + batch_size = 3 + image_size = 1, 32, 32 + images = torch.rand((batch_size, *image_size)) + out = model(images).detach().numpy() + quantized_module = compile_brevitas_qat_model(model, images, rounding_threshold_bits=6) + out_qm = quantized_module(images.detach().numpy()) + + # The following values are quite arbitrary + check_float_array_equal(out, out_qm, atol=0.02, rtol=0.5)