From 2ab9464b6d1653cad1a446b77336eac926d5d4a0 Mon Sep 17 00:00:00 2001 From: Arthur Lin Date: Mon, 13 Nov 2023 17:35:29 -0600 Subject: [PATCH] added tests to ensure constant n functionality is correct --- anisoap/representations/radial_basis.py | 6 +-- tests/test_ellipsoidal_density_projection.py | 6 +-- tests/test_radial_basis.py | 51 ++++++++++++++++++++ 3 files changed, 55 insertions(+), 8 deletions(-) diff --git a/anisoap/representations/radial_basis.py b/anisoap/representations/radial_basis.py index 4eb1468..c7e96c7 100644 --- a/anisoap/representations/radial_basis.py +++ b/anisoap/representations/radial_basis.py @@ -117,8 +117,8 @@ def __init__( num_n = (max_angular - l) // 2 + 1 self.num_radial_functions.append(num_n) elif isinstance(max_radial, list): - if len(max_radial) < l: - raise ValueError( + if len(max_radial) <= l: + raise ValueError( "If you specify a list of number of radial components, this list must be of length {}. Received {}.".format( max_angular + 1, len(max_radial) ) @@ -235,7 +235,7 @@ def orthonormalize_basis(self, features: TensorMap): for n in n_arr: if n < 1: n = 1 - sigma_arr.append(self.hypers["cutoff_radius"] * np.sqrt(n) / nmax) + sigma_arr.append(self.cutoff_radius * np.sqrt(n) / nmax) sigma_arr = np.array(sigma_arr) prefactor_arr = gto_prefactor(l_2n_arr, sigma_arr) diff --git a/tests/test_ellipsoidal_density_projection.py b/tests/test_ellipsoidal_density_projection.py index c3c2f92..73766d1 100644 --- a/tests/test_ellipsoidal_density_projection.py +++ b/tests/test_ellipsoidal_density_projection.py @@ -134,11 +134,7 @@ class TestBadInputs: ValueError, "radial_gaussian_width is set as an integer, which could cause overflow errors. Pass in float.", ], - [ - {**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": 3}, - ValueError, - "Only one of max_radial or radial_gaussian_width can be independently specified", - ], + ] @pytest.mark.parametrize("hypers,error_type,expected_message", test_hypers) diff --git a/tests/test_radial_basis.py b/tests/test_radial_basis.py index 9a07bd9..1fe0205 100644 --- a/tests/test_radial_basis.py +++ b/tests/test_radial_basis.py @@ -48,6 +48,57 @@ def test_radial_functions_n7(self): for l, num in enumerate(num_ns): assert num == num_ns_exact[l] + def test_radial_functions_n8(self): + basis_gto = RadialBasis( + radial_basis="monomial", max_angular=6, max_radial=[1, 2, 3, 4, 5, 6, 7], cutoff_radius=5 + ) + num_ns = basis_gto.get_num_radial_functions() + + # We specify max_radial so it's decoupled from max_angular. + num_ns_exact = [1, 2, 3, 4, 5, 6, 7] + assert len(num_ns) == len(num_ns_exact) + for l, num in enumerate(num_ns): + assert num == num_ns_exact[l] + +class TestBadInputs: + """ + Class for testing if radial_basis fails with bad inputs + """ + DEFAULT_HYPERS = { + "max_angular": 10, + "radial_basis": "gto", + "radial_gaussian_width": 5.0, + "cutoff_radius": 1.0, + } + test_hypers = [ + # [ + # {**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": 3}, + # ValueError, + # "Only one of max_radial or radial_gaussian_width can be independently specified", + # ], + [ + {**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": [1, 2, 3]}, # default max_angular = 10 + ValueError, + "If you specify a list of number of radial components, this list must be of length 11. Received 3." + ], + [ + {**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": "nonsense"}, + ValueError, + "`max_radial` must be None, int, or list of int" + ], + [ + {**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": [1, "nonsense", 2]}, + ValueError, + "`max_radial` must be None, int, or list of int" + ], + + ] + + @pytest.mark.parametrize("hypers,error_type,expected_message", test_hypers) + def test_hypers(self, hypers, error_type, expected_message): + with pytest.raises(error_type) as cm: + RadialBasis(**hypers) + assert cm.message == expected_message class TestGaussianParameters: """