From f2137389504eec80f0220223896fc26cc6e56522 Mon Sep 17 00:00:00 2001 From: jfrery Date: Thu, 28 Mar 2024 12:02:14 +0100 Subject: [PATCH] chore: fix new feature --- src/concrete/ml/torch/compile.py | 2 +- tests/torch/test_compile_torch.py | 32 +++++++++++++++++++ .../evaluate_one_example_fhe.py | 4 +-- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/concrete/ml/torch/compile.py b/src/concrete/ml/torch/compile.py index be85233714..d04b2017f3 100644 --- a/src/concrete/ml/torch/compile.py +++ b/src/concrete/ml/torch/compile.py @@ -212,7 +212,7 @@ def _compile_torch_or_onnx_model( # mypy assert not isinstance(n_bits_rounding, str) and n_bits_rounding is not None - rounding_threshold_bits = {"n_bits": n_bits_rounding, "method": method.name} + rounding_threshold_bits = {"n_bits": n_bits_rounding, "method": method} inputset_as_numpy_tuple = tuple( convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset) diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index ea0564c75d..8fd7e686e4 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -1420,3 +1420,35 @@ def test_compile_torch_model_rounding_threshold_bits_errors( rounding_threshold_bits=rounding_threshold_bits, configuration=default_configuration, ) + + +@pytest.mark.parametrize( + "rounding_method, expected_reinterpret", + [ + ("APPROXIMATE", True), + ("EXACT", False), + ], +) +def test_rounding_mode(rounding_method, expected_reinterpret, default_configuration): + """Test that the underlying FHE circuit uses the right rounding method.""" + model = FCSmall(input_output=5, activation_function=nn.ReLU) + torch_inputset = torch.randn(10, 5) + configuration = default_configuration + + compiled_module = compile_torch_model( + torch_model=model, + torch_inputset=torch_inputset, + rounding_threshold_bits={"method": rounding_method, "n_bits": 4}, + configuration=configuration, + ) + + # Convert compiled module to string to search for patterns + mlir = compiled_module.fhe_circuit.mlir + if expected_reinterpret: + assert ( + "FHE.reinterpret_precision" in mlir and "FHE.round" not in mlir + ), "Expected 'FHE.reinterpret_precision' found but 'FHE.round' should not be present." + else: + assert ( + "FHE.reinterpret_precision" not in mlir + ), "Unexpected 'FHE.reinterpret_precision' found." diff --git a/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py b/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py index d3c54671bb..fff70a89ee 100644 --- a/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py +++ b/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py @@ -138,11 +138,9 @@ def wrapper(*args, **kwargs): print(f"Quantization of a single input (image) took {quantization_execution_time} seconds") print(f"Size of CLEAR input is {q_x_numpy.nbytes} bytes\n") - # Use new VL with .simulate() once CP's multi-parameter/precision bug is fixed - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3856 p_error = quantized_numpy_module.fhe_circuit.p_error expected_quantized_prediction, clear_inference_time = measure_execution_time( - partial(quantized_numpy_module.fhe_circuit.graph, p_error=p_error) + partial(quantized_numpy_module.fhe_circuit.simulate, p_error=p_error) )(q_x_numpy) # Encrypt the input