Skip to content

Commit

Permalink
chore: expose p-error to cifar-16b benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Feb 1, 2024
1 parent df81aca commit 060d851
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/cifar_benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ on:
default: "3"
type: string
required: true
p_error:
description: P-error to use
default: "0.01"
type: string
required: true

# FIXME: Add recurrent launching
# https://github.com/zama-ai/concrete-ml-internal/issues/1851
Expand Down Expand Up @@ -128,7 +133,7 @@ jobs:
if: github.event.inputs.benchmark == 'cifar-10-16b'
run: |
source .venv/bin/activate
NUM_SAMPLES=${{ github.event.inputs.num_samples }} python3 ./use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py
NUM_SAMPLES=${{ github.event.inputs.num_samples }} P_ERROR=${{ github.event.inputs.p_error }} python3 ./use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py
python3 ./benchmarks/convert_cifar.py --model-name "16-bits-trained-v0"
- name: Archive raw predictions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3953
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_SAMPLES = int(os.environ.get("NUM_SAMPLES", 1))
P_ERROR = float(os.environ.get("P_ERROR", 0.01))


def measure_execution_time(func):
Expand Down Expand Up @@ -83,7 +84,7 @@ def wrapper(*args, **kwargs):
print("Compiling the model.")
quantized_numpy_module, compilation_execution_time = measure_execution_time(
compile_brevitas_qat_model
)(torch_model, x, configuration=configuration, rounding_threshold_bits=6, p_error=0.01)
)(torch_model, x, configuration=configuration, rounding_threshold_bits=6, p_error=P_ERROR)
assert isinstance(quantized_numpy_module, QuantizedModule)

print(f"Compilation time took {compilation_execution_time} seconds")
Expand Down

0 comments on commit 060d851

Please sign in to comment.