Skip to content

Commit

Permalink
feat: update UTs
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Mar 4, 2024
1 parent f36988d commit 9c9cbbe
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 0 deletions.
5 changes: 5 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ def compute_or_load_stat(
self.dp_model.compute_or_load_stat(sampled_func, stat_file_path)
self.zbl_model.compute_output_stats(sampled_func, stat_file_path)

def change_energy_bias(self):
# need to implement
pass


def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ def compute_output_stats(
torch.tensor(bias_atom_e, device=env.DEVICE).view([ntypes, 1])
)

def change_energy_bias(self) -> None:
# need to implement
pass

def forward_atomic(
self,
extended_coord: torch.Tensor,
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
EnergyStdLoss,
TensorLoss,
)
from deepmd.pt.model.model import (
DPZBLModel
)
from deepmd.pt.model.model import (
get_model,
get_zbl_model,
Expand Down Expand Up @@ -516,6 +519,9 @@ def get_loss(loss_params, start_lr, _ntypes, _model):
ntest=ntest,
bias_shift=model_params.get("bias_shift", "delta"),
)
elif isinstance(self.model, DPZBLModel):
# need to updated
self.model.change_energy_bias()
if init_frz_model is not None:
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
self.model.load_state_dict(frz_model.state_dict())
Expand Down

0 comments on commit 9c9cbbe

Please sign in to comment.