diff --git a/.github/workflows/cifar_benchmark.yaml b/.github/workflows/cifar_benchmark.yaml index 45d68c688..5888e4daf 100644 --- a/.github/workflows/cifar_benchmark.yaml +++ b/.github/workflows/cifar_benchmark.yaml @@ -26,6 +26,11 @@ on: default: "3" type: string required: true + p_error: + description: P-error to use + default: "0.01" + type: string + required: true # FIXME: Add recurrent launching # https://github.com/zama-ai/concrete-ml-internal/issues/1851 @@ -128,7 +133,7 @@ jobs: if: github.event.inputs.benchmark == 'cifar-10-16b' run: | source .venv/bin/activate - NUM_SAMPLES=${{ github.event.inputs.num_samples }} python3 ./use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py + NUM_SAMPLES=${{ github.event.inputs.num_samples }} P_ERROR=${{ github.event.inputs.p_error }} python3 ./use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py python3 ./benchmarks/convert_cifar.py --model-name "16-bits-trained-v0" - name: Archive raw predictions diff --git a/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py b/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py index fa58da69b..c552ce2de 100644 --- a/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py +++ b/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py @@ -20,6 +20,7 @@ # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3953 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" NUM_SAMPLES = int(os.environ.get("NUM_SAMPLES", 1)) +P_ERROR = float(os.environ.get("P_ERROR", 0.01)) def measure_execution_time(func): @@ -83,7 +84,7 @@ def wrapper(*args, **kwargs): print("Compiling the model.") quantized_numpy_module, compilation_execution_time = measure_execution_time( compile_brevitas_qat_model -)(torch_model, x, configuration=configuration, rounding_threshold_bits=6, p_error=0.01) +)(torch_model, x, configuration=configuration, rounding_threshold_bits=6, p_error=P_ERROR) assert isinstance(quantized_numpy_module, QuantizedModule) print(f"Compilation time took {compilation_execution_time} seconds")