-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: GLWE mat mul for hybrid model #868
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
ac76836858506534a0dc01cae9341f7d | ||
8ea8aec4f5aac03565c2dcb9f3f8a1da |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
ac76836858506534a0dc01cae9341f7d | ||
8ea8aec4f5aac03565c2dcb9f3f8a1da |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
ac76836858506534a0dc01cae9341f7d | ||
8ea8aec4f5aac03565c2dcb9f3f8a1da |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -77,7 +77,7 @@ concrete_clf.compile(X, debug_config) | |
|
||
#### 3. Quantization import failed | ||
|
||
**Error message**: `Error occurred during quantization aware training (QAT) import [...] Could not determine a unique scale for the quantization!`. | ||
**Error message**: `Error occurred during quantization aware training (QAT) import [...] Are you missing a QuantIdentity layer in your Brevitas model?`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also have a non-intuitive error that occurs when using 'view' instead of 'reshape'. Error: ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size ... is different from ...) Maybe we could add it ? CC @andrei-stoian-zama There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's not related, it's due to shape inference during the onnx export. |
||
|
||
**Cause**: This error occurs when the model imported as a quantized-aware training model lacks quantization operators. See [this guide](../deep-learning/fhe_friendly_models.md) on how to use Brevitas layers. This error message indicates that some layers do not take inputs quantized through `QuantIdentity` layers. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,11 +112,12 @@ class FCSmall(nn.Module): | |
super().__init__() | ||
self.quant_input = qnn.QuantIdentity(bit_width=3) | ||
self.fc1 = qnn.QuantLinear(in_features=input_output, out_features=input_output, weight_bit_width=3, bias=True) | ||
self.quant_2 = qnn.QuantIdentity(bit_width=3) | ||
self.act_f = nn.ReLU() | ||
self.fc2 = qnn.QuantLinear(in_features=input_output, out_features=input_output, weight_bit_width=3, bias=True) | ||
|
||
def forward(self, x): | ||
return self.fc2(self.act_f(self.fc1(self.quant_input(x)))) | ||
return self.fc2(self.quant_2(self.act_f(self.fc1(self.quant_input(x))))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was bad all along and only worked because the QAT quantization was guessed from the calibration data There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm really not a fan of using quantIdentity layers after quantLayers. It works fine without them, and including them just adds unnecessary complexity in my opinion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it used to work because of some automatic detection of quant parameters that we were doing. but I removed that detection since it was slow. so now we need quantidentity in the right places |
||
|
||
torch_model = FCSmall(3) | ||
|
||
|
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,11 +63,13 @@ def forward(self, inputs): | |
class FCSmall(nn.Module): | ||
"""Torch model for the tests.""" | ||
|
||
def __init__(self, input_output, activation_function): | ||
def __init__(self, input_output, activation_function, hidden=None): | ||
super().__init__() | ||
self.fc1 = nn.Linear(in_features=input_output, out_features=input_output) | ||
|
||
hidden_size = input_output if hidden is None else hidden | ||
self.fc1 = nn.Linear(in_features=input_output, out_features=hidden_size) | ||
self.act_f = activation_function() | ||
self.fc2 = nn.Linear(in_features=input_output, out_features=input_output) | ||
self.fc2 = nn.Linear(in_features=hidden_size, out_features=input_output) | ||
|
||
def forward(self, x): | ||
"""Forward pass. | ||
|
@@ -850,7 +852,7 @@ def forward(self, x): | |
return x | ||
|
||
|
||
class SimpleQAT(nn.Module): | ||
class StepFunctionPTQ(nn.Module): | ||
"""Torch model implements a step function that needs Greater, Cast and Where.""" | ||
|
||
def __init__(self, input_output, activation_function, n_bits=2, disable_bit_check=False): | ||
|
@@ -1354,17 +1356,17 @@ def __init__( | |
super().__init__() | ||
|
||
self.n_blocks = n_blocks | ||
self.quant_1 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True) | ||
self.quant_1 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=False) | ||
self.fc1 = qnn.QuantLinear(input_shape, hidden_shape, bias=False, weight_bit_width=n_bits) | ||
|
||
self.quant_concat = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True) | ||
self.quant_concat = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=False) | ||
|
||
self.quant_2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True) | ||
self.quant_2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=False) | ||
self.fc2 = qnn.QuantLinear( | ||
hidden_shape * self.n_blocks, hidden_shape, bias=True, weight_bit_width=n_bits | ||
) | ||
|
||
self.quant_3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True) | ||
self.quant_3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=False) | ||
self.fc4 = qnn.QuantLinear(hidden_shape, output_shape, bias=True, weight_bit_width=n_bits) | ||
|
||
def forward(self, x): | ||
|
@@ -1379,9 +1381,9 @@ def forward(self, x): | |
x_pre = [] | ||
|
||
for i in range(self.n_blocks): | ||
x_block = x[:, i, :] | ||
q1_out = self.quant_1(x_block) | ||
fc1_out = self.fc1(q1_out) | ||
q_x = self.quant_1(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was wrong.. the input to the slicing should be quantized with QuantIdentity |
||
q_x_block = q_x[:, i, :] | ||
fc1_out = self.fc1(q_x_block) | ||
q_concat_out = self.quant_concat(fc1_out) | ||
|
||
x_pre.append(q_concat_out) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
QuantizationOptions, | ||
QuantizedArray, | ||
UniformQuantizationParameters, | ||
UniformQuantizer, | ||
) | ||
|
||
# pylint: disable=too-many-lines | ||
|
@@ -559,7 +560,10 @@ def _prepare_quantized_input(self, input_: QuantizedArray) -> QuantizedArray: | |
# but when parsing the ONNX graph, some options can be overwritten. Thus | ||
# when evaluating QAT layers we ignore one of these options to allow the | ||
# override. | ||
if quant_opts.is_equal(input_.quantizer.quant_options, ignore_sign_qat=True): | ||
if ( | ||
quant_opts.is_equal(input_.quantizer.quant_options, ignore_sign_qat=True) | ||
or input_.quantizer.quant_options.is_precomputed_qat | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. before we were overriding the op input quantization mainly for the model input. now we override every time there is a BrevitasQuant layer before the op There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation, maybe you should update the comment accordingly. |
||
): | ||
# Pass-through the input quantizer when the input is already quantized in | ||
# the manner that this op requires: this makes the op use the qvalues directly, | ||
# in q_impl and will avoid a TLU to re-quantize. | ||
|
@@ -661,7 +665,9 @@ def _prepare_inputs_with_constants( | |
elif calibrate or is_clear_value: | ||
# This is used during calibration with numpy.ndarrays | ||
# or then the input is raw (not quantized) | ||
prepared_inputs[curr_input_fill_idx] = input_ | ||
prepared_inputs[curr_input_fill_idx] = ( | ||
input_.values if isinstance(input_, QuantizedArray) else input_ | ||
) | ||
elif quantize_actual_values: | ||
# This is used by mixing (conv/gemm) or value re-arranging ops (reshape) | ||
input_ = cast(QuantizedArray, input_) | ||
|
@@ -674,9 +680,6 @@ def _prepare_inputs_with_constants( | |
new_input.quantizer.is_qat | ||
and not input_.quantizer.is_precomputed_qat | ||
and self.error_tracker is not None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. basically:
|
||
and not new_input.quantizer.check_is_uniform_quantized( | ||
new_input.quantizer.quant_options | ||
) | ||
): | ||
self.error_tracker.append(input_idx) | ||
|
||
|
@@ -700,7 +703,7 @@ def _prepare_inputs_with_constants( | |
|
||
return prepared_inputs | ||
|
||
def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray: | ||
def calibrate(self, *inputs: Union[QuantizedArray, numpy.ndarray]) -> numpy.ndarray: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Allow calibration with this is similar to what Brevitas does with |
||
"""Create corresponding QuantizedArray for the output of the activation function. | ||
|
||
Args: | ||
|
@@ -712,6 +715,8 @@ def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray: | |
|
||
# Here we need the actual values of the constants, we need to pass through | ||
# the numpy.ndarrays in the computation graph | ||
# Mixing ops may be calibrated using QuantizedArray inputs, in order | ||
# to pre-compute anlytical output quantization | ||
prepared_inputs = self._prepare_inputs_with_constants( | ||
*inputs, calibrate=True, quantize_actual_values=False | ||
) | ||
|
@@ -720,12 +725,48 @@ def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray: | |
if isinstance(raw_result, RawOpOutput): | ||
return raw_result | ||
|
||
quantized_samples = QuantizedArray(self.n_bits, raw_result) | ||
# If the caller passes only QuantizedArray it means | ||
# that they are asking to quantized using analytical | ||
# formulas | ||
requested_analytical_quant = all( | ||
isinstance(qv, QuantizedArray) for qv in inputs | ||
) and isinstance(self, QuantizedMixingOp) | ||
if requested_analytical_quant: | ||
assert_true( | ||
self.supported_by_linear_backend(), | ||
"Calibration using QuantizedArray is only possible" | ||
" for operations that can calibrate analytically", | ||
) | ||
q_prepared_inputs = self._prepare_inputs_with_constants( | ||
*inputs, calibrate=False, quantize_actual_values=True | ||
) | ||
quantizer = self.calibrate_analytical_output(*q_prepared_inputs) | ||
self.output_quant_params = quantizer.quant_params | ||
self.output_quant_stats = quantizer.quant_stats | ||
Comment on lines
+743
to
+745
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we doing this? We should always have the output based on the calibration quantized params since we want a specific bit-width, no? This analytical output quantizer would mean we don't requantize the output to a lower bitwidth. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, exactly, the analytically computed quantizers are only used for dequantization There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
because doing this avoids having to actually perform the computation of the layer outputs for the calibration. thus it's much faster. since we only need the parameters for dequantization we can skip the actual computation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright so that's an optimization that will only be applied for linear only circuit. Though I am surprised it works since the condition is just for the op to be a QuantizedMixingOp. Why isn't the e.g. QuantizedAdd making our tests fail since it doesn't have an implementation for this calibrate_analytical_output ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because we only test for the QuantizedGemm / QuantizedMatmul case, while for the other MixingOps we test that an error is raised if one attempts to call calibrate_analytically on them There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The quantized add must be tested in torch. Why isn't this failing? |
||
else: | ||
# These output quantization parameters are only used | ||
# for operations that produce graph output operation | ||
# and are a non-linear | ||
quantized_samples = QuantizedArray(self.n_bits, raw_result) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the previous "data-based calibration" that was done for op outputs. these values aren't actually used except for Univariate functions (mixing ops apply their input quantization) |
||
|
||
self.output_quant_params = quantized_samples.quantizer.quant_params | ||
self.output_quant_stats = quantized_samples.quantizer.quant_stats | ||
|
||
return raw_result | ||
|
||
self.output_quant_params = quantized_samples.quantizer.quant_params | ||
self.output_quant_stats = quantized_samples.quantizer.quant_stats | ||
def calibrate_analytical_output(self, *inputs: QuantizedArray) -> UniformQuantizer: | ||
"""Calibrate output quantization based on analytical formulas. | ||
|
||
return quantized_samples.values | ||
Args: | ||
*inputs (QuantizedArray): quantized operation inputs. Quantized weights | ||
are storea in the op instance | ||
|
||
Raises: | ||
AssertionError: if the operation does not support analytical calibration | ||
""" | ||
raise AssertionError( | ||
f"calibrate_analytical_output: not implemented for {self._impl_for_op_named} op" | ||
) | ||
|
||
def prepare_output(self, qoutput_activation: numpy.ndarray) -> QuantizedArray: | ||
"""Quantize the output of the activation function. | ||
|
@@ -817,6 +858,15 @@ def _get_output_quant_opts(self): | |
output_quant_opts.is_qat = False | ||
return output_quant_opts | ||
|
||
@classmethod | ||
def supported_by_linear_backend(cls) -> bool: | ||
"""Indicate if this op can be executed on the GLWE linear backend. | ||
|
||
Returns: | ||
bool: True if the op can be executed with GLWE. | ||
""" | ||
return False | ||
|
||
|
||
class QuantizedOpUnivariateOfEncrypted(QuantizedOp, is_utility=True): | ||
"""An univariate operator of an encrypted value. | ||
|
@@ -931,11 +981,6 @@ def make_output_quant_parameters( | |
Returns: | ||
QuantizedArray: the quantized array that will be passed to the QuantizedModule output. | ||
""" | ||
|
||
out_opts = self._get_output_quant_opts() | ||
out_opts.is_signed = False | ||
out_opts.is_symmetric = False | ||
|
||
# Since we don't know the real bit-width of these quantized values, | ||
# return a quantizer that has zero offset | ||
out_params = UniformQuantizationParameters( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only for linux, we'll see in the weekly what happens on mac (should fall back to CP compilation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you run make licenses?
I thought we were only supposed to do that through GitHub Actions.
Or was it for a another purpose ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I generated the license files in the action. but this library only works on linux so only the linux license file shows it