diff --git a/src/concrete/ml/quantization/quantized_ops.py b/src/concrete/ml/quantization/quantized_ops.py index a275e04a59..a89c7c9eb7 100644 --- a/src/concrete/ml/quantization/quantized_ops.py +++ b/src/concrete/ml/quantization/quantized_ops.py @@ -1612,6 +1612,49 @@ class QuantizedBatchNormalization(QuantizedOp): _impl_for_op_named: str = "BatchNormalization" + def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray: + """Create corresponding QuantizedArray for the output of the activation function. + + Args: + *inputs (numpy.ndarray): Calibration sample inputs. + + Returns: + numpy.ndarray: the output values for the provided calibration samples. + """ + + # Here we need the actual values of the constants, we need to pass through + # the numpy.ndarrays in the computation graph + prepared_inputs = self._prepare_inputs_with_constants( + *inputs, calibrate=True, quantize_actual_values=False + ) + + raw_result = self.call_impl(*prepared_inputs, **self.attrs) + if isinstance(raw_result, RawOpOutput): + return raw_result + + # Check if batch normalization is applied per channel + scale = self.constant_inputs[self._params_name_to_input_idx["scale"]].values + bias = self.constant_inputs[self._params_name_to_input_idx["bias"]].values + if scale.size > 1 or bias.size > 1: + # Per channel batchnorm behave poorly with low bit-width quantization. + # This is because per channel batchnorm can have paramters (scale / bias / mean / var) + # order of magnitude different between channels. But our tensor based quantization + # only allows us to provide a global scale and offset. Any error in quantized values + # could be dramatic. + # To avoid this, we use percentiles to clip the extreme values. + + lower_bound = numpy.percentile(raw_result, 0.1) + upper_bound = numpy.percentile(raw_result, 99.9) + + raw_result = numpy.clip(raw_result, lower_bound, upper_bound) + + quantized_samples = QuantizedArray(self.n_bits, raw_result) + + self.output_quant_params = quantized_samples.quantizer.quant_params + self.output_quant_stats = quantized_samples.quantizer.quant_stats + + return quantized_samples.values + class QuantizedFlatten(QuantizedOp): """Quantized flatten for encrypted inputs."""