diff --git a/src/concrete/ml/quantization/linear_op_glwe_backend.py b/src/concrete/ml/quantization/linear_op_glwe_backend.py index 2407b9089..66604928b 100644 --- a/src/concrete/ml/quantization/linear_op_glwe_backend.py +++ b/src/concrete/ml/quantization/linear_op_glwe_backend.py @@ -7,13 +7,18 @@ from ..common.utils import HybridFHEMode, to_tuple from .quantized_module import QuantizedModule -try: - import concrete_ml_extensions as fhext - _HAS_GLWE_BACKEND = True -except ImportError: # pragma: no cover - fhext = None - _HAS_GLWE_BACKEND = False +def has_glwe_backend(): + """Check if the GLWE backend is installed. + + Returns: + bool: True if the GLWE backend is installed, False otherwise. + """ + try: + __import__("concrete_ml_extensions") + return True + except ImportError: + return False class GLWELinearLayerExecutor: @@ -24,6 +29,12 @@ def __init__( private_key=None, compression_key=None, ): + assert has_glwe_backend(), "GLWE backend not installed" + + import concrete_ml_extensions as fhext + + self.fhext = fhext + self.compression_key = compression_key self.private_key = private_key @@ -39,7 +50,9 @@ def __init__( def keygen(self): """Generate private and compression key.""" # pylint: disable-next=no-member - self.private_key, self.compression_key = fhext.create_private_key(self.glwe_crypto_params) + self.private_key, self.compression_key = self.fhext.create_private_key( + self.glwe_crypto_params + ) def forward( self, x: numpy.ndarray, q_module: QuantizedModule, fhe: HybridFHEMode @@ -123,15 +136,15 @@ def forward( for idx, q_x_sample in enumerate(q_x): - ciphertext = fhext.encrypt_matrix( # pylint: disable=no-member + ciphertext = self.fhext.encrypt_matrix( # pylint: disable=no-member pkey=self.private_key, crypto_params=self.glwe_crypto_params, data=q_x_sample ) - encrypted_result = fhext.matrix_multiplication( # pylint: disable=no-member + encrypted_result = self.fhext.matrix_multiplication( # pylint: disable=no-member encrypted_matrix=ciphertext, data=q_weight.astype(numpy.uint64), compression_key=self.compression_key, ) - q_result = fhext.decrypt_matrix( # pylint: disable=no-member + q_result = self.fhext.decrypt_matrix( # pylint: disable=no-member encrypted_result, self.private_key, self.glwe_crypto_params, diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index e8bbf6c34..8cc4e69f2 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -21,7 +21,7 @@ from ..common.utils import MAX_BITWIDTH_BACKWARD_COMPATIBLE, HybridFHEMode from ..deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer -from ..quantization.linear_op_glwe_backend import _HAS_GLWE_BACKEND, GLWELinearLayerExecutor +from ..quantization.linear_op_glwe_backend import GLWELinearLayerExecutor, has_glwe_backend from .compile import ( QuantizedModule, build_quantized_module, @@ -69,7 +69,7 @@ def convert_conv1d_to_linear(layer_or_module): or the Conv1D layer converted to a Linear layer. """ try: - from transformers import Conv1D # pylint: disable=import-outside-toplevel + from transformers.modeling_utils import Conv1D # pylint: disable=import-outside-toplevel except ImportError: # pragma: no cover return layer_or_module @@ -412,13 +412,14 @@ def _replace_modules(self): if is_pure_linear_layer: module = self.private_modules[module_name] # Use weight shape instead of in/out_features - if hasattr(module, "weight"): - input_dim = module.weight.shape[ - 1 - ] # Input dimension is second dimension for Linear layers - output_dim = module.weight.shape[0] # Output dimension is first dimension - else: - input_dim = output_dim = 0 + input_dim, output_dim = ( + ( + module.weight.shape[1], + module.weight.shape[0], + ) + if hasattr(module, "weight") + else (0, 0) + ) is_pure_linear_layer = ( is_pure_linear_layer and input_dim >= 512 and output_dim >= 512 @@ -465,7 +466,7 @@ def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: # Validate the FHE mode fhe_mode = HybridFHEMode(fhe) - if _HAS_GLWE_BACKEND and self._has_only_large_linear_layers: + if has_glwe_backend() and self._has_only_large_linear_layers: if fhe_mode == HybridFHEMode.SIMULATE: raise AssertionError( "When the HybridFHEModel is instantiated with only " @@ -474,8 +475,7 @@ def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: if fhe_mode in (HybridFHEMode.EXECUTE, HybridFHEMode.REMOTE, HybridFHEMode.DISABLE): # Initialize executor only if not already done - if self.executor is None: - self.executor = GLWELinearLayerExecutor() + self.executor = self.executor or GLWELinearLayerExecutor() # Generate keys only if needed and not already done if fhe_mode != HybridFHEMode.DISABLE and self.executor.private_key is None: @@ -589,7 +589,7 @@ def compile_model( # If all layers are linear and the GLWE backend is available # then simply quantize the model without compiling with # Concrete Python. - if self._has_only_large_linear_layers and _HAS_GLWE_BACKEND: + if self._has_only_large_linear_layers and has_glwe_backend(): self.executor = GLWELinearLayerExecutor() self.private_q_modules[name] = build_quantized_module( self.private_modules[name], diff --git a/tests/torch/test_hybrid_converter.py b/tests/torch/test_hybrid_converter.py index 67af03037..cd1440417 100644 --- a/tests/torch/test_hybrid_converter.py +++ b/tests/torch/test_hybrid_converter.py @@ -1,6 +1,5 @@ """Tests for the hybrid model converter.""" -import importlib import sys import tempfile from pathlib import Path @@ -13,7 +12,6 @@ from sklearn.model_selection import train_test_split from transformers import GPT2LMHeadModel, GPT2Tokenizer -import concrete.ml.torch.hybrid_model from concrete.ml.pytest.torch_models import FCSmall, PartialQATModel from concrete.ml.torch.hybrid_model import ( HybridFHEModel, @@ -37,166 +35,121 @@ def test_tuple_serialization(tup): assert tup == underscore_str_to_tuple(tuple_to_underscore_str(tup)) -# pylint: disable=too-many-locals, too-many-branches, too-many-statements +# pylint: disable=too-many-arguments, too-many-locals, too-many-statements, too-many-branches def run_hybrid_llm_test( model: torch.nn.Module, inputs: torch.Tensor, - module_names: Union[str, List], - expected_accuracy, + module_names: Union[str, List[str]], + expected_accuracy: float, has_pbs: bool, - has_pbs_reshape: bool, - monkeypatch, - transformers_installed, - glwe_backend_installed, + transformers_installed: bool, + glwe_backend_installed: bool, + monkeypatch: pytest.MonkeyPatch, ): """Run the test for any model with its private module names.""" - # Multi-parameter strategy is used in order to speed-up the FHE executions + # Configure the model configuration = Configuration( single_precision=False, compress_input_ciphertexts=True, ) - logits_simulate = None - - with monkeypatch.context() as m: - if not transformers_installed: - m.setitem(sys.modules, "transformers", None) - if has_pbs_reshape: - has_pbs = True - - # Patching for GLWE backend - if not glwe_backend_installed: - m.setitem(sys.modules, "concrete_ml_extensions", None) - - # Reload the affected modules to ensure the changes take effect - importlib.reload(concrete.ml.quantization.linear_op_glwe_backend) - importlib.reload(concrete.ml.torch.hybrid_model) - - hybrid_model = HybridFHEModel(model, module_names) - is_compiled = False - try: - hybrid_model.compile_model( - inputs, - p_error=10e-40, # compare precisely simulate and disable - n_bits=9, - rounding_threshold_bits=8, - configuration=configuration, - ) - is_compiled = True - except RuntimeError as error: - # When reshaping adds PBSs we sometimes encounter NoParametersFound - # when compiling. In this case we skip the rest since we can't simulate - # without compilation. - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4183 - assert "NoParametersFound" in error.args[0] - pytest.skip(error.args[0]) - - # Check we can run the simulate locally - if has_pbs or not glwe_backend_installed: - logits_simulate = hybrid_model(inputs, fhe="simulate").logits + # Mock sys.modules to simulate missing modules + if not transformers_installed: + monkeypatch.setitem(sys.modules, "transformers", None) + if not glwe_backend_installed: + monkeypatch.setitem(sys.modules, "concrete_ml_extensions", None) + + # Initialize and compile the hybrid model + hybrid_model = HybridFHEModel(model, module_names) + + try: + hybrid_model.compile_model( + inputs, + p_error=10e-40, + n_bits=9, + rounding_threshold_bits=8, + configuration=configuration, + ) + except RuntimeError as error: + # Skip test if NoParametersFound error occurs + if "NoParametersFound" in str(error): + pytest.skip(str(error)) else: - with pytest.raises(AssertionError, match=".*fhe=simulate is not supported.*"): - hybrid_model(inputs, fhe="simulate") + raise - if has_pbs: - # Check for non-zero programmable bootstrapping - for module in hybrid_model.private_q_modules.values(): - assert module.fhe_circuit.statistics["programmable_bootstrap_count"] > 0, ( - "Programmable bootstrap count should be greater than 0, " - f"but found {module.fhe_circuit.statistics['programmable_bootstrap_count']}" - ) + # Run the model in different modes + logits_simulate = None + if has_pbs or not glwe_backend_installed: + logits_simulate = hybrid_model(inputs, fhe="simulate").logits else: - # Check for zero programmable bootstrapping - for module in hybrid_model.private_q_modules.values(): - # The RemoteModule does not have a circuit if it was optimized - # (in the case of pure linear remote modules) - assert ( - not module.fhe_circuit - or module.fhe_circuit.statistics["programmable_bootstrap_count"] == 0 - ), ( - "Programmable bootstrap count should be 0, " - f"but found {module.fhe_circuit.statistics['programmable_bootstrap_count']}" - ) + with pytest.raises(AssertionError, match=".*fhe=simulate is not supported.*"): + hybrid_model(inputs, fhe="simulate") logits_disable = hybrid_model(inputs, fhe="disable").logits logits_original = hybrid_model(inputs, fhe="torch").logits - # Compare the topk accuracy of the FHE simulate circuit vs. the original. - k = 5 - - # Check that the topk next tokens are similar for the different FHE modes - # and the original model. + # Check programmable bootstrap counts if not glwe backend + if not glwe_backend_installed: + for module in hybrid_model.private_q_modules.values(): + pbs_count = module.fhe_circuit.statistics.get("programmable_bootstrap_count", 0) + if has_pbs: + assert pbs_count > 0, "Expected programmable bootstrap count > 0" + else: + assert pbs_count == 0, "Expected programmable bootstrap count == 0" - # Get the topk indices for logits_disable and logits_simulate + # Compare top-k accuracy + k = 5 topk_disable = logits_disable.topk(k, dim=-1).indices topk_original = logits_original.topk(k, dim=-1).indices - - # Compute accuracy of disable and simulate by checking - # how many labels correspond with the topk_original accuracy_disable = (topk_disable == topk_original).float().mean().item() - # Ensure logits_disable and logits_original return the same output for the logits - # Assert that both accuracy values are above the expected threshold assert ( accuracy_disable >= expected_accuracy - ), f"Disable accuracy {accuracy_disable:.4f} is below the expected {expected_accuracy:.4f}" + ), f"Disable accuracy {accuracy_disable:.4f} is below expected {expected_accuracy:.4f}" if logits_simulate is not None: - assert torch.allclose(logits_disable, logits_simulate, atol=1e-7), "Outputs do not match!" + assert torch.allclose(logits_disable, logits_simulate, atol=1e-7) topk_simulate = logits_simulate.topk(k, dim=-1).indices accuracy_simulate = (topk_simulate == topk_original).float().mean().item() - assert accuracy_simulate >= expected_accuracy, ( - f"Simulate accuracy {accuracy_simulate:.4f} is below " - f"the expected {expected_accuracy:.4f}" - ) + assert ( + accuracy_simulate >= expected_accuracy + ), f"Simulate accuracy {accuracy_simulate:.4f} is below expected {expected_accuracy:.4f}" + # Test model saving and deployment with tempfile.TemporaryDirectory() as temp_dir: temp_dir_path = Path(temp_dir) - # Get the temp directory path - if not has_pbs and glwe_backend_installed: - - if is_compiled: - # Deployment of GLWE backend hybrid models is not yet supported - with pytest.raises( - NotImplementedError, match="GLWE backend deployment is not yet supported" - ): - hybrid_model.save_and_clear_private_info(temp_dir_path) - else: - # Check that we get an error when trying to save a non-compiled model - with pytest.raises( - AttributeError, - match="The quantized module is not compiled. Please run compile*", - ): - hybrid_model.save_and_clear_private_info(temp_dir_path) + with pytest.raises( + NotImplementedError, match="GLWE backend deployment is not yet supported" + ): + hybrid_model.save_and_clear_private_info(temp_dir_path) else: - hybrid_model.save_and_clear_private_info(temp_dir_path) + # If transformers is not installed, skip the saving test + if not transformers_installed: + pytest.skip("Skipping save test as transformers module is not available") + hybrid_model.save_and_clear_private_info(temp_dir_path) hybrid_model.set_fhe_mode("remote") - # At this point, the hybrid model does not have - # the parameters necessaryto run the module_names - module_names = module_names if isinstance(module_names, list) else [module_names] - - # Check that files are there + # Verify saved files assert (temp_dir_path / "model.pth").exists() - for module_name in module_names: - module_dir_path = temp_dir_path / module_name - module_dir_files = set(str(elt.name) for elt in module_dir_path.glob("**/*")) - for file_name in ["client.zip", "server.zip"]: - assert file_name in module_dir_files + module_names_list = module_names if isinstance(module_names, list) else [module_names] + for module_name in module_names_list: + module_dir = temp_dir_path / module_name + files = {file.name for file in module_dir.glob("**/*")} + assert "client.zip" in files and "server.zip" in files # Dependency 'huggingface-hub' raises a 'FutureWarning' from version 0.23.0 when calling the # 'from_pretrained' method @pytest.mark.filterwarnings("ignore::FutureWarning") @pytest.mark.parametrize( - "list_or_str_private_modules_names, expected_accuracy, has_pbs, has_pbs_reshape", + "list_or_str_private_modules_names, expected_accuracy, has_pbs", [ - ("transformer.h.0.mlp", 0.95, True, False), - (["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.40, True, False), - ("transformer.h.0.mlp.c_fc", 1.0, False, True), + ("transformer.h.0.mlp", 0.95, True), + (["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.40, True), + ("transformer.h.0.mlp.c_fc", 1.0, False), ], ) @pytest.mark.parametrize("transformers_installed", [True, False]) @@ -205,7 +158,6 @@ def test_gpt2_hybrid_mlp( list_or_str_private_modules_names, expected_accuracy, has_pbs, - has_pbs_reshape, transformers_installed, glwe_backend_installed, monkeypatch, @@ -227,10 +179,9 @@ def test_gpt2_hybrid_mlp( list_or_str_private_modules_names, expected_accuracy, has_pbs, - has_pbs_reshape, - monkeypatch, transformers_installed, glwe_backend_installed, + monkeypatch, )