Skip to content

Commit

Permalink
change the test name for json run_train mh
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Oct 29, 2024
1 parent ae84fa2 commit c1bb3b2
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions tests/test_run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,13 +601,13 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs):
assert np.allclose(Es, ref_Es, atol=1e-1)


def test_run_train_foundation_multihead(tmp_path, fitting_configs):
def test_run_train_foundation_multihead_json(tmp_path, fitting_configs):
fitting_configs_dft = []
fitting_configs_mp2 = []
for i, c in enumerate(fitting_configs):

if i in (0, 1):
continue # skip isolated atoms, as energies specified by json files below
continue # skip isolated atoms, as energies specified by json files below
elif i % 2 == 0:
c.info["head"] = "DFT"
fitting_configs_dft.append(c)
Expand All @@ -625,8 +625,14 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs):
json.dump(E0s, f)

heads = {
"DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", "E0s": f"{str(tmp_path)}/fit_multihead_dft.json"},
"MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json"},
"DFT": {
"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz",
"E0s": f"{str(tmp_path)}/fit_multihead_dft.json",
},
"MP2": {
"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz",
"E0s": f"{str(tmp_path)}/fit_multihead_mp2.json",
},
}
yaml_str = "heads:\n"
for key, value in heads.items():
Expand Down

0 comments on commit c1bb3b2

Please sign in to comment.