diff --git a/tests/torch/test_hybrid_converter.py b/tests/torch/test_hybrid_converter.py index 397cf5b83..f07454652 100644 --- a/tests/torch/test_hybrid_converter.py +++ b/tests/torch/test_hybrid_converter.py @@ -329,9 +329,9 @@ def prepare_data(x, y, test_size=0.1, random_state=42): y_glwe = hybrid_local(x1_test, fhe="execute").numpy() y1_test = y1_test.numpy() - acc_fp32 = numpy.sum(numpy.argmax(y_torch, axis=1) == y1_test) - acc_qm = numpy.sum(numpy.argmax(y_qm, axis=1) == y1_test) - acc_glwe = numpy.sum(numpy.argmax(y_glwe, axis=1) == y1_test) + n_correct_fp32 = numpy.sum(numpy.argmax(y_torch, axis=1) == y1_test) + n_correct_qm = numpy.sum(numpy.argmax(y_qm, axis=1) == y1_test) + n_correct_glwe = numpy.sum(numpy.argmax(y_glwe, axis=1) == y1_test) # These two should be exactly the same assert numpy.all(numpy.allclose(y_torch, y_hybrid_torch, rtol=1, atol=0.001)) @@ -349,9 +349,9 @@ def prepare_data(x, y, test_size=0.1, random_state=42): assert numpy.all(numpy.allclose(y_qm, y_glwe, rtol=1, atol=threshold_fhe)) assert numpy.all(numpy.allclose(y_torch, y_glwe, rtol=1, atol=threshold_fhe)) - acc_threshold_fhe = 0.01 + n_correct_delta_threshold_fhe = 1 # Check accuracy between fp32 and glwe - assert numpy.abs(acc_fp32 - acc_glwe) < acc_threshold_fhe + assert numpy.abs(n_correct_fp32 - n_correct_glwe) <= n_correct_delta_threshold_fhe # Check accuracy between quantized and glwe - assert numpy.abs(acc_qm - acc_glwe) < acc_threshold_fhe + assert numpy.abs(n_correct_qm - n_correct_glwe) <= n_correct_delta_threshold_fhe