Skip to content

Commit

Permalink
chore: fix new feature
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Mar 28, 2024
1 parent b6145e5 commit f213738
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _compile_torch_or_onnx_model(

# mypy
assert not isinstance(n_bits_rounding, str) and n_bits_rounding is not None
rounding_threshold_bits = {"n_bits": n_bits_rounding, "method": method.name}
rounding_threshold_bits = {"n_bits": n_bits_rounding, "method": method}

inputset_as_numpy_tuple = tuple(
convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
Expand Down
32 changes: 32 additions & 0 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,3 +1420,35 @@ def test_compile_torch_model_rounding_threshold_bits_errors(
rounding_threshold_bits=rounding_threshold_bits,
configuration=default_configuration,
)


@pytest.mark.parametrize(
"rounding_method, expected_reinterpret",
[
("APPROXIMATE", True),
("EXACT", False),
],
)
def test_rounding_mode(rounding_method, expected_reinterpret, default_configuration):
"""Test that the underlying FHE circuit uses the right rounding method."""
model = FCSmall(input_output=5, activation_function=nn.ReLU)
torch_inputset = torch.randn(10, 5)
configuration = default_configuration

compiled_module = compile_torch_model(
torch_model=model,
torch_inputset=torch_inputset,
rounding_threshold_bits={"method": rounding_method, "n_bits": 4},
configuration=configuration,
)

# Convert compiled module to string to search for patterns
mlir = compiled_module.fhe_circuit.mlir
if expected_reinterpret:
assert (
"FHE.reinterpret_precision" in mlir and "FHE.round" not in mlir
), "Expected 'FHE.reinterpret_precision' found but 'FHE.round' should not be present."
else:
assert (
"FHE.reinterpret_precision" not in mlir
), "Unexpected 'FHE.reinterpret_precision' found."
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,9 @@ def wrapper(*args, **kwargs):
print(f"Quantization of a single input (image) took {quantization_execution_time} seconds")
print(f"Size of CLEAR input is {q_x_numpy.nbytes} bytes\n")

# Use new VL with .simulate() once CP's multi-parameter/precision bug is fixed
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3856
p_error = quantized_numpy_module.fhe_circuit.p_error
expected_quantized_prediction, clear_inference_time = measure_execution_time(
partial(quantized_numpy_module.fhe_circuit.graph, p_error=p_error)
partial(quantized_numpy_module.fhe_circuit.simulate, p_error=p_error)
)(q_x_numpy)

# Encrypt the input
Expand Down

0 comments on commit f213738

Please sign in to comment.