Skip to content

Commit

Permalink
fix(pt): store min_nbor_dist in the state dict
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 5, 2024
1 parent dabedd2 commit b186980
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
8 changes: 6 additions & 2 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,14 @@ def train(
# save min_nbor_dist
if min_nbor_dist is not None:
if not multi_task:
trainer.model.min_nbor_dist = min_nbor_dist
trainer.model.min_nbor_dist = torch.tensor(
min_nbor_dist, dtype=torch.float64, device=DEVICE
)
else:
for model_item in min_nbor_dist:
trainer.model[model_item].min_nbor_dist = min_nbor_dist[model_item]
trainer.model[model_item].min_nbor_dist = torch.tensor(
min_nbor_dist[model_item], dtype=torch.float64, device=DEVICE
)
trainer.run()


Expand Down
6 changes: 4 additions & 2 deletions deepmd/pt/model/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, *args, **kwargs):
"""Construct a basic model for different tasks."""
torch.nn.Module.__init__(self)
self.model_def_script = ""
self.min_nbor_dist = None
self.register_buffer("min_nbor_dist", None)

def compute_or_load_stat(
self,
Expand Down Expand Up @@ -50,7 +50,9 @@ def get_model_def_script(self) -> str:
@torch.jit.export
def get_min_nbor_dist(self) -> Optional[float]:
"""Get the minimum distance between two atoms."""
return self.min_nbor_dist
if self.min_nbor_dist is None:
return None
return self.min_nbor_dist.item()

@torch.jit.export
def get_ntypes(self):
Expand Down

0 comments on commit b186980

Please sign in to comment.