Skip to content

Commit

Permalink
Merge pull request #569 from CompRhys/cleanup-zbl
Browse files Browse the repository at this point in the history
Clean up unused Polynomial Cutoff Class from ZBLBasis, remove r_max argument.
  • Loading branch information
ilyes319 authored Dec 19, 2024
2 parents 9436e4a + fc6fa95 commit 8c5a80b
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 38 deletions.
2 changes: 1 addition & 1 deletion mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
if pair_repulsion:
self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff)
self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff)
self.pair_repulsion = True

sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
Expand Down
81 changes: 44 additions & 37 deletions mace/modules/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################

import logging

import ase
import numpy as np
import torch
Expand Down Expand Up @@ -110,67 +112,70 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]

@compile_mode("script")
class PolynomialCutoff(torch.nn.Module):
"""
Equation (8)
"""Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max.
Equation (8) -- TODO: from where?
"""

p: torch.Tensor
r_max: torch.Tensor

def __init__(self, r_max: float, p=6):
super().__init__()
self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# yapf: disable
return self.calculate_envelope(x, self.r_max, self.p.to(torch.int))

@staticmethod
def calculate_envelope(
x: torch.Tensor, r_max: torch.Tensor, p: int
) -> torch.Tensor:
r_over_r_max = x / r_max
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2)
1.0
- ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p)
+ p * (p + 2.0) * torch.pow(r_over_r_max, p + 1)
- (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2)
)
# yapf: enable

# noinspection PyUnresolvedReferences
return envelope * (x < self.r_max)
return envelope * (x < r_max)

def __repr__(self):
return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})"


@compile_mode("script")
class ZBLBasis(torch.nn.Module):
"""
Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
"""Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
with a polynomial cutoff envelope.
"""

p: torch.Tensor
r_max: torch.Tensor

def __init__(self, r_max: float, p=6, trainable=False):
def __init__(self, p=6, trainable=False, **kwargs):
super().__init__()
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
if "r_max" in kwargs:
logging.warning(
"r_max is deprecated. r_max is determined from the covalent radii."
)

# Pre-calculate the p coefficients for the ZBL potential
self.register_buffer(
"c",
torch.tensor(
[0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype()
),
)
self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
self.cutoff = PolynomialCutoff(r_max, p)
if trainable:
self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True))
self.a_prefactor = torch.nn.Parameter(
Expand Down Expand Up @@ -208,12 +213,7 @@ def forward(
)
v_edges = (14.3996 * Z_u * Z_v) / x * phi
r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / r_max, self.p + 2)
) * (x < r_max)
envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p)
v_edges = 0.5 * v_edges * envelope
V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0))
return V_ZBL.squeeze(-1)
Expand All @@ -224,8 +224,8 @@ def __repr__(self):

@compile_mode("script")
class AgnesiTransform(torch.nn.Module):
"""
Agnesi transform see ACEpotentials.jl, JCP 2023, p. 160
"""Agnesi transform - see section on Radial transformations in
ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783).
"""

def __init__(
Expand Down Expand Up @@ -265,21 +265,27 @@ def forward(
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0 = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_over_r_0 = x / r_0
return (
1 + (self.a * ((x / r_0) ** self.q) / (1 + (x / r_0) ** (self.q - self.p)))
) ** (-1)
1
+ (
self.a
* torch.pow(r_over_r_0, self.q)
/ (1 + torch.pow(r_over_r_0, self.q - self.p))
)
).reciprocal_()

def __repr__(self):
return f"{self.__class__.__name__}(a={self.a}, q={self.q}, p={self.p})"
return (
f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})"
)


@simplify_if_compile
@compile_mode("script")
class SoftTransform(torch.nn.Module):
"""
Soft Transform
"""
"""Soft Transform."""

def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False):
super().__init__()
Expand Down Expand Up @@ -312,9 +318,10 @@ def forward(
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0 = (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) / 4
r_over_r_0 = x / r_0
y = (
x
+ (1 / 2) * torch.tanh(-(x / r_0) - self.a * ((x / r_0) ** self.b))
+ (1 / 2) * torch.tanh(-r_over_r_0 - self.a * torch.pow(r_over_r_0, self.b))
+ 1 / 2
)
return y
Expand Down
83 changes: 83 additions & 0 deletions tests/modules/test_radial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
import torch
from mace.modules.radial import ZBLBasis, AgnesiTransform

@pytest.fixture
def zbl_basis():
return ZBLBasis(p=6, trainable=False)

def test_zbl_basis_initialization(zbl_basis):
assert zbl_basis.p == torch.tensor(6.0)
assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817]))

assert zbl_basis.a_exp == torch.tensor(0.300)
assert zbl_basis.a_prefactor == torch.tensor(0.4543)
assert not zbl_basis.a_exp.requires_grad
assert not zbl_basis.a_prefactor.requires_grad

def test_trainable_zbl_basis_initialization(zbl_basis):
zbl_basis = ZBLBasis(p=6, trainable=True)
assert zbl_basis.p == torch.tensor(6.0)
assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817]))

assert zbl_basis.a_exp == torch.tensor(0.300)
assert zbl_basis.a_prefactor == torch.tensor(0.4543)
assert zbl_basis.a_exp.requires_grad
assert zbl_basis.a_prefactor.requires_grad

def test_forward(zbl_basis):
x = torch.tensor([1.0, 1.0, 2.0]).unsqueeze(-1) # [n_edges]
node_attrs = torch.tensor([[1, 0], [0, 1]]) # [n_nodes, n_node_features] - one_hot encoding of atomic numbers
edge_index = torch.tensor([[0, 1, 1], [1, 0, 1]]) # [2, n_edges]
atomic_numbers = torch.tensor([1, 6]) # [n_nodes]
output = zbl_basis(x, node_attrs, edge_index, atomic_numbers)

assert output.shape == torch.Size([node_attrs.shape[0]])
assert torch.is_tensor(output)
assert torch.allclose(
output,
torch.tensor([0.0031, 0.0031], dtype=torch.get_default_dtype()),
rtol=1e-2
)

@pytest.fixture
def agnesi():
return AgnesiTransform(trainable=False)

def test_agnesi_transform_initialization(agnesi: AgnesiTransform):
assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4)
assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4)
assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4)
assert not agnesi.a.requires_grad
assert not agnesi.q.requires_grad
assert not agnesi.p.requires_grad

def test_trainable_agnesi_transform_initialization():
agnesi = AgnesiTransform(trainable=True)

assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4)
assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4)
assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4)
assert agnesi.a.requires_grad
assert agnesi.q.requires_grad
assert agnesi.p.requires_grad

def test_agnesi_transform_forward():
agnesi = AgnesiTransform()
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.get_default_dtype()).unsqueeze(-1)
node_attrs = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.get_default_dtype())
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]])
atomic_numbers = torch.tensor([1, 6, 8])
output = agnesi(x, node_attrs, edge_index, atomic_numbers)
assert output.shape == x.shape
assert torch.is_tensor(output)
assert torch.allclose(
output,
torch.tensor(
[0.3646, 0.2175, 0.2089], dtype=torch.get_default_dtype()
).unsqueeze(-1),
rtol=1e-2
)

if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 8c5a80b

Please sign in to comment.