Skip to content

Commit

Permalink
fix: full gpu hybrid model
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Dec 18, 2024
1 parent 8014ec5 commit 43601e1
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 35 deletions.
26 changes: 13 additions & 13 deletions src/concrete/ml/quantization/linear_op_glwe_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json

import numpy
import torch

from ..common.utils import HybridFHEMode, to_tuple
from .quantized_module import QuantizedModule
Expand Down Expand Up @@ -55,8 +56,8 @@ def keygen(self):
)

def forward(
self, x: numpy.ndarray, q_module: QuantizedModule, fhe: HybridFHEMode
) -> numpy.ndarray:
self, x: torch.Tensor, q_module: QuantizedModule, fhe: HybridFHEMode
) -> torch.Tensor:
"""Perform the inference of this linear layer.
Args:
Expand Down Expand Up @@ -91,23 +92,23 @@ def forward(
assert weight_bias[0].quantizer.quant_params.zero_point == 0

# Retrieve quantized weights
q_weight = weight_bias[0].qvalues
q_weight = weight_bias[0].values
assert(isinstance(q_weight, numpy.ndarray))
assert(q_weight.dtype == numpy.float32)

q_weight = numpy.transpose(q_weight) if transpose_inputs2 else q_weight

q_x = q_module.quantize_input(x)
assert q_x is not None
assert isinstance(q_x, numpy.ndarray)

q_x = numpy.transpose(q_x) if transpose_inputs1 else q_x
q_x = q_module.quantize_input(x, dtype=numpy.float32 if fhe == HybridFHEMode.DISABLE else None)
q_x = torch.transpose(q_x) if transpose_inputs1 else q_x

if fhe == HybridFHEMode.DISABLE:
# There is no need to add the bias to the de-quantized values
# as the bias is already included in the output quantizer
# zero-point, in the analytical calibration
q_x = q_x.astype(numpy.float32)
q_weight = q_weight.astype(numpy.float32)
y = q_module.dequantize_output(*to_tuple(numpy.matmul(q_x, q_weight)))

q_w = torch.from_numpy(q_weight).to(q_x.device)
mm = torch.matmul(q_x, q_w)
y = q_module.dequantize_output(*to_tuple(mm))
else:
# Need to slice the last GLWE (this will be improved in later cml-extensions)
num_valid_glwe_values_in_last_ciphertext = (
Expand Down Expand Up @@ -162,7 +163,6 @@ def forward(
if return_2d:
y = numpy.squeeze(y)

# Only single outputs are supported
assert isinstance(y, numpy.ndarray)
y = y.astype(numpy.float32)

return y
16 changes: 10 additions & 6 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy
import onnx
import torch
from concrete.fhe.compilation.artifacts import DebugArtifacts
from concrete.fhe.compilation.circuit import Circuit
from concrete.fhe.compilation.compiler import Compiler
Expand Down Expand Up @@ -702,7 +703,7 @@ def _fhe_forward(
return q_results

def quantize_input(
self, *x: Optional[numpy.ndarray]
self, *x: Optional[Union[numpy.ndarray, torch.Tensor]], dtype=numpy.int64
) -> Union[numpy.ndarray, Tuple[Optional[numpy.ndarray], ...]]:
"""Take the inputs in fp32 and quantize it using the learned quantization parameters.
Expand All @@ -729,7 +730,7 @@ def quantize_input(
# cannot be None
q_x = tuple(
(
self.input_quantizers[idx].quant(x[idx]) # type: ignore[arg-type]
self.input_quantizers[idx].quant(x[idx], dtype) # type: ignore[arg-type]
if x[idx] is not None
else None
)
Expand All @@ -738,7 +739,7 @@ def quantize_input(

# Make sure all inputs are quantized to int64
assert all_values_are_of_dtype(
*q_x, dtypes="int64", allow_none=True
*q_x, dtypes=numpy.dtype(dtype).name, allow_none=True
), "Inputs were not quantized to int64"

if len(q_x) == 1:
Expand All @@ -749,8 +750,8 @@ def quantize_input(
return q_x

def dequantize_output(
self, *q_y_preds: numpy.ndarray
) -> Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]:
self, *q_y_preds: Union[numpy.ndarray, torch.Tensor]
) -> Union[Union[numpy.ndarray, torch.Tensor], Tuple[Union[numpy.ndarray, torch.Tensor], ...]]:
"""Take the last layer q_out and use its de-quant function.
Args:
Expand All @@ -767,10 +768,13 @@ def dequantize_output(
)

y_preds = tuple(
numpy.array(output_quantizer.dequant(q_y_pred))
output_quantizer.dequant(q_y_pred)
for q_y_pred, output_quantizer in zip(q_y_preds, self.output_quantizers)
)

if not isinstance(q_y_preds[0], torch.Tensor):
y_preds = tuple(map(numpy.array, y_preds))

if len(y_preds) == 1:
return y_preds[0]

Expand Down
34 changes: 25 additions & 9 deletions src/concrete/ml/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict, Optional, TextIO, Union, get_type_hints

import numpy
import torch
from concrete.fhe.tracing.tracer import Tracer

from ..common.debugging import assert_true
Expand Down Expand Up @@ -671,7 +672,7 @@ def dump(self, file: TextIO) -> None:
"""
dump(self, file)

def quant(self, values: numpy.ndarray) -> numpy.ndarray:
def quant(self, values: Union[numpy.ndarray, torch.Tensor], dtype=numpy.int64) -> numpy.ndarray:
"""Quantize values.
Args:
Expand All @@ -686,10 +687,17 @@ def quant(self, values: numpy.ndarray) -> numpy.ndarray:
assert self.offset is not None
assert self.scale is not None

if QUANT_ROUND_LIKE_ROUND_PBS:
qvalues = numpy.floor(values / self.scale + self.zero_point + 0.5) # pragma: no cover
assert dtype in (numpy.int64, numpy.int32, numpy.float32, numpy.float64)

delta = 0.5 if QUANT_ROUND_LIKE_ROUND_PBS else 0
if isinstance(values, numpy.ndarray):
round_func = numpy.floor if QUANT_ROUND_LIKE_ROUND_PBS else numpy.rint
clip_func = numpy.clip
else:
qvalues = numpy.rint(values / self.scale + self.zero_point)
round_func = torch.floor if QUANT_ROUND_LIKE_ROUND_PBS else torch.round
clip_func = torch.clip

qvalues = round_func(values / self.scale + self.zero_point + delta)

# Clipping must be performed for PTQ and for precomputed (for now only Brevitas) QAT
# (where quantizer parameters are available in ONNX layers).
Expand All @@ -705,11 +713,15 @@ def quant(self, values: numpy.ndarray) -> numpy.ndarray:
if self.is_narrow:
min_value += 1

qvalues = qvalues.clip(min_value, 2 ** (self.n_bits) - 1 - self.offset)
qvalues = clip_func(qvalues, min_value, 2 ** (self.n_bits) - 1 - self.offset)

return qvalues.astype(numpy.int64)
# Only cast for numpy usage for Concrete circuits
if isinstance(values, numpy.ndarray):
qvalues = qvalues.astype(dtype)

def dequant(self, qvalues: numpy.ndarray) -> Union[float, numpy.ndarray, Tracer]:
return qvalues

def dequant(self, qvalues: Union[numpy.ndarray, torch.Tensor]) -> Union[float, numpy.ndarray, torch.Tensor, Tracer]:
"""De-quantize values.
Args:
Expand All @@ -731,9 +743,13 @@ def dequant(self, qvalues: numpy.ndarray) -> Union[float, numpy.ndarray, Tracer]
+ ((" " + str(self.scale.dtype)) if isinstance(self.scale, numpy.ndarray) else ""),
)

values = self.scale * (qvalues - numpy.asarray(self.zero_point, dtype=numpy.float64))
prepared_zp = numpy.asarray(self.zero_point, dtype=numpy.float64)
if isinstance(qvalues, torch.Tensor):
prepared_zp = torch.from_numpy(prepared_zp).float().to(qvalues.device)

values = self.scale * (qvalues - prepared_zp)

assert isinstance(values, (float, numpy.ndarray, Tracer)), f"{values=}, {type(values)=}"
assert isinstance(values, (float, numpy.ndarray, torch.Tensor, Tracer)), f"{values=}, {type(values)=}"
return values


Expand Down
24 changes: 17 additions & 7 deletions src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,15 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:

if self.executor:
# Delegate to the optimized GLWE executor
y = torch.Tensor(
self.executor.forward(
x.detach().numpy(), self.private_q_module, self.fhe_local_mode
y = self.executor.forward(
x.detach(), self.private_q_module, self.fhe_local_mode
)
)
else:
device = x.device
# Delegate to the quantized module for all fhe modes
y = torch.Tensor(
self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode.value)
)
self.private_q_module.forward(x.cpu().detach().numpy(), fhe=self.fhe_local_mode.value)
).to(device)

