From 046c412b675ea5d172e48ad8f3545bb902459d81 Mon Sep 17 00:00:00 2001 From: jfrery Date: Fri, 22 Nov 2024 11:48:21 +0100 Subject: [PATCH] chore: fix coverage + fix wrong Conv1D transformer import --- src/concrete/ml/torch/hybrid_model.py | 20 +-- tests/torch/test_hybrid_converter.py | 204 +++++++++++--------------- 2 files changed, 97 insertions(+), 127 deletions(-) diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index e8bbf6c34..87de43bae 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -69,7 +69,7 @@ def convert_conv1d_to_linear(layer_or_module): or the Conv1D layer converted to a Linear layer. """ try: - from transformers import Conv1D # pylint: disable=import-outside-toplevel + from transformers.modeling_utils import Conv1D # pylint: disable=import-outside-toplevel except ImportError: # pragma: no cover return layer_or_module @@ -412,13 +412,14 @@ def _replace_modules(self): if is_pure_linear_layer: module = self.private_modules[module_name] # Use weight shape instead of in/out_features - if hasattr(module, "weight"): - input_dim = module.weight.shape[ - 1 - ] # Input dimension is second dimension for Linear layers - output_dim = module.weight.shape[0] # Output dimension is first dimension - else: - input_dim = output_dim = 0 + input_dim, output_dim = ( + ( + module.weight.shape[1], + module.weight.shape[0], + ) + if hasattr(module, "weight") + else (0, 0) + ) is_pure_linear_layer = ( is_pure_linear_layer and input_dim >= 512 and output_dim >= 512 @@ -474,8 +475,7 @@ def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: if fhe_mode in (HybridFHEMode.EXECUTE, HybridFHEMode.REMOTE, HybridFHEMode.DISABLE): # Initialize executor only if not already done - if self.executor is None: - self.executor = GLWELinearLayerExecutor() + self.executor = self.executor or GLWELinearLayerExecutor() # Generate keys only if needed and not already done if fhe_mode != HybridFHEMode.DISABLE and self.executor.private_key is None: diff --git a/tests/torch/test_hybrid_converter.py b/tests/torch/test_hybrid_converter.py index 67af03037..909c0a392 100644 --- a/tests/torch/test_hybrid_converter.py +++ b/tests/torch/test_hybrid_converter.py @@ -37,166 +37,138 @@ def test_tuple_serialization(tup): assert tup == underscore_str_to_tuple(tuple_to_underscore_str(tup)) -# pylint: disable=too-many-locals, too-many-branches, too-many-statements +def setup_test_environment(): + """Save the original state of critical modules for restoration.""" + original_modules = {} + for module_name in [ + "transformers", + "concrete_ml_extensions", + "concrete.ml.quantization.linear_op_glwe_backend", + "concrete.ml.torch.hybrid_model", + ]: + original_modules[module_name] = sys.modules.get(module_name) + return original_modules + + +# pylint: disable=too-many-arguments, too-many-locals, too-many-statements, too-many-branches def run_hybrid_llm_test( model: torch.nn.Module, inputs: torch.Tensor, - module_names: Union[str, List], - expected_accuracy, + module_names: Union[str, List[str]], + expected_accuracy: float, has_pbs: bool, - has_pbs_reshape: bool, - monkeypatch, - transformers_installed, - glwe_backend_installed, + transformers_installed: bool, + glwe_backend_installed: bool, + monkeypatch: pytest.MonkeyPatch, ): """Run the test for any model with its private module names.""" - # Multi-parameter strategy is used in order to speed-up the FHE executions + # Configure the model configuration = Configuration( single_precision=False, compress_input_ciphertexts=True, ) - logits_simulate = None - - with monkeypatch.context() as m: - if not transformers_installed: - m.setitem(sys.modules, "transformers", None) - if has_pbs_reshape: - has_pbs = True - - # Patching for GLWE backend - if not glwe_backend_installed: - m.setitem(sys.modules, "concrete_ml_extensions", None) - - # Reload the affected modules to ensure the changes take effect - importlib.reload(concrete.ml.quantization.linear_op_glwe_backend) - importlib.reload(concrete.ml.torch.hybrid_model) - - hybrid_model = HybridFHEModel(model, module_names) - is_compiled = False - try: - hybrid_model.compile_model( - inputs, - p_error=10e-40, # compare precisely simulate and disable - n_bits=9, - rounding_threshold_bits=8, - configuration=configuration, - ) - is_compiled = True - except RuntimeError as error: - # When reshaping adds PBSs we sometimes encounter NoParametersFound - # when compiling. In this case we skip the rest since we can't simulate - # without compilation. - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4183 - assert "NoParametersFound" in error.args[0] - pytest.skip(error.args[0]) - - # Check we can run the simulate locally - if has_pbs or not glwe_backend_installed: - logits_simulate = hybrid_model(inputs, fhe="simulate").logits + # Mock sys.modules to simulate missing modules + if not transformers_installed: + monkeypatch.setitem(sys.modules, "transformers", None) + if not glwe_backend_installed: + monkeypatch.setitem(sys.modules, "concrete_ml_extensions", None) + + # Reload affected modules after mocking + importlib.reload(concrete.ml.quantization.linear_op_glwe_backend) + importlib.reload(concrete.ml.torch.hybrid_model) + + # Initialize and compile the hybrid model + hybrid_model = HybridFHEModel(model, module_names) + + try: + hybrid_model.compile_model( + inputs, + p_error=10e-40, + n_bits=9, + rounding_threshold_bits=8, + configuration=configuration, + ) + except RuntimeError as error: + # Skip test if NoParametersFound error occurs + if "NoParametersFound" in str(error): + pytest.skip(str(error)) else: - with pytest.raises(AssertionError, match=".*fhe=simulate is not supported.*"): - hybrid_model(inputs, fhe="simulate") + raise - if has_pbs: - # Check for non-zero programmable bootstrapping - for module in hybrid_model.private_q_modules.values(): - assert module.fhe_circuit.statistics["programmable_bootstrap_count"] > 0, ( - "Programmable bootstrap count should be greater than 0, " - f"but found {module.fhe_circuit.statistics['programmable_bootstrap_count']}" - ) + # Run the model in different modes + logits_simulate = None + if has_pbs or not glwe_backend_installed: + logits_simulate = hybrid_model(inputs, fhe="simulate").logits else: - # Check for zero programmable bootstrapping - for module in hybrid_model.private_q_modules.values(): - # The RemoteModule does not have a circuit if it was optimized - # (in the case of pure linear remote modules) - assert ( - not module.fhe_circuit - or module.fhe_circuit.statistics["programmable_bootstrap_count"] == 0 - ), ( - "Programmable bootstrap count should be 0, " - f"but found {module.fhe_circuit.statistics['programmable_bootstrap_count']}" - ) + with pytest.raises(AssertionError, match=".*fhe=simulate is not supported.*"): + hybrid_model(inputs, fhe="simulate") logits_disable = hybrid_model(inputs, fhe="disable").logits logits_original = hybrid_model(inputs, fhe="torch").logits - # Compare the topk accuracy of the FHE simulate circuit vs. the original. - k = 5 - - # Check that the topk next tokens are similar for the different FHE modes - # and the original model. + # Check programmable bootstrap counts if not glwe backend + if not glwe_backend_installed: + for module in hybrid_model.private_q_modules.values(): + pbs_count = module.fhe_circuit.statistics.get("programmable_bootstrap_count", 0) + if has_pbs: + assert pbs_count > 0, "Expected programmable bootstrap count > 0" + else: + assert pbs_count == 0, "Expected programmable bootstrap count == 0" - # Get the topk indices for logits_disable and logits_simulate + # Compare top-k accuracy + k = 5 topk_disable = logits_disable.topk(k, dim=-1).indices topk_original = logits_original.topk(k, dim=-1).indices - - # Compute accuracy of disable and simulate by checking - # how many labels correspond with the topk_original accuracy_disable = (topk_disable == topk_original).float().mean().item() - # Ensure logits_disable and logits_original return the same output for the logits - # Assert that both accuracy values are above the expected threshold assert ( accuracy_disable >= expected_accuracy - ), f"Disable accuracy {accuracy_disable:.4f} is below the expected {expected_accuracy:.4f}" + ), f"Disable accuracy {accuracy_disable:.4f} is below expected {expected_accuracy:.4f}" if logits_simulate is not None: - assert torch.allclose(logits_disable, logits_simulate, atol=1e-7), "Outputs do not match!" + assert torch.allclose(logits_disable, logits_simulate, atol=1e-7) topk_simulate = logits_simulate.topk(k, dim=-1).indices accuracy_simulate = (topk_simulate == topk_original).float().mean().item() - assert accuracy_simulate >= expected_accuracy, ( - f"Simulate accuracy {accuracy_simulate:.4f} is below " - f"the expected {expected_accuracy:.4f}" - ) + assert ( + accuracy_simulate >= expected_accuracy + ), f"Simulate accuracy {accuracy_simulate:.4f} is below expected {expected_accuracy:.4f}" + # Test model saving and deployment with tempfile.TemporaryDirectory() as temp_dir: temp_dir_path = Path(temp_dir) - # Get the temp directory path - if not has_pbs and glwe_backend_installed: - - if is_compiled: - # Deployment of GLWE backend hybrid models is not yet supported - with pytest.raises( - NotImplementedError, match="GLWE backend deployment is not yet supported" - ): - hybrid_model.save_and_clear_private_info(temp_dir_path) - else: - # Check that we get an error when trying to save a non-compiled model - with pytest.raises( - AttributeError, - match="The quantized module is not compiled. Please run compile*", - ): - hybrid_model.save_and_clear_private_info(temp_dir_path) + with pytest.raises( + NotImplementedError, match="GLWE backend deployment is not yet supported" + ): + hybrid_model.save_and_clear_private_info(temp_dir_path) else: - hybrid_model.save_and_clear_private_info(temp_dir_path) + # If transformers is not installed, skip the saving test + if not transformers_installed: + pytest.skip("Skipping save test as transformers module is not available") + hybrid_model.save_and_clear_private_info(temp_dir_path) hybrid_model.set_fhe_mode("remote") - # At this point, the hybrid model does not have - # the parameters necessaryto run the module_names - module_names = module_names if isinstance(module_names, list) else [module_names] - - # Check that files are there + # Verify saved files assert (temp_dir_path / "model.pth").exists() - for module_name in module_names: - module_dir_path = temp_dir_path / module_name - module_dir_files = set(str(elt.name) for elt in module_dir_path.glob("**/*")) - for file_name in ["client.zip", "server.zip"]: - assert file_name in module_dir_files + module_names_list = module_names if isinstance(module_names, list) else [module_names] + for module_name in module_names_list: + module_dir = temp_dir_path / module_name + files = {file.name for file in module_dir.glob("**/*")} + assert "client.zip" in files and "server.zip" in files # Dependency 'huggingface-hub' raises a 'FutureWarning' from version 0.23.0 when calling the # 'from_pretrained' method @pytest.mark.filterwarnings("ignore::FutureWarning") @pytest.mark.parametrize( - "list_or_str_private_modules_names, expected_accuracy, has_pbs, has_pbs_reshape", + "list_or_str_private_modules_names, expected_accuracy, has_pbs", [ - ("transformer.h.0.mlp", 0.95, True, False), - (["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.40, True, False), - ("transformer.h.0.mlp.c_fc", 1.0, False, True), + ("transformer.h.0.mlp", 0.95, True), + (["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.40, True), + ("transformer.h.0.mlp.c_fc", 1.0, False), ], ) @pytest.mark.parametrize("transformers_installed", [True, False]) @@ -205,7 +177,6 @@ def test_gpt2_hybrid_mlp( list_or_str_private_modules_names, expected_accuracy, has_pbs, - has_pbs_reshape, transformers_installed, glwe_backend_installed, monkeypatch, @@ -227,10 +198,9 @@ def test_gpt2_hybrid_mlp( list_or_str_private_modules_names, expected_accuracy, has_pbs, - has_pbs_reshape, - monkeypatch, transformers_installed, glwe_backend_installed, + monkeypatch, )