diff --git a/anisoap/representations/ellipsoidal_density_projection.py b/anisoap/representations/ellipsoidal_density_projection.py index 6cd19ea..f155274 100644 --- a/anisoap/representations/ellipsoidal_density_projection.py +++ b/anisoap/representations/ellipsoidal_density_projection.py @@ -375,7 +375,12 @@ class EllipsoidalDensityProjection: Key under which rotations are stored in ase frames arrays rotation_type : string Type of rotation object being passed. Currently implemented - are 'quaternion' and 'matrix'\ + are 'quaternion' and 'matrix' + num_radial : None, int, list of int + Number of radial bases to use. Can either correspond to number of + bases per spherical harmonic or a value to use with every harmonic. + If `None`, then for every `l`, `(max_angular - l) // 2 + 1` will + be used. Attributes ---------- @@ -393,6 +398,7 @@ def __init__( radial_gaussian_width=None, rotation_key="quaternion", rotation_type="quaternion", + num_radial = None, ): # Store the input variables self.max_angular = max_angular @@ -424,6 +430,7 @@ def __init__( radial_hypers["radial_basis"] = radial_basis_name.lower() # lower case radial_hypers["radial_gaussian_width"] = radial_gaussian_width radial_hypers["max_angular"] = max_angular + radial_hypers["num_radial"] = num_radial self.radial_basis = RadialBasis(**radial_hypers) self.num_ns = self.radial_basis.get_num_radial_functions() diff --git a/anisoap/representations/radial_basis.py b/anisoap/representations/radial_basis.py index 9c1e208..1682f4a 100644 --- a/anisoap/representations/radial_basis.py +++ b/anisoap/representations/radial_basis.py @@ -98,7 +98,7 @@ class RadialBasis: """ - def __init__(self, radial_basis, max_angular, **hypers): + def __init__(self, radial_basis, max_angular, num_radial=None, **hypers): # Store all inputs into internal variables self.radial_basis = radial_basis self.max_angular = max_angular @@ -110,8 +110,23 @@ def __init__(self, radial_basis, max_angular, **hypers): # functions, nmax, for each angular frequency l. self.num_radial_functions = [] for l in range(max_angular + 1): - num_n = (max_angular - l) // 2 + 1 - self.num_radial_functions.append(num_n) + if num_radial is None: + num_n = (max_angular - l) // 2 + 1 + self.num_radial_functions.append(num_n) + elif isinstance(num_radial, list): + if len(num_radial) < l: + raise ValueError( + "If you specify a list of number of radial components, this list must be of length {}. Received {}.".format( + max_angular + 1, len(num_radial) + ) + ) + if not isinstance(num_radial[l], int): + raise ValueError("`num_radial` must be None, int, or list of int") + self.num_radial_functions.append(num_radial[l]) + elif isinstance(num_radial, int): + self.num_radial_functions.append(num_radial) + else: + raise ValueError("`num_radial` must be None, int, or list of int") # As part of the initialization, compute the orthonormalization matrix for GTOs # If we are using the monomial basis, set self.overlap_matrix equal to None