Skip to content

Commit

Permalink
added comments explaining orthonormalizing each subblock per neighbor…
Browse files Browse the repository at this point in the history
… and refactored the code to create a boolean-array mask variable called neighbor_mask
  • Loading branch information
arthur-lin1027 committed Nov 14, 2023
1 parent a5f2f9a commit 0e138c6
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions anisoap/representations/radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,34 +202,29 @@ def orthonormalize_basis(self, features: TensorMap):
)
return features
for label, block in features.items():
# Each block's `properties` dimension contains radial channels for each neighbor species
# Hence we have to iterate through each neighbor species and orthonormalize the block in subblocks
# Each subblock is indexed using the neighbor_mask boolean array.
neighbors = np.unique(block.properties["neighbor_species"])
for neighbor in neighbors:
l = label["angular_channel"]
n_arr = block.properties["n"][
block.properties["neighbor_species"] == neighbor
].flatten()
neighbor_mask = block.properties["neighbor_species"] == neighbor
n_arr = block.properties["n"][neighbor_mask].flatten()
l_2n_arr = l + 2 * n_arr
# normalize all the GTOs by the appropriate prefactor first, since the overlap matrix is in terms of
# normalized GTOs
prefactor_arr = gto_prefactor(
l_2n_arr, self.hypers["radial_gaussian_width"]
)
block.values[:, :, block.properties["neighbor_species"] == neighbor] = (
block.values[:, :, block.properties["neighbor_species"] == neighbor]
* prefactor_arr
)
block.values[:, :, neighbor_mask] *= prefactor_arr

gto_overlap_matrix_slice = self.overlap_matrix[l_2n_arr, :][:, l_2n_arr]
orthonormalization_matrix = inverse_matrix_sqrt(
gto_overlap_matrix_slice
)
block.values[
:, :, block.properties["neighbor_species"] == neighbor
] = np.einsum(
block.values[:, :, neighbor_mask] = np.einsum(
"ijk,kl->ijl",
block.values[
:, :, block.properties["neighbor_species"] == neighbor
],
block.values[:, :, neighbor_mask],
orthonormalization_matrix,
)

Expand Down

0 comments on commit 0e138c6

Please sign in to comment.