-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: GLWE mat mul for hybrid model #868
Conversation
andrei-stoian-zama
commented
Sep 9, 2024
•
edited
Loading
edited
- the PR activates GLWE when possible
- the MLP lora works (it used to work before in @RomanBredehoft ’s version, I just cleaned it up and added text)
- LORA is supported for GLWE in hybrid model
- adds a test which uses GLWE for MLP inference and comparest to QM and pytorch fp32
- optimizes compilation when the hybrid model only has linear layers
- optimizes quantized clear execution (fhe=disable) when the hybrid model only has linear layers
6716709
to
6abe336
Compare
with monkeypatch.context() as m: | ||
if not transformers_installed: | ||
m.setitem(sys.modules, "transformers", None) | ||
if has_pbs_reshape: | ||
has_pbs = True | ||
if not glwe_backend_installed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this simulates environments where concrete-ml-extensions
is not available (e.g. MAC)
@@ -75,6 +85,13 @@ def run_hybrid_llm_test( | |||
assert "NoParametersFound" in error.args[0] | |||
pytest.skip(error.args[0]) | |||
|
|||
# Check we can run the simulate locally | |||
if has_pbs or not glwe_backend_installed: | |||
logits_simulate = hybrid_model(inputs, fhe="simulate").logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if GLWE is enabled (and this model is fully linear), it can't be simulated - we don't yet know how to simulate the GLWE stuff
assert numpy.all(numpy.allclose(y_torch, y_hybrid_torch, rtol=1, atol=0.001)) | ||
|
||
# The clear quantization vs fp32 test has more tolerance | ||
threshold_fhe = 0.01 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
experimentally adjusted.
the accuracy is fine even with this threshold on the logits
for name in hints: | ||
if getattr(obj, name) is None: | ||
raise TypeError(f"Missing quantizer parameter {name}") | ||
all_members_missing = all(getattr(obj, name) is None for name in hints) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted the stats
to be optional -> to be able to create a quantizer out of analytically determined values, e.g. when no min/max is computed.
But I want the caller to either be able to use all the fields in MinMaxStatistics
or none of them.
@@ -700,7 +700,7 @@ def _prepare_inputs_with_constants( | |||
|
|||
return prepared_inputs | |||
|
|||
def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray: | |||
def calibrate(self, *inputs: Union[QuantizedArray, numpy.ndarray]) -> numpy.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allow calibration with QuantizedArray
. Thus the calibration can be performed analytically.
this is similar to what Brevitas does with QuantTensor
which determines output quant parameters based on the input ones, without calibration.
@@ -720,12 +722,40 @@ def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray: | |||
if isinstance(raw_result, RawOpOutput): | |||
return raw_result | |||
|
|||
quantized_samples = QuantizedArray(self.n_bits, raw_result) | |||
supports_analytical_quant = all( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In theory all mixing ops could be calibrated analytically. But only Gemm implements the analytical formulas, all others will throw some kind of E_NOTIMPL
# These output quantization parameters are only used | ||
# for operations that produce graph output operation | ||
# and are a non-linear | ||
quantized_samples = QuantizedArray(self.n_bits, raw_result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the previous "data-based calibration" that was done for op outputs. these values aren't actually used except for Univariate functions (mixing ops apply their input quantization)
that can be used to de-quantize these values. | ||
""" | ||
|
||
q_input1 = inputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implements the famous gemm calibration formula by identifying terms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand the whole analytical part on this PR. Why are we doing this?
coloredlogs, 15.0.1, MIT License | ||
concrete-ml-extensions, 0.1.2, BSD-3-Clause-Clear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only for linux, we'll see in the weekly what happens on mac (should fall back to CP compilation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you run make licenses?
I thought we were only supposed to do that through GitHub Actions.
Or was it for a another purpose ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I generated the license files in the action. but this library only works on linux so only the linux license file shows it
@@ -674,9 +680,6 @@ def _prepare_inputs_with_constants( | |||
new_input.quantizer.is_qat | |||
and not input_.quantizer.is_precomputed_qat | |||
and self.error_tracker is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically:
- only mixing ops and reshape ops (ops that quantize their inputs) call this function with
quantize_actual_values
- only QuantizedArray that have
is_precomputed_qat
are now accepted for these functions ifis_qat
(QAT import was requested by the user). - you don't neeed
QuantIdentity
before reshape/mixing layers applied on other mixing ops outputs because the "passthrough" foris_precomputed_qat
implemented above.
b625cbc
to
2ec8685
Compare
self.act_f = nn.ReLU() | ||
self.fc2 = qnn.QuantLinear(in_features=input_output, out_features=input_output, weight_bit_width=3, bias=True) | ||
|
||
def forward(self, x): | ||
return self.fc2(self.act_f(self.fc1(self.quant_input(x)))) | ||
return self.fc2(self.quant_2(self.act_f(self.fc1(self.quant_input(x))))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was bad all along and only worked because the QAT quantization was guessed from the calibration data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm really not a fan of using quantIdentity layers after quantLayers.
It works fine without them, and including them just adds unnecessary complexity in my opinion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it used to work because of some automatic detection of quant parameters that we were doing. but I removed that detection since it was slow. so now we need quantidentity in the right places
x_block = x[:, i, :] | ||
q1_out = self.quant_1(x_block) | ||
fc1_out = self.fc1(q1_out) | ||
q_x = self.quant_1(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was wrong.. the input to the slicing should be quantized with QuantIdentity
if quant_opts.is_equal(input_.quantizer.quant_options, ignore_sign_qat=True): | ||
if ( | ||
quant_opts.is_equal(input_.quantizer.quant_options, ignore_sign_qat=True) | ||
or input_.quantizer.quant_options.is_precomputed_qat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
before we were overriding the op input quantization mainly for the model input. now we override every time there is a BrevitasQuant layer before the op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation, maybe you should update the comment accordingly.
@@ -603,7 +603,7 @@ def test_compile_torch_or_onnx_activations( | |||
@pytest.mark.parametrize( | |||
"model", | |||
[ | |||
pytest.param(SimpleQAT), | |||
pytest.param(StepFunctionPTQ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was a non-brevitas QAT model for which QAT quantization had to be guessed. we don't support that anymore, so we use ptq here
@@ -109,6 +124,152 @@ def convert_conv1d_to_linear(layer_or_module): | |||
return layer_or_module | |||
|
|||
|
|||
class OptimizedLinearLayerExecutor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the hybrid_model the proper place for this class? Look like this could be used in our qops as well. And it's very unrelated to hybrid model in the end.
I propose we move this to a glwe_backend.py
or linear_fhe_backend.py
and GLWELinearProcessor
for the class name or similar.
try: | ||
optimized_linear_layer_executor = _optimized_linear_executor.get() | ||
except LookupError: | ||
optimized_linear_layer_executor = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You do that a few times so we probably need a _get_optimized_executor
or something.
except LookupError: | ||
optimized_linear_layer_executor = None | ||
|
||
assert optimized_linear_layer_executor is None, ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this is relevant. In remote mode, the remote layer should already know if it will query a glwe server. Why are we checking the optimized_linear_layer_executor
?
@@ -364,7 +550,7 @@ def __init__( | |||
if not isinstance(model, torch.nn.Module): | |||
raise TypeError("The model must be a PyTorch or Brevitas model.") | |||
|
|||
self.model = model | |||
self.model = deepcopy(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should do that here. This will have lots of impact in our use cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it adds bugs too :) we'll figure it out in the new API Pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
try: | ||
import concrete_ml_extensions as fhext | ||
|
||
_HAS_GLWE_BACKEND = True | ||
except ImportError: # pragma: no cover | ||
fhext = None | ||
_HAS_GLWE_BACKEND = False | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could move that to the another file. The we can do something like
from [NEW_FILE] import OptimizedLinearLayerExecutor, is_glwe_backend_available
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also use importlib
for conditional import, like:
import importlib
fhext = importlib.import_module('x') if importlib.util.find_spec("x") else None
_HAS_GLWE_BACKEND = fhext is not None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we used the try/except pattern for the transformers
import too so I stuck to that
|
||
# Validate the FHE mode | ||
fhe_mode = HybridFHEMode(fhe) | ||
|
||
if _HAS_GLWE_BACKEND and self._all_layers_are_pure_linear: | ||
if fhe_mode == HybridFHEMode.SIMULATE: | ||
raise AssertionError( | ||
"When the HybridFHEModel is instantiated with only " | ||
"linear remote layers, fhe=simulate is not supported for now.", | ||
) | ||
|
||
if fhe_mode in (HybridFHEMode.EXECUTE, HybridFHEMode.REMOTE, HybridFHEMode.DISABLE): | ||
# If all layers are pure linear, enable the GLWE optimization for all layers | ||
# and generate an encryption and compression key for all layers | ||
# as they share crypto-parameters | ||
private_key, compression_key = None, None | ||
if fhe_mode != HybridFHEMode.DISABLE and self._all_layers_are_pure_linear: | ||
# pylint: disable-next=no-member | ||
fhext_glwe_crypto_params = fhext.MatmulCryptoParameters.deserialize( | ||
json.dumps(self.default_crypto_params_glwe) | ||
) | ||
# pylint: disable-next=no-member | ||
private_key, compression_key = fhext.create_private_key( | ||
fhext_glwe_crypto_params | ||
) | ||
|
||
_optimized_linear_executor.set( | ||
OptimizedLinearLayerExecutor( | ||
self.default_crypto_params_glwe, | ||
private_key=private_key, | ||
compression_key=compression_key, | ||
) | ||
) | ||
|
||
result = self.model(x) | ||
|
||
_optimized_linear_executor.set(None) | ||
|
||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are adding a lot of stuff here. Maybe we could use a context manager for the _optimized_linear_executor.set
calls such that can just do
with self._glwe_executor_context():
result = self.model(x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how can we get a reference to that executor context in the forward()
of self
with this setting ? with contextvars there is a global object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do this I think:
@contextlib.contextmanager
def _glwe_executor_context(self):
try:
_optimized_linear_executor.set(self.glwe_executor)
yield
finally:
_optimized_linear_executor.set(None)
but up to you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll keep it in mind for the refactor
self.default_crypto_params_glwe = ( | ||
json.loads(fhext.default_params()) # pylint: disable=no-member | ||
if _HAS_GLWE_BACKEND | ||
else None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that should be in the glwe class
# Save the model state dict due to a Brevitas issue | ||
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4572 | ||
torch.save(self.model.state_dict(), model_path.resolve()) | ||
torch.save(self.model, model_path.resolve()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
# At this point, the hybrid model does not have | ||
# the parameters necessaryto run the module_names | ||
module_names = module_names if isinstance(module_names, list) else [module_names] | ||
if not has_pbs and glwe_backend_installed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we have a better way to check if glwe is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not checking, it's enabling/disabling the glwe backend based on the value of glwe_backend_installed
tests/torch/test_hybrid_converter.py
Outdated
assert numpy.all(numpy.allclose(y_torch, y_glwe, rtol=1, atol=threshold_fhe)) | ||
|
||
# Check accuracy between fp32 and glwe | ||
assert numpy.abs(acc_fp32 - acc_glwe) < 0.01 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
threshold_fhe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have a function to check this? I think you do it multiple time here and in the tests in different ways.
@@ -77,7 +77,7 @@ concrete_clf.compile(X, debug_config) | |||
|
|||
#### 3. Quantization import failed | |||
|
|||
**Error message**: `Error occurred during quantization aware training (QAT) import [...] Could not determine a unique scale for the quantization!`. | |||
**Error message**: `Error occurred during quantization aware training (QAT) import [...] Are you missing a QuantIdentity layer in your Brevitas model?`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also have a non-intuitive error that occurs when using 'view' instead of 'reshape'.
Error: ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size ... is different from ...)
Maybe we could add it ? CC @andrei-stoian-zama
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's not related, it's due to shape inference during the onnx export.
that can be used to de-quantize these values. | ||
""" | ||
|
||
q_input1 = inputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sanity check: assert len(inputs) >= 2
no ?
# In the operation Y = alpha * A' * B' + beta * C, q_bias is used for | ||
# generalised matrix multiplication. q_bias is set to None for standard | ||
# matrix multiplication (beta == 0 or only two inputs) | ||
q_bias = None if len(inputs) == 2 or self.attrs["beta"] == 0 else inputs[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in which scenario len(inputs) != 2 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if a bias is specified for this op there are 3 inputs
|
||
assert q_input1.quantizer.scale is not None | ||
assert q_input2.quantizer.scale is not None | ||
m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have added a comment (even though it's trivial, but it might help in the future):
out_scale = scale_input1 × scale_input2
assert q_input2.quantizer.scale is not None | ||
m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale | ||
|
||
input2_q_values = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have grouped those rows, for better clarity:
# Transpose input2, if needed
transpose_inputs2 = self.attrs.get("transB", False)
input2_q_values = (
numpy.transpose(q_input2.qvalues) if transpose_inputs2 else q_input2.qvalues
)
assert q_input1.quantizer.scale is not None | ||
assert q_input2.quantizer.scale is not None | ||
m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have added a comment to show the formula of out_zp (even though it's trivial, but it might help in the future):
out_zp = sum_weights − final_term
out_zp = zp_input1 × ∑input2_i − p × zp_input1 × zp_input2
that can be used to de-quantize these values. | ||
""" | ||
|
||
q_input1 = inputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand the whole analytical part on this PR. Why are we doing this?
quantizer = self.calibrate_analytical_output(*q_prepared_inputs) | ||
self.output_quant_params = quantizer.quant_params | ||
self.output_quant_stats = quantizer.quant_stats |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we doing this? We should always have the output based on the calibration quantized params since we want a specific bit-width, no? This analytical output quantizer would mean we don't requantize the output to a lower bitwidth.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, exactly, the analytically computed quantizers are only used for dequantization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand the whole analytical part on this PR. Why are we doing this?
because doing this avoids having to actually perform the computation of the layer outputs for the calibration. thus it's much faster. since we only need the parameters for dequantization we can skip the actual computation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright so that's an optimization that will only be applied for linear only circuit. Though I am surprised it works since the condition is just for the op to be a QuantizedMixingOp. Why isn't the e.g. QuantizedAdd making our tests fail since it doesn't have an implementation for this calibrate_analytical_output ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because we only test for the QuantizedGemm / QuantizedMatmul case, while for the other MixingOps we test that an error is raised if one attempts to call calibrate_analytically on them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The quantized add must be tested in torch. Why isn't this failing?
# Check if the model has only GLWE supported linear layers. | ||
# In this case, use analytical calibration which is much faster | ||
fast_calibration = True | ||
for node in graph.node: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's adding another full traversal of the onnx graph here. We should probably exit the loop when we encounter an op that's not supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you're right. traversing the graph in this way is pretty fast though, I'm guessing the graph is in memory and we're just processing some pointers.
supports_analytical_quant = all( | ||
isinstance(qv, QuantizedArray) for qv in inputs | ||
) and isinstance(self, QuantizedMixingOp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be in QuantizedMixingOp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it could be a base function that returns False and an overload in the QuantizedMixingOp that implements the condition
@@ -720,12 +725,40 @@ def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray: | |||
if isinstance(raw_result, RawOpOutput): | |||
return raw_result | |||
|
|||
quantized_samples = QuantizedArray(self.n_bits, raw_result) | |||
supports_analytical_quant = all( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain in a comment when we are supposed to have this true or false?
If you're new to commit signing, there are different ways to set it up: Sign commits with
|
fix: use deai matmul fix: refactor glwe executor in hybrid model feat: use fhe execution in mlp lora example fix: better notebook printing chore: remove mlp lora example fix: update to use cml extensions from pypi feat: add glwe library, optimize compilation of hybrid model fix: handle simulate/disable in hybrid model full linear fix: pcc fix: pcc fix: bad assert fix: test fix: bad link fix: pcc, tests, glwe extensions only on linux chore: update licenses fix: tests fix: revert gpt2 fix: precomputed qat handled properly fix: codeblock test fix: readme link fix: deepcopy fix: refactoring fix: comments
c91ee33
to
ba292fb
Compare
|
Coverage passed ✅Coverage details
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!