elif self.fhe_local_mode == HybridFHEMode.CALIBRATE:
# Calling torch + gathering calibration data
Expand Down Expand Up @@ -568,7 +567,9 @@ def compile_model(

self.configuration = configuration

for name in self.module_names:
from tqdm import tqdm

for name in tqdm(self.module_names):
remote_module = self._get_module_by_name(self.model, name)
assert isinstance(remote_module, RemoteModule)

Expand Down Expand Up @@ -596,6 +597,13 @@ def compile_model(
n_bits=n_bits,
rounding_threshold_bits=rounding_threshold_bits,
)

vals = self.private_q_modules[name].quant_layers_dict.values()
_, q_op = next(iter(vals))
const_inp = q_op.constant_inputs[1] # Get the weights, the bias is in [2]
const_inp.values = const_inp.qvalues.astype(numpy.float32)

self.private_q_modules[name]._onnx_model = None
else:
self.private_q_modules[name] = compile_torch_model(
self.private_modules[name],
Expand All @@ -608,6 +616,8 @@ def compile_model(

self.remote_modules[name].private_q_module = self.private_q_modules[name]

remote_module.calibration_data = None

def _save_fhe_circuit(self, path: Path, via_mlir=False):
"""Private method that saves the FHE circuits.
Expand Down

0 comments on commit 43601e1

Please sign in to comment.