Skip to content

Commit

Permalink
Pulled latest main changes and merged it.
Browse files Browse the repository at this point in the history
  • Loading branch information
YCC-ProjBackups committed Nov 29, 2023
2 parents 3f53a9f + c70eea0 commit 01f57d5
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 15 deletions.
1 change: 1 addition & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- name: Lint with black
run: |
black --check .
black --diff .
- name: Check imports
run: |
isort anisoap/*/*py -m 3 --tc --fgw --up -e -l 88 --check
Expand Down
15 changes: 13 additions & 2 deletions anisoap/representations/ellipsoidal_density_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,12 @@ class EllipsoidalDensityProjection:
Key under which rotations are stored in ase frames arrays
rotation_type : string
Type of rotation object being passed. Currently implemented
are 'quaternion' and 'matrix'\
are 'quaternion' and 'matrix'
max_radial : None, int, list of int
Number of radial bases to use. Can either correspond to number of
bases per spherical harmonic or a value to use with every harmonic.
If `None`, then for every `l`, `(max_angular - l) // 2 + 1` will
be used.
Attributes
----------
Expand All @@ -408,6 +413,7 @@ def __init__(
compute_gradients=False,
subtract_center_contribution=False,
radial_gaussian_width=None,
max_radial=None,
rotation_key="quaternion",
rotation_type="quaternion",
):
Expand Down Expand Up @@ -436,11 +442,16 @@ def __init__(
raise ValueError(
"radial_gaussian_width is set as an integer, which could cause overflow errors. Pass in float."
)

radial_hypers = {
"radial_gaussian_width": radial_gaussian_width,
}
self.radial_basis = RadialBasis(
radial_basis_name.lower(), max_angular, **radial_hypers
radial_basis_name.lower(),
max_angular,
cutoff_radius,
max_radial,
**radial_hypers,
)

self.num_ns = self.radial_basis.get_num_radial_functions()
Expand Down
32 changes: 26 additions & 6 deletions anisoap/representations/radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,39 @@ class RadialBasis:
"""

def __init__(self, radial_basis, max_angular, **hypers):
def __init__(
self, radial_basis, max_angular, cutoff_radius, max_radial=None, **hypers
):
# Store all inputs into internal variables
self.radial_basis = radial_basis
self.max_angular = max_angular
self.cutoff_radius = cutoff_radius
self.hypers = hypers
if self.radial_basis not in ["monomial", "gto"]:
raise ValueError(f"{self.radial_basis} is not an implemented basis.")

# As part of the initialization, compute the number of radial basis
# functions, nmax, for each angular frequency l.
# functions, num_n, for each angular frequency l.
# If nmax is given, num_n = nmax + 1 (n ranges from 0 to nmax)
self.num_radial_functions = []
for l in range(max_angular + 1):
num_n = (max_angular - l) // 2 + 1
self.num_radial_functions.append(num_n)
if max_radial is None:
num_n = (max_angular - l) // 2 + 1
self.num_radial_functions.append(num_n)
elif isinstance(max_radial, list):
if len(max_radial) <= l:
raise ValueError(
"If you specify a list of number of radial components, this list must be of length {}. Received {}.".format(
max_angular + 1, len(max_radial)
)
)
if not isinstance(max_radial[l], int):
raise ValueError("`max_radial` must be None, int, or list of int")
self.num_radial_functions.append(max_radial[l] + 1)
elif isinstance(max_radial, int):
self.num_radial_functions.append(max_radial + 1)
else:
raise ValueError("`max_radial` must be None, int, or list of int")

# As part of the initialization, compute the orthonormalization matrix for GTOs
# If we are using the monomial basis, set self.overlap_matrix equal to None
Expand Down Expand Up @@ -167,8 +186,9 @@ def calc_gto_overlap_matrix(self):
Returns:
S: 2D array. The overlap matrix
"""
# Consequence of the floor divide used to compute self.num_radial_functions
max_deg = self.max_angular + 1
max_deg = np.max(
np.arange(self.max_angular + 1) + 2 * np.array(self.num_radial_functions)
)
n_grid = np.arange(max_deg)
sigma = self.hypers["radial_gaussian_width"]
sigma_grid = np.ones(max_deg) * sigma
Expand Down
96 changes: 89 additions & 7 deletions tests/test_radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ class TestNumberOfRadialFunctions:

def test_notimplemented_basis(self):
with pytest.raises(ValueError):
basis = RadialBasis(radial_basis="nonsense", max_angular=5)
basis = RadialBasis(radial_basis="nonsense", max_angular=5, cutoff_radius=5)

def test_radial_functions_n5(self):
basis_gto = RadialBasis(radial_basis="monomial", max_angular=5)
basis_gto = RadialBasis(radial_basis="monomial", max_angular=5, cutoff_radius=5)
num_ns = basis_gto.get_num_radial_functions()

# Compare against exact results
Expand All @@ -27,7 +27,7 @@ def test_radial_functions_n5(self):
assert num == num_ns_exact[l]

def test_radial_functions_n6(self):
basis_gto = RadialBasis(radial_basis="monomial", max_angular=6)
basis_gto = RadialBasis(radial_basis="monomial", max_angular=6, cutoff_radius=5)
num_ns = basis_gto.get_num_radial_functions()

# Compare against exact results
Expand All @@ -36,6 +36,82 @@ def test_radial_functions_n6(self):
for l, num in enumerate(num_ns):
assert num == num_ns_exact[l]

def test_radial_functions_n7(self):
basis_gto = RadialBasis(
radial_basis="monomial", max_angular=6, max_radial=5, cutoff_radius=5
)
num_ns = basis_gto.get_num_radial_functions()

# We specify max_radial so it's decoupled from max_angular.
max_ns_exact = [5, 5, 5, 5, 5, 5, 5]
assert len(num_ns) == len(max_ns_exact)
for l, num in enumerate(num_ns):
assert num == max_ns_exact[l] + 1

def test_radial_functions_n8(self):
basis_gto = RadialBasis(
radial_basis="monomial",
max_angular=6,
max_radial=[1, 2, 3, 4, 5, 6, 7],
cutoff_radius=5,
)
num_ns = basis_gto.get_num_radial_functions()

# We specify max_radial so it's decoupled from max_angular.
max_ns_exact = [1, 2, 3, 4, 5, 6, 7]
assert len(num_ns) == len(max_ns_exact)
for l, num in enumerate(num_ns):
assert num == max_ns_exact[l] + 1


class TestBadInputs:
"""
Class for testing if radial_basis fails with bad inputs
"""

DEFAULT_HYPERS = {
"max_angular": 10,
"radial_basis": "gto",
"radial_gaussian_width": 5.0,
"cutoff_radius": 1.0,
}
test_hypers = [
# [
# {**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": 3},
# ValueError,
# "Only one of max_radial or radial_gaussian_width can be independently specified",
# ],
[
{
**DEFAULT_HYPERS,
"radial_gaussian_width": 5.0,
"max_radial": [1, 2, 3],
}, # default max_angular = 10
ValueError,
"If you specify a list of number of radial components, this list must be of length 11. Received 3.",
],
[
{**DEFAULT_HYPERS, "radial_gaussian_width": 5.0, "max_radial": "nonsense"},
ValueError,
"`max_radial` must be None, int, or list of int",
],
[
{
**DEFAULT_HYPERS,
"radial_gaussian_width": 5.0,
"max_radial": [1, "nonsense", 2],
},
ValueError,
"`max_radial` must be None, int, or list of int",
],
]

@pytest.mark.parametrize("hypers,error_type,expected_message", test_hypers)
def test_hypers(self, hypers, error_type, expected_message):
with pytest.raises(error_type) as cm:
RadialBasis(**hypers)
assert cm.message == expected_message


class TestGaussianParameters:
"""
Expand Down Expand Up @@ -65,9 +141,12 @@ class TestGaussianParameters:
@pytest.mark.parametrize("rotation_matrix", rotation_matrices)
def test_limit_large_sigma(self, sigma, r_ij, lengths, rotation_matrix):
# Initialize the classes
basis_mon = RadialBasis(radial_basis="monomial", max_angular=2)
basis_mon = RadialBasis(radial_basis="monomial", max_angular=2, cutoff_radius=5)
basis_gto = RadialBasis(
radial_basis="gto", radial_gaussian_width=sigma, max_angular=2
radial_basis="gto",
radial_gaussian_width=sigma,
max_angular=2,
cutoff_radius=5,
)

# Get the center and precision matrix
Expand All @@ -93,7 +172,10 @@ def test_limit_large_sigma(self, sigma, r_ij, lengths, rotation_matrix):
def test_limit_small_sigma(self, sigma, r_ij, lengths, rotation_matrix):
# Initialize the class
basis_gto = RadialBasis(
radial_basis="gto", radial_gaussian_width=sigma, max_angular=2
radial_basis="gto",
radial_gaussian_width=sigma,
max_angular=2,
cutoff_radius=5,
)

# Get the center and precision matrix
Expand Down Expand Up @@ -132,7 +214,7 @@ class TestGTOUtils:
def test_nogto_warning(self):
with pytest.warns(UserWarning):
lmax = 5
non_gto_basis = RadialBasis("monomial", lmax)
non_gto_basis = RadialBasis("monomial", lmax, cutoff_radius=5)
# As a proxy for a tensor map, pass in a numpy array for features
features = np.random.random((5, 5))
non_normalized_features = non_gto_basis.orthonormalize_basis(features)
Expand Down

0 comments on commit 01f57d5

Please sign in to comment.