diff --git a/tests/data/torch/mnist_2b_s1_1.zip b/tests/data/torch/mnist_2b_s1_1.zip deleted file mode 100644 index 9487827ad..000000000 --- a/tests/data/torch/mnist_2b_s1_1.zip +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:35ef93103b0b0d55632d322a4eae2c0c63b93507f59c6afd326cf84487fba5a5 -size 67730 diff --git a/tests/data/torch/mnist_test_batch.zip b/tests/data/torch/mnist_test_batch.zip deleted file mode 100644 index d2afb1d52..000000000 --- a/tests/data/torch/mnist_test_batch.zip +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b185001c5bd1dd31e475bd3ea0446231c0fc549b21f12b2ea51f07cf1b9adca7 -size 90728 diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index 4abc0f5c9..e8cf8572e 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -1,16 +1,13 @@ """Tests for the torch to numpy module.""" # pylint: disable=too-many-lines -import io import tempfile -import zipfile from functools import partial from inspect import signature from pathlib import Path import numpy import onnx -import onnxruntime as ort import pytest import torch import torch.quantization @@ -808,112 +805,6 @@ def test_compile_where_net(default_configuration, check_is_good_execution_for_cm numpy.testing.assert_allclose(torch_output, quantized_output, rtol=1e-2, atol=1e-2) -@pytest.mark.parametrize("verbose", [True, False], ids=["with_verbose", "without_verbose"]) -# pylint: disable-next=too-many-locals -def test_pretrained_mnist_qat( - default_configuration, - check_accuracy, - verbose, - check_graph_output_has_no_tlu, - check_is_good_execution_for_cml_vs_circuit, - is_weekly_option, -): - """Load a QAT MNIST model and confirm we get the same results in FHE simulation as with ONNX.""" - if not is_weekly_option: - pytest.skip("Tests too long") - - onnx_file_path = "tests/data/torch/mnist_2b_s1_1.zip" - mnist_test_path = "tests/data/torch/mnist_test_batch.zip" - - # Load ONNX model from zip file - with zipfile.ZipFile(onnx_file_path, "r") as archive_model: - onnx_model_serialized = io.BytesIO(archive_model.read("mnist_2b_s1_1.onnx")).read() - onnx_model = onnx.load_model_from_string(onnx_model_serialized) - - onnx.checker.check_model(onnx_model) - - # Load test data and ground truth from zip file - with zipfile.ZipFile(mnist_test_path, "r") as archive_data: - mnist_data = numpy.load( - io.BytesIO(archive_data.read("mnist_test_batch.npy")), allow_pickle=True - ).item() - - # Get the test data - inputset = mnist_data["test_data"] - - # Run through ONNX runtime and collect results - ort_session = ort.InferenceSession(onnx_model_serialized) - - onnx_results = numpy.zeros((inputset.shape[0],), dtype=numpy.int64) - for i, x_test in enumerate(inputset): - onnx_outputs = ort_session.run( - None, - {onnx_model.graph.input[0].name: x_test.reshape(1, -1)}, - ) - onnx_results[i] = numpy.argmax(onnx_outputs[0]) - - # Compile to Concrete ML in FHE simulation mode, with a high bit-width - n_bits = { - "model_inputs": 16, - "op_weights": 2, - "op_inputs": 2, - "model_outputs": 16, - } - - quantized_numpy_module = compile_onnx_model( - onnx_model, - inputset, - configuration=default_configuration, - n_bits=n_bits, - verbose=verbose, - ) - - quantized_numpy_module.check_model_is_compiled() - - check_is_good_execution_for_cml_vs_circuit(inputset, quantized_numpy_module, simulate=True) - - # Collect FHE simulation results - results = [] - for i in range(inputset.shape[0]): - - # Extract example i for each tensor in the tuple input-set - # while keeping the dimension of the original tensors. - # e.g., if input-set is a tuple of two (100, 10) tensors - # then q_x becomes a tuple of two tensors of shape (1, 10). - x = tuple(input[[i]] for input in to_tuple(inputset)) - result = numpy.argmax(quantized_numpy_module.forward(*x, fhe="simulate")) - results.append(result) - - # Compare ONNX runtime vs FHE simulation mode - check_accuracy(onnx_results, results, threshold=0.999) - - # Make sure absolute accuracy is good, this model should have at least 90% accuracy - check_accuracy(mnist_data["gt"], results, threshold=0.9) - - # Compile to Concrete ML using the FHE simulation mode and compatible bit-width - n_bits = { - "model_inputs": 7, - "op_weights": 2, - "op_inputs": 2, - "model_outputs": 7, - } - - quantized_numpy_module = compile_onnx_model( - onnx_model, - inputset, - import_qat=True, - configuration=default_configuration, - n_bits=n_bits, - verbose=verbose, - ) - - # As this is a custom QAT network, the input goes through multiple univariate - # ops that form a quantizer. Thus it has input TLUs. But it should not have output TLUs - check_graph_output_has_no_tlu(quantized_numpy_module.fhe_circuit.graph) - - assert quantized_numpy_module.fhe_circuit.graph.maximum_integer_bit_width() <= 8 - - def test_qat_import_bits_check(default_configuration): """Test that compile_brevitas_qat_model does not need an n_bits config."""