diff --git a/mtenn/tests/test_combination.py b/mtenn/tests/test_combination.py index 4798588..13355c3 100644 --- a/mtenn/tests/test_combination.py +++ b/mtenn/tests/test_combination.py @@ -11,9 +11,7 @@ @pytest.fixture() def models_and_inputs(): model_test = SchNet( - PygSchNet( - hidden_channels=16, num_filters=16, num_interactions=2, num_gaussians=2 - ) + PygSchNet(hidden_channels=2, num_filters=2, num_interactions=2, num_gaussians=2) ) model_ref = deepcopy(model_test) model_ref = SchNet.get_model(model_ref, strategy="complex") @@ -56,7 +54,7 @@ def test_mean_combination(models_and_inputs): ref_param_dict = dict(model_ref.named_parameters()) assert all( [ - np.allclose(p.grad, ref_param_dict[n].grad, atol=1e-7) + np.allclose(p.grad, ref_param_dict[n].grad, atol=5e-7) for n, p in model_test.named_parameters() ] ) @@ -88,7 +86,7 @@ def test_max_combination(models_and_inputs): ref_param_dict = dict(model_ref.named_parameters()) assert all( [ - np.allclose(p.grad, ref_param_dict[n].grad, atol=1e-7) + np.allclose(p.grad, ref_param_dict[n].grad, atol=5e-7) for n, p in model_test.named_parameters() ] ) @@ -118,7 +116,7 @@ def test_boltzmann_combination(models_and_inputs): ref_param_dict = dict(model_ref.named_parameters()) assert all( [ - np.allclose(p.grad, ref_param_dict[n].grad, atol=1e-7) + np.allclose(p.grad, ref_param_dict[n].grad, atol=5e-7) for n, p in model_test.named_parameters() ] )