Skip to content

Commit

Permalink
use angle_only_cos
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 29, 2024
1 parent fc461d7 commit da822e3
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 5 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def __init__(
g1_out_mlp: bool = True,
scale_dist: bool = True,
multiscale_mode: str = "None",
angle_only_cos: bool = False,
ln_eps: Optional[float] = 1e-5,
) -> None:
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Expand Down Expand Up @@ -341,6 +342,7 @@ def __init__(
self.g1_out_mlp = g1_out_mlp
self.scale_dist = scale_dist
self.multiscale_mode = multiscale_mode
self.angle_only_cos = angle_only_cos
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def init_subclass_params(sub_data, sub_class):
g1_out_mlp=self.repformer_args.g1_out_mlp,
scale_dist=self.repformer_args.scale_dist,
multiscale_mode=self.repformer_args.multiscale_mode,
angle_only_cos=self.repformer_args.angle_only_cos,
seed=child_seed(seed, 1),
)
self.rcsl_list = [
Expand Down
11 changes: 10 additions & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
g1_out_mlp: bool = True,
scale_dist: bool = True,
multiscale_mode: str = "None",
angle_only_cos: bool = False,
) -> None:
r"""
The repformer descriptor block.
Expand Down Expand Up @@ -267,16 +268,24 @@ def __init__(
self.scale_dist = scale_dist
self.multiscale_mode = multiscale_mode
self.prec = PRECISION_DICT[precision]
self.angle_only_cos = angle_only_cos
if num_a % 2 != 1:
raise ValueError(f"{num_a=} must be an odd integer")
circular_harmonics_order = (num_a - 1) // 2
self.fourier_expansion = Fourier(
order=circular_harmonics_order,
learnable=True,
angle_only_cos=angle_only_cos,
precision=precision,
)
self.angle_embedding_in_features = self.num_a
if self.angle_only_cos:
self.angle_embedding_in_features = 1 + circular_harmonics_order
self.angle_embedding = torch.nn.Linear(
in_features=self.num_a, out_features=self.a_dim, bias=False, dtype=self.prec
in_features=self.angle_embedding_in_features,
out_features=self.a_dim,
bias=False,
dtype=self.prec,
)
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
Expand Down
16 changes: 12 additions & 4 deletions deepmd/pt/model/network/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ class Fourier(nn.Module):
"""Fourier Expansion for angle features."""

def __init__(
self, *, order: int = 5, learnable: bool = False, precision="float64"
self,
*,
order: int = 5,
learnable: bool = False,
angle_only_cos: bool = False,
precision="float64",
) -> None:
"""Initialize the Fourier expansion.
Expand All @@ -36,6 +41,7 @@ def __init__(
self.order = order
self.precision = precision
self.prec = PRECISION_DICT[self.precision]
self.angle_only_cos = angle_only_cos
# Initialize frequencies at canonical
if learnable:
self.frequencies = torch.nn.Parameter(
Expand All @@ -50,13 +56,15 @@ def __init__(

def forward(self, x: Tensor) -> Tensor:
"""Apply Fourier expansion to a feature Tensor."""
result = x.new_zeros(x.shape[0], 1 + 2 * self.order)
out_size = 1 + 2 * self.order if not self.angle_only_cos else 1 + self.order
result = x.new_zeros(x.shape[0], out_size)
result[:, 0] = 1 / torch.sqrt(
torch.tensor([2], device=result.device, dtype=result.dtype)
)
tmp = torch.outer(x, self.frequencies)
result[:, 1 : self.order + 1] = torch.sin(tmp)
result[:, self.order + 1 :] = torch.cos(tmp)
result[:, 1 : self.order + 1] = torch.cos(tmp)
if not self.angle_only_cos:
result[:, self.order + 1 :] = torch.sin(tmp)
return result / (torch.pi**0.5)


Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,12 @@ def dpa2_repformer_args():
optional=True,
default=True,
),
Argument(
"angle_only_cos",
bool,
optional=True,
default=False,
),
Argument(
"multiscale_mode",
str,
Expand Down

0 comments on commit da822e3

Please sign in to comment.