Skip to content

Commit

Permalink
Adding ability to override number of radial bases
Browse files Browse the repository at this point in the history
  • Loading branch information
rosecers committed Nov 8, 2023
1 parent 374606c commit c18395f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
9 changes: 8 additions & 1 deletion anisoap/representations/ellipsoidal_density_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,12 @@ class EllipsoidalDensityProjection:
Key under which rotations are stored in ase frames arrays
rotation_type : string
Type of rotation object being passed. Currently implemented
are 'quaternion' and 'matrix'\
are 'quaternion' and 'matrix'
num_radial : None, int, list of int
Number of radial bases to use. Can either correspond to number of
bases per spherical harmonic or a value to use with every harmonic.
If `None`, then for every `l`, `(max_angular - l) // 2 + 1` will
be used.
Attributes
----------
Expand All @@ -393,6 +398,7 @@ def __init__(
radial_gaussian_width=None,
rotation_key="quaternion",
rotation_type="quaternion",
num_radial = None,
):
# Store the input variables
self.max_angular = max_angular
Expand Down Expand Up @@ -424,6 +430,7 @@ def __init__(
radial_hypers["radial_basis"] = radial_basis_name.lower() # lower case
radial_hypers["radial_gaussian_width"] = radial_gaussian_width
radial_hypers["max_angular"] = max_angular
radial_hypers["num_radial"] = num_radial
self.radial_basis = RadialBasis(**radial_hypers)

self.num_ns = self.radial_basis.get_num_radial_functions()
Expand Down
21 changes: 18 additions & 3 deletions anisoap/representations/radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class RadialBasis:
"""

def __init__(self, radial_basis, max_angular, **hypers):
def __init__(self, radial_basis, max_angular, num_radial=None, **hypers):
# Store all inputs into internal variables
self.radial_basis = radial_basis
self.max_angular = max_angular
Expand All @@ -110,8 +110,23 @@ def __init__(self, radial_basis, max_angular, **hypers):
# functions, nmax, for each angular frequency l.
self.num_radial_functions = []
for l in range(max_angular + 1):
num_n = (max_angular - l) // 2 + 1
self.num_radial_functions.append(num_n)
if num_radial is None:
num_n = (max_angular - l) // 2 + 1
self.num_radial_functions.append(num_n)
elif isinstance(num_radial, list):
if len(num_radial) < l:
raise ValueError(
"If you specify a list of number of radial components, this list must be of length {}. Received {}.".format(
max_angular + 1, len(num_radial)
)
)
if not isinstance(num_radial[l], int):
raise ValueError("`num_radial` must be None, int, or list of int")
self.num_radial_functions.append(num_radial[l])
elif isinstance(num_radial, int):
self.num_radial_functions.append(num_radial)
else:
raise ValueError("`num_radial` must be None, int, or list of int")

# As part of the initialization, compute the orthonormalization matrix for GTOs
# If we are using the monomial basis, set self.overlap_matrix equal to None
Expand Down

0 comments on commit c18395f

Please sign in to comment.