Skip to content

Commit

Permalink
Fix uts
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Mar 5, 2024
1 parent 62c30cb commit 7804da8
Show file tree
Hide file tree
Showing 15 changed files with 33 additions and 38 deletions.
1 change: 1 addition & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_spin_model(data: dict) -> SpinModel:
if "env_protection" not in data["descriptor"]:
data["descriptor"]["env_protection"] = 1e-6
if data["descriptor"]["type"] in ["se_e2_a"]:
# only expand sel for se_e2_a
data["descriptor"]["sel"] += data["descriptor"]["sel"]
backbone_model = get_standard_model(data)
return SpinModel(backbone_model=backbone_model, spin=spin)
Expand Down
2 changes: 0 additions & 2 deletions deepmd/dpmodel/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
class SpinModel:
"""A spin model wrapper, with spin input preprocess and output split."""

__USE_SPIN_INPUT__: bool = True

def __init__(
self,
backbone_model,
Expand Down
12 changes: 6 additions & 6 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,12 @@ def test_ener(
else: # pt support for spin
force_r = force
test_force_r = test_data["force"][:numb_test]
force_m = force_m.reshape(-1, 3)[mask_mag.reshape(-1)].reshape(nframes, -1)
test_force_m = (
test_data["force_mag"][:numb_test]
.reshape(-1, 3)[mask_mag.reshape(-1)]
.reshape(nframes, -1)
)
# The shape of force_m and test_force_m are [-1, 3],
# which is designed for mixed_type cases
force_m = force_m.reshape(-1, 3)[mask_mag.reshape(-1)]
test_force_m = test_data["force_mag"][:numb_test].reshape(-1, 3)[
mask_mag.reshape(-1)
]

diff_e = energy - test_data["energy"][:numb_test].reshape([-1, 1])
mae_e = mae(diff_e)
Expand Down
5 changes: 3 additions & 2 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def __init__(
neighbor_list=neighbor_list,
**kwargs,
)
if getattr(self.deep_eval, "has_spin_pt", False) and hasattr(
if getattr(self.deep_eval, "_has_spin", False) and hasattr(
self, "output_def_mag"
):
self.deep_eval.output_def = self.output_def_mag
Expand Down Expand Up @@ -529,7 +529,8 @@ def has_efield(self) -> bool:
@property
def has_spin(self) -> bool:
"""Check if the model has spin."""
return getattr(self.deep_eval, "has_spin_pt", False)
# use _has_spin to differentiate from has_spin form the old tf implementation
return getattr(self.deep_eval, "_has_spin", False)

def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model."""
Expand Down
2 changes: 1 addition & 1 deletion deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def eval(
force,
virial,
)
if getattr(self.deep_eval, "has_spin_pt", False):
if getattr(self.deep_eval, "_has_spin", False):
force_mag = results["energy_derv_r_mag"].reshape(nframes, natoms, 3)
mask_mag = results["mask_mag"].reshape(nframes, natoms, 1)
result = (*list(result), force_mag, mask_mag)
Expand Down
15 changes: 9 additions & 6 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def __init__(
self.auto_batch_size = auto_batch_size
else:
raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize")
self.has_spin_pt = getattr(self.dp.model["Default"], "has_spin", False)
if callable(self.has_spin_pt):
self.has_spin_pt = self.has_spin_pt()
self._has_spin = getattr(self.dp.model["Default"], "has_spin", False)
if callable(self._has_spin):
self._has_spin = self._has_spin()

def get_rcut(self) -> float:
"""Get the cutoff radius of this model."""
Expand Down Expand Up @@ -243,7 +243,7 @@ def eval(
coords, atom_types, len(atom_types.shape) > 1
)
request_defs = self._get_request_defs(atomic)
if "spin" not in kwargs:
if "spin" not in kwargs or kwargs["spin"] is None:
out = self._eval_func(self._eval_model, numb_test, natoms)(
coords, cells, atom_types, fparam, aparam, request_defs
)
Expand Down Expand Up @@ -570,6 +570,9 @@ def eval_model(
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
has_spin = getattr(model, "has_spin", False)
if callable(has_spin):
has_spin = has_spin()
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
box_input = None
if cells is None:
Expand All @@ -596,7 +599,7 @@ def eval_model(
"box": batch_box,
"do_atomic_virial": atomic,
}
if getattr(model, "__USE_SPIN_INPUT__", False):
if has_spin:
input_dict["spin"] = batch_spin
batch_output = model(**input_dict)
if isinstance(batch_output, tuple):
Expand Down Expand Up @@ -723,7 +726,7 @@ def eval_model(
"force": force_out,
"virial": virial_out,
}
if getattr(model, "__USE_SPIN_INPUT__", False):
if has_spin:
results_dict["force_mag"] = force_mag_out
if atomic:
results_dict["atom_energy"] = atomic_energy_out
Expand Down
14 changes: 2 additions & 12 deletions deepmd/pt/loss/ener_spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def __init__(
self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference
self.has_fr = (start_pref_fr != 0.0 and limit_pref_fr != 0.0) or inference
self.has_fm = (start_pref_fm != 0.0 and limit_pref_fm != 0.0) or inference

# TODO need support for virial, atomic energy and atomic pref
self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference
# TODO need support for atomic energy and atomic pref
self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference
self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference

Expand Down Expand Up @@ -172,17 +173,6 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
l1_force_mag_loss = l1_force_mag_loss.sum(-1).mean(-1).sum()
loss += (pref_fm * l1_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION)

if self.has_v and "virial" in model_pred and "virial" in label:
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = l2_virial_loss.detach()
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = rmse_v.detach()
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
more_loss["mae_v"] = mae_v.detach()
if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return loss, more_loss
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_stats(self) -> Dict[str, StatItem]:

def get_emask(self, nlist: torch.Tensor, atype: torch.Tensor) -> torch.Tensor:
"""
Compute the pair-wise type mask for given nlist and atype,
Compute the pair-wise type mask for given nlist and atype, for data stat
with shape same as nlist.
1 for include and 0 for exclude.
"""
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_spin_model(model_params):
):
model_params["descriptor"]["env_protection"] = 1e-6
if model_params["descriptor"]["type"] in ["se_e2_a"]:
# only expand sel for se_e2_a
model_params["descriptor"]["sel"] += model_params["descriptor"]["sel"]
backbone_model = get_standard_model(model_params)
return SpinEnergyModel(backbone_model=backbone_model, spin=spin)
Expand Down
2 changes: 0 additions & 2 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
class SpinModel(torch.nn.Module):
"""A spin model wrapper, with spin input preprocess and output split."""

__USE_SPIN_INPUT__: bool = True

def __init__(
self,
backbone_model,
Expand Down
1 change: 0 additions & 1 deletion examples/spin/se_e2_a/input_torch.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
{
"model": {
"type": "spin",
"type_map": [
"Ni",
"O"
Expand Down
1 change: 1 addition & 0 deletions source/tests/common/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
p_examples / "zinc_protein" / "zinc_se_a_mask.json",
p_examples / "dos" / "train" / "input.json",
p_examples / "spin" / "se_e2_a" / "input_tf.json",
p_examples / "spin" / "se_e2_a" / "input_torch.json",
p_examples / "dprc" / "normal" / "input.json",
p_examples / "dprc" / "pairwise" / "input.json",
p_examples / "dprc" / "generalized_force" / "input.json",
Expand Down
3 changes: 2 additions & 1 deletion source/tests/pt/model/test_deeppot.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def test_dp_test(self):
).reshape(1, -1, 3)
atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1)

e, f, v, ae, av = dp.eval(coord, cell, atype, atomic=True)
ret = dp.eval(coord, cell, atype, atomic=True)
e, f, v, ae, av = ret[0], ret[1], ret[2], ret[3], ret[4]
self.assertEqual(e.shape, (1, 1))
self.assertEqual(f.shape, (1, 5, 3))
self.assertEqual(v.shape, (1, 9))
Expand Down
6 changes: 4 additions & 2 deletions source/tests/pt/test_init_frz_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ def test_dp_test(self):
).reshape(1, -1, 3)
atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1)

e1, f1, v1, ae1, av1 = dp1.eval(coord, cell, atype, atomic=True)
e2, f2, v2, ae2, av2 = dp2.eval(coord, cell, atype, atomic=True)
ret1 = dp1.eval(coord, cell, atype, atomic=True)
e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4]
ret2 = dp2.eval(coord, cell, atype, atomic=True)
e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4]
np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10)
Expand Down
4 changes: 2 additions & 2 deletions source/tests/tf/test_deeppot_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def test_convert_012(self):
convert_pbtxt_to_pb(str(infer_path / "sea_012.pbtxt"), old_model)
run_dp(f"dp convert-from 0.12 -i {old_model} -o {new_model}")
dp = DeepPot(new_model)
_, _, _, _, _ = dp.eval(self.coords, self.box, self.atype, atomic=True)
ret = dp.eval(self.coords, self.box, self.atype, atomic=True)
os.remove(old_model)
os.remove(new_model)

Expand All @@ -814,7 +814,7 @@ def test_convert(self):
convert_pbtxt_to_pb(str(infer_path / "sea_012.pbtxt"), old_model)
run_dp(f"dp convert-from -i {old_model} -o {new_model}")
dp = DeepPot(new_model)
_, _, _, _, _ = dp.eval(self.coords, self.box, self.atype, atomic=True)
ret = dp.eval(self.coords, self.box, self.atype, atomic=True)
os.remove(old_model)
os.remove(new_model)

Expand Down

0 comments on commit 7804da8

Please sign in to comment.