Skip to content

Commit

Permalink
Update SE2Aggregation class in se2.py to use the last 4 columns of x …
Browse files Browse the repository at this point in the history
…instead of the last 3 columns. Update _SE2Descriptor class in se2.py to set the flow parameter to "target_to_source". and add radial info into env matrix (#135)
  • Loading branch information
floatingCatty authored Apr 20, 2024
1 parent ab3fc92 commit 8349224
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions dptb/nn/embedding/se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def forward(self, x: torch.Tensor, index: torch.LongTensor, **kwargs):
_type_
_description_
"""
direct_vec = x[:, -3:]
x = x[:,:-3].unsqueeze(-1) * direct_vec.unsqueeze(1) # [N_env, D, 3]
direct_vec = x[:, -4:]
x = x[:,:-4].unsqueeze(-1) * direct_vec.unsqueeze(1) # [N_env, D, 3]
return self.reduce(x, index, reduce="mean", dim=0) # [N_atom, D, 3] following the orders of atom index.


Expand All @@ -127,7 +127,7 @@ def __init__(
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"), **kwargs):

super(_SE2Descriptor, self).__init__(aggr=aggr, **kwargs)
super(_SE2Descriptor, self).__init__(aggr=aggr, **kwargs, flow="target_to_source")

if isinstance(device, str):
device = torch.device(device)
Expand Down Expand Up @@ -173,7 +173,7 @@ def forward(self, env_vectors, atom_attr, env_index, edge_index, edge_length):
def message(self, env_vectors, env_attr):
rij = env_vectors.norm(dim=-1, keepdim=True)
snorm = self.smooth(rij, self.rs, self.rc)
env_vectors = snorm * env_vectors / rij
env_vectors = torch.cat([snorm, snorm * env_vectors / rij], dim=-1)
return torch.cat([self.embedding_net(torch.cat([snorm, env_attr], dim=-1)), env_vectors], dim=-1) # [N_env, D_emb + 3]

def update(self, aggr_out):
Expand Down

0 comments on commit 8349224

Please sign in to comment.