diff --git a/src/concrete/ml/quantization/linear_op_glwe_backend.py b/src/concrete/ml/quantization/linear_op_glwe_backend.py index 66604928b..98b2b4527 100644 --- a/src/concrete/ml/quantization/linear_op_glwe_backend.py +++ b/src/concrete/ml/quantization/linear_op_glwe_backend.py @@ -3,6 +3,7 @@ import json import numpy +import torch from ..common.utils import HybridFHEMode, to_tuple from .quantized_module import QuantizedModule @@ -55,8 +56,8 @@ def keygen(self): ) def forward( - self, x: numpy.ndarray, q_module: QuantizedModule, fhe: HybridFHEMode - ) -> numpy.ndarray: + self, x: torch.Tensor, q_module: QuantizedModule, fhe: HybridFHEMode + ) -> torch.Tensor: """Perform the inference of this linear layer. Args: @@ -91,23 +92,23 @@ def forward( assert weight_bias[0].quantizer.quant_params.zero_point == 0 # Retrieve quantized weights - q_weight = weight_bias[0].qvalues + q_weight = weight_bias[0].values + assert(isinstance(q_weight, numpy.ndarray)) + assert(q_weight.dtype == numpy.float32) q_weight = numpy.transpose(q_weight) if transpose_inputs2 else q_weight - q_x = q_module.quantize_input(x) - assert q_x is not None - assert isinstance(q_x, numpy.ndarray) - - q_x = numpy.transpose(q_x) if transpose_inputs1 else q_x + q_x = q_module.quantize_input(x, dtype=numpy.float32 if fhe == HybridFHEMode.DISABLE else None) + q_x = torch.transpose(q_x) if transpose_inputs1 else q_x if fhe == HybridFHEMode.DISABLE: # There is no need to add the bias to the de-quantized values # as the bias is already included in the output quantizer # zero-point, in the analytical calibration - q_x = q_x.astype(numpy.float32) - q_weight = q_weight.astype(numpy.float32) - y = q_module.dequantize_output(*to_tuple(numpy.matmul(q_x, q_weight))) + + q_w = torch.from_numpy(q_weight).to(q_x.device) + mm = torch.matmul(q_x, q_w) + y = q_module.dequantize_output(*to_tuple(mm)) else: # Need to slice the last GLWE (this will be improved in later cml-extensions) num_valid_glwe_values_in_last_ciphertext = ( @@ -162,7 +163,6 @@ def forward( if return_2d: y = numpy.squeeze(y) - # Only single outputs are supported - assert isinstance(y, numpy.ndarray) + y = y.astype(numpy.float32) return y diff --git a/src/concrete/ml/quantization/quantized_module.py b/src/concrete/ml/quantization/quantized_module.py index 7761d7bbe..a14996d5a 100644 --- a/src/concrete/ml/quantization/quantized_module.py +++ b/src/concrete/ml/quantization/quantized_module.py @@ -8,6 +8,7 @@ import numpy import onnx +import torch from concrete.fhe.compilation.artifacts import DebugArtifacts from concrete.fhe.compilation.circuit import Circuit from concrete.fhe.compilation.compiler import Compiler @@ -702,7 +703,7 @@ def _fhe_forward( return q_results def quantize_input( - self, *x: Optional[numpy.ndarray] + self, *x: Optional[Union[numpy.ndarray, torch.Tensor]], dtype=numpy.int64 ) -> Union[numpy.ndarray, Tuple[Optional[numpy.ndarray], ...]]: """Take the inputs in fp32 and quantize it using the learned quantization parameters. @@ -729,7 +730,7 @@ def quantize_input( # cannot be None q_x = tuple( ( - self.input_quantizers[idx].quant(x[idx]) # type: ignore[arg-type] + self.input_quantizers[idx].quant(x[idx], dtype) # type: ignore[arg-type] if x[idx] is not None else None ) @@ -738,7 +739,7 @@ def quantize_input( # Make sure all inputs are quantized to int64 assert all_values_are_of_dtype( - *q_x, dtypes="int64", allow_none=True + *q_x, dtypes=numpy.dtype(dtype).name, allow_none=True ), "Inputs were not quantized to int64" if len(q_x) == 1: @@ -749,8 +750,8 @@ def quantize_input( return q_x def dequantize_output( - self, *q_y_preds: numpy.ndarray - ) -> Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]: + self, *q_y_preds: Union[numpy.ndarray, torch.Tensor] + ) -> Union[Union[numpy.ndarray, torch.Tensor], Tuple[Union[numpy.ndarray, torch.Tensor], ...]]: """Take the last layer q_out and use its de-quant function. Args: @@ -767,10 +768,13 @@ def dequantize_output( ) y_preds = tuple( - numpy.array(output_quantizer.dequant(q_y_pred)) + output_quantizer.dequant(q_y_pred) for q_y_pred, output_quantizer in zip(q_y_preds, self.output_quantizers) ) + if not isinstance(q_y_preds[0], torch.Tensor): + y_preds = tuple(map(numpy.array, y_preds)) + if len(y_preds) == 1: return y_preds[0] diff --git a/src/concrete/ml/quantization/quantizers.py b/src/concrete/ml/quantization/quantizers.py index 8e65b54d3..c42706abc 100644 --- a/src/concrete/ml/quantization/quantizers.py +++ b/src/concrete/ml/quantization/quantizers.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Optional, TextIO, Union, get_type_hints import numpy +import torch from concrete.fhe.tracing.tracer import Tracer from ..common.debugging import assert_true @@ -671,7 +672,7 @@ def dump(self, file: TextIO) -> None: """ dump(self, file) - def quant(self, values: numpy.ndarray) -> numpy.ndarray: + def quant(self, values: Union[numpy.ndarray, torch.Tensor], dtype=numpy.int64) -> numpy.ndarray: """Quantize values. Args: @@ -686,10 +687,17 @@ def quant(self, values: numpy.ndarray) -> numpy.ndarray: assert self.offset is not None assert self.scale is not None - if QUANT_ROUND_LIKE_ROUND_PBS: - qvalues = numpy.floor(values / self.scale + self.zero_point + 0.5) # pragma: no cover + assert dtype in (numpy.int64, numpy.int32, numpy.float32, numpy.float64) + + delta = 0.5 if QUANT_ROUND_LIKE_ROUND_PBS else 0 + if isinstance(values, numpy.ndarray): + round_func = numpy.floor if QUANT_ROUND_LIKE_ROUND_PBS else numpy.rint + clip_func = numpy.clip else: - qvalues = numpy.rint(values / self.scale + self.zero_point) + round_func = torch.floor if QUANT_ROUND_LIKE_ROUND_PBS else torch.round + clip_func = torch.clip + + qvalues = round_func(values / self.scale + self.zero_point + delta) # Clipping must be performed for PTQ and for precomputed (for now only Brevitas) QAT # (where quantizer parameters are available in ONNX layers). @@ -705,11 +713,15 @@ def quant(self, values: numpy.ndarray) -> numpy.ndarray: if self.is_narrow: min_value += 1 - qvalues = qvalues.clip(min_value, 2 ** (self.n_bits) - 1 - self.offset) + qvalues = clip_func(qvalues, min_value, 2 ** (self.n_bits) - 1 - self.offset) - return qvalues.astype(numpy.int64) + # Only cast for numpy usage for Concrete circuits + if isinstance(values, numpy.ndarray): + qvalues = qvalues.astype(dtype) - def dequant(self, qvalues: numpy.ndarray) -> Union[float, numpy.ndarray, Tracer]: + return qvalues + + def dequant(self, qvalues: Union[numpy.ndarray, torch.Tensor]) -> Union[float, numpy.ndarray, torch.Tensor, Tracer]: """De-quantize values. Args: @@ -731,9 +743,13 @@ def dequant(self, qvalues: numpy.ndarray) -> Union[float, numpy.ndarray, Tracer] + ((" " + str(self.scale.dtype)) if isinstance(self.scale, numpy.ndarray) else ""), ) - values = self.scale * (qvalues - numpy.asarray(self.zero_point, dtype=numpy.float64)) + prepared_zp = numpy.asarray(self.zero_point, dtype=numpy.float64) + if isinstance(qvalues, torch.Tensor): + prepared_zp = torch.from_numpy(prepared_zp).float().to(qvalues.device) + + values = self.scale * (qvalues - prepared_zp) - assert isinstance(values, (float, numpy.ndarray, Tracer)), f"{values=}, {type(values)=}" + assert isinstance(values, (float, numpy.ndarray, torch.Tensor, Tracer)), f"{values=}, {type(values)=}" return values diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index 5aa58e5a0..aa21490db 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -247,16 +247,15 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]: if self.executor: # Delegate to the optimized GLWE executor - y = torch.Tensor( - self.executor.forward( - x.detach().numpy(), self.private_q_module, self.fhe_local_mode + y = self.executor.forward( + x.detach(), self.private_q_module, self.fhe_local_mode ) - ) else: + device = x.device # Delegate to the quantized module for all fhe modes y = torch.Tensor( - self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode.value) - ) + self.private_q_module.forward(x.cpu().detach().numpy(), fhe=self.fhe_local_mode.value) + ).to(device) elif self.fhe_local_mode == HybridFHEMode.CALIBRATE: # Calling torch + gathering calibration data @@ -568,7 +567,9 @@ def compile_model( self.configuration = configuration - for name in self.module_names: + from tqdm import tqdm + + for name in tqdm(self.module_names): remote_module = self._get_module_by_name(self.model, name) assert isinstance(remote_module, RemoteModule) @@ -596,6 +597,13 @@ def compile_model( n_bits=n_bits, rounding_threshold_bits=rounding_threshold_bits, ) + + vals = self.private_q_modules[name].quant_layers_dict.values() + _, q_op = next(iter(vals)) + const_inp = q_op.constant_inputs[1] # Get the weights, the bias is in [2] + const_inp.values = const_inp.qvalues.astype(numpy.float32) + + self.private_q_modules[name]._onnx_model = None else: self.private_q_modules[name] = compile_torch_model( self.private_modules[name], @@ -608,6 +616,8 @@ def compile_model( self.remote_modules[name].private_q_module = self.private_q_modules[name] + remote_module.calibration_data = None + def _save_fhe_circuit(self, path: Path, via_mlir=False): """Private method that saves the FHE circuits.