Skip to content

Commit

Permalink
added tests to ensure constant n functionality is correct
Browse files Browse the repository at this point in the history
  • Loading branch information
arthur-lin1027 committed Nov 13, 2023
1 parent b3008db commit 2ab9464
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 8 deletions.
6 changes: 3 additions & 3 deletions anisoap/representations/radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def __init__(
num_n = (max_angular - l) // 2 + 1
self.num_radial_functions.append(num_n)
elif isinstance(max_radial, list):
if len(max_radial) < l:
raise ValueError(
if len(max_radial) <= l:
raise ValueError(
"If you specify a list of number of radial components, this list must be of length {}. Received {}.".format(
max_angular + 1, len(max_radial)
)
Expand Down Expand Up @@ -235,7 +235,7 @@ def orthonormalize_basis(self, features: TensorMap):
for n in n_arr:
if n < 1:
n = 1
sigma_arr.append(self.hypers["cutoff_radius"] * np.sqrt(n) / nmax)
sigma_arr.append(self.cutoff_radius * np.sqrt(n) / nmax)

sigma_arr = np.array(sigma_arr)
prefactor_arr = gto_prefactor(l_2n_arr, sigma_arr)
Expand Down
6 changes: 1 addition & 5 deletions tests/test_ellipsoidal_density_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,7 @@ class TestBadInputs:
ValueError,
"radial_gaussian_width is set as an integer, which could cause overflow errors. Pass in float.",
],
[
{**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": 3},
ValueError,
"Only one of max_radial or radial_gaussian_width can be independently specified",
],

]

@pytest.mark.parametrize("hypers,error_type,expected_message", test_hypers)
Expand Down
51 changes: 51 additions & 0 deletions tests/test_radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,57 @@ def test_radial_functions_n7(self):
for l, num in enumerate(num_ns):
assert num == num_ns_exact[l]

def test_radial_functions_n8(self):
basis_gto = RadialBasis(
radial_basis="monomial", max_angular=6, max_radial=[1, 2, 3, 4, 5, 6, 7], cutoff_radius=5
)
num_ns = basis_gto.get_num_radial_functions()

# We specify max_radial so it's decoupled from max_angular.
num_ns_exact = [1, 2, 3, 4, 5, 6, 7]
assert len(num_ns) == len(num_ns_exact)
for l, num in enumerate(num_ns):
assert num == num_ns_exact[l]

class TestBadInputs:
"""
Class for testing if radial_basis fails with bad inputs
"""
DEFAULT_HYPERS = {
"max_angular": 10,
"radial_basis": "gto",
"radial_gaussian_width": 5.0,
"cutoff_radius": 1.0,
}
test_hypers = [
# [
# {**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": 3},
# ValueError,
# "Only one of max_radial or radial_gaussian_width can be independently specified",
# ],
[
{**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": [1, 2, 3]}, # default max_angular = 10
ValueError,
"If you specify a list of number of radial components, this list must be of length 11. Received 3."
],
[
{**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": "nonsense"},
ValueError,
"`max_radial` must be None, int, or list of int"
],
[
{**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": [1, "nonsense", 2]},
ValueError,
"`max_radial` must be None, int, or list of int"
],

]

@pytest.mark.parametrize("hypers,error_type,expected_message", test_hypers)
def test_hypers(self, hypers, error_type, expected_message):
with pytest.raises(error_type) as cm:
RadialBasis(**hypers)
assert cm.message == expected_message

class TestGaussianParameters:
"""
Expand Down

0 comments on commit 2ab9464

Please sign in to comment.