Skip to content

Commit

Permalink
find element in list rather than tuple, which easily makes bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 30, 2024
1 parent 49a10a9 commit 6efab8c
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 20 deletions.
20 changes: 10 additions & 10 deletions deepmd/model_format/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,29 +200,29 @@ def output_def(self):
)

def __setitem__(self, key, value):
if key in ("bias_atom_e"):
if key in ["bias_atom_e"]:
self.bias_atom_e = value
elif key in ("fparam_avg"):
elif key in ["fparam_avg"]:
self.fparam_avg = value
elif key in ("fparam_inv_std"):
elif key in ["fparam_inv_std"]:
self.fparam_inv_std = value
elif key in ("aparam_avg"):
elif key in ["aparam_avg"]:
self.aparam_avg = value
elif key in ("aparam_inv_std"):
elif key in ["aparam_inv_std"]:
self.aparam_inv_std = value
else:
raise KeyError(key)

Check warning on line 214 in deepmd/model_format/fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model_format/fitting.py#L214

Added line #L214 was not covered by tests

def __getitem__(self, key):
if key in ("bias_atom_e"):
if key in ["bias_atom_e"]:
return self.bias_atom_e
elif key in ("fparam_avg"):
elif key in ["fparam_avg"]:
return self.fparam_avg
elif key in ("fparam_inv_std"):
elif key in ["fparam_inv_std"]:
return self.fparam_inv_std
elif key in ("aparam_avg"):
elif key in ["aparam_avg"]:
return self.aparam_avg
elif key in ("aparam_inv_std"):
elif key in ["aparam_inv_std"]:
return self.aparam_inv_std
else:
raise KeyError(key)

Check warning on line 228 in deepmd/model_format/fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model_format/fitting.py#L228

Added line #L228 was not covered by tests
Expand Down
20 changes: 10 additions & 10 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,30 +166,30 @@ def output_def(self) -> FittingOutputDef:
)

def __setitem__(self, key, value):
if key in ("bias_atom_e"):
if key in ["bias_atom_e"]:
# correct bias_atom_e shape. user may provide stupid shape
self.bias_atom_e = value
elif key in ("fparam_avg"):
elif key in ["fparam_avg"]:
self.fparam_avg = value
elif key in ("fparam_inv_std"):
elif key in ["fparam_inv_std"]:
self.fparam_inv_std = value
elif key in ("aparam_avg"):
elif key in ["aparam_avg"]:
self.aparam_avg = value
elif key in ("aparam_inv_std"):
elif key in ["aparam_inv_std"]:
self.aparam_inv_std = value
else:
raise KeyError(key)

Check warning on line 181 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L181

Added line #L181 was not covered by tests

def __getitem__(self, key):
if key in ("bias_atom_e"):
if key in ["bias_atom_e"]:
return self.bias_atom_e
elif key in ("fparam_avg"):
elif key in ["fparam_avg"]:
return self.fparam_avg
elif key in ("fparam_inv_std"):
elif key in ["fparam_inv_std"]:
return self.fparam_inv_std
elif key in ("aparam_avg"):
elif key in ["aparam_avg"]:
return self.aparam_avg
elif key in ("aparam_inv_std"):
elif key in ["aparam_inv_std"]:
return self.aparam_inv_std
else:
raise KeyError(key)

Check warning on line 195 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L195

Added line #L195 was not covered by tests
Expand Down
19 changes: 19 additions & 0 deletions source/tests/common/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,22 @@ def test_self_exception(
with self.assertRaises(ValueError) as context:
ret0 = ifn0(dd[0], atype, fparam=ifp, aparam=iap)
self.assertIn("input aparam", context.exception)

def test_get_set(self):
ifn0 = InvarFitting(
"energy",
self.nt,
3,
1,
)
rng = np.random.default_rng()
foo = rng.normal([3, 4])
for ii in [
"bias_atom_e",
"fparam_avg",
"fparam_inv_std",
"aparam_avg",
"aparam_inv_std",
]:
ifn0[ii] = foo
np.testing.assert_allclose(foo, ifn0[ii])
19 changes: 19 additions & 0 deletions source/tests/pt/test_ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,22 @@ def test_jit(
use_tebd=(not distinguish_types),
).to(env.DEVICE)
torch.jit.script(ft0)

def test_get_set(self):
ifn0 = InvarFitting(
"energy",
self.nt,
3,
1,
)
rng = np.random.default_rng()
foo = rng.normal([3, 4])
for ii in [
"bias_atom_e",
"fparam_avg",
"fparam_inv_std",
"aparam_avg",
"aparam_inv_std",
]:
ifn0[ii] = torch.tensor(foo, dtype=dtype, device=env.DEVICE)
np.testing.assert_allclose(foo, ifn0[ii].detach().cpu().numpy())

0 comments on commit 6efab8c

Please sign in to comment.