Skip to content

Commit

Permalink
fix: hybrid test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Oct 17, 2024
1 parent ba292fb commit 46e2767
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/torch/test_hybrid_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ def prepare_data(x, y, test_size=0.1, random_state=42):
y_glwe = hybrid_local(x1_test, fhe="execute").numpy()

y1_test = y1_test.numpy()
acc_fp32 = numpy.sum(numpy.argmax(y_torch, axis=1) == y1_test)
acc_qm = numpy.sum(numpy.argmax(y_qm, axis=1) == y1_test)
acc_glwe = numpy.sum(numpy.argmax(y_glwe, axis=1) == y1_test)
n_correct_fp32 = numpy.sum(numpy.argmax(y_torch, axis=1) == y1_test)
n_correct_qm = numpy.sum(numpy.argmax(y_qm, axis=1) == y1_test)
n_correct_glwe = numpy.sum(numpy.argmax(y_glwe, axis=1) == y1_test)

# These two should be exactly the same
assert numpy.all(numpy.allclose(y_torch, y_hybrid_torch, rtol=1, atol=0.001))
Expand All @@ -349,9 +349,9 @@ def prepare_data(x, y, test_size=0.1, random_state=42):
assert numpy.all(numpy.allclose(y_qm, y_glwe, rtol=1, atol=threshold_fhe))
assert numpy.all(numpy.allclose(y_torch, y_glwe, rtol=1, atol=threshold_fhe))

acc_threshold_fhe = 0.01
n_correct_delta_threshold_fhe = 1
# Check accuracy between fp32 and glwe
assert numpy.abs(acc_fp32 - acc_glwe) < acc_threshold_fhe
assert numpy.abs(n_correct_fp32 - n_correct_glwe) <= n_correct_delta_threshold_fhe

# Check accuracy between quantized and glwe
assert numpy.abs(acc_qm - acc_glwe) < acc_threshold_fhe
assert numpy.abs(n_correct_qm - n_correct_glwe) <= n_correct_delta_threshold_fhe

0 comments on commit 46e2767

Please sign in to comment.