Skip to content

Commit

Permalink
chore: fix batchn norm per channel with clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Jun 3, 2024
1 parent a8f37ac commit 476e939
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions src/concrete/ml/quantization/quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,49 @@ class QuantizedBatchNormalization(QuantizedOp):

_impl_for_op_named: str = "BatchNormalization"

def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray:
"""Create corresponding QuantizedArray for the output of the activation function.
Args:
*inputs (numpy.ndarray): Calibration sample inputs.
Returns:
numpy.ndarray: the output values for the provided calibration samples.
"""

# Here we need the actual values of the constants, we need to pass through
# the numpy.ndarrays in the computation graph
prepared_inputs = self._prepare_inputs_with_constants(
*inputs, calibrate=True, quantize_actual_values=False
)

raw_result = self.call_impl(*prepared_inputs, **self.attrs)
if isinstance(raw_result, RawOpOutput):
return raw_result

# Check if batch normalization is applied per channel
scale = self.constant_inputs[self._params_name_to_input_idx["scale"]].values
bias = self.constant_inputs[self._params_name_to_input_idx["bias"]].values
if scale.size > 1 or bias.size > 1:
# Per channel batchnorm behave poorly with low bit-width quantization.
# This is because per channel batchnorm can have paramters (scale / bias / mean / var)
# order of magnitude different between channels. But our tensor based quantization
# only allows us to provide a global scale and offset. Any error in quantized values
# could be dramatic.
# To avoid this, we use percentiles to clip the extreme values.

lower_bound = numpy.percentile(raw_result, 0.1)
upper_bound = numpy.percentile(raw_result, 99.9)

raw_result = numpy.clip(raw_result, lower_bound, upper_bound)

quantized_samples = QuantizedArray(self.n_bits, raw_result)

self.output_quant_params = quantized_samples.quantizer.quant_params
self.output_quant_stats = quantized_samples.quantizer.quant_stats

return quantized_samples.values


class QuantizedFlatten(QuantizedOp):
"""Quantized flatten for encrypted inputs."""
Expand Down

0 comments on commit 476e939

Please sign in to comment.