From 38128666627c6d1dc6f56e55d2f9c7313a7902bf Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:42:09 +0800 Subject: [PATCH] Fix single-task training&data stat --- deepmd/pt/model/descriptor/dpa2.py | 2 +- deepmd/pt/model/model/__init__.py | 9 ++++----- deepmd/pt/model/model/model.py | 4 ++-- deepmd/utils/path.py | 1 + examples/water/dpa2/input_torch.json | 8 ++------ examples/water/se_atten/input_torch.json | 2 ++ examples/water/se_e2_a/input_torch.json | 1 + 7 files changed, 13 insertions(+), 14 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index e693116cf4..b1df56a004 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -304,7 +304,7 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) } for item in merged ] - descrpt.compute_input_stats(merged_tmp) + descrpt.compute_input_stats(merged_tmp, path) def serialize(self) -> dict: """Serialize the obj to dict.""" diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 0dc9ae20af..b823a051f5 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -20,7 +20,7 @@ BaseDescriptor, ) from deepmd.pt.model.task import ( - Fitting, + BaseFitting, ) from .dp_model import ( @@ -61,7 +61,7 @@ def get_zbl_model(model_params): fitting_net["out_dim"] = descriptor.get_dim_emb() if "ener" in fitting_net["type"]: fitting_net["return_energy"] = True - fitting = Fitting(**fitting_net) + fitting = BaseFitting(**fitting_net) dp_model = DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"]) # pairtab filepath = model_params["use_srtab"] @@ -97,9 +97,8 @@ def get_model(model_params): fitting_net["out_dim"] = descriptor.get_dim_emb() if "ener" in fitting_net["type"]: fitting_net["return_energy"] = True - fitting = Fitting(**fitting_net) - - model = EnergyModel(descriptor, fitting, type_map=model_params["type_map"]) + fitting = BaseFitting(**fitting_net) + model = DPModel(descriptor, fitting, type_map=model_params["type_map"]) model.model_def_script = json.dumps(model_params) return model diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index 0f5e27aea9..e32d2f307d 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -59,9 +59,9 @@ # in DPAtomicModel (and other classes), but this requires the developer aware # of it when developing it... class BaseModel(make_base_model()): - def __init__(self): + def __init__(self, *args, **kwargs): """Construct a basic model for different tasks.""" - super().__init__() + super().__init__(*args, **kwargs) def compute_or_load_stat( self, diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index c9a7cd8554..79361b6c23 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -355,6 +355,7 @@ def save_numpy(self, arr: np.ndarray) -> None: if self._name in self._keys: del self.root[self._name] self.root.create_dataset(self._name, data=arr) + self.root.flush() def glob(self, pattern: str) -> List["DPPath"]: """Search path using the glob pattern. diff --git a/examples/water/dpa2/input_torch.json b/examples/water/dpa2/input_torch.json index 9d783b35d5..108e75df62 100644 --- a/examples/water/dpa2/input_torch.json +++ b/examples/water/dpa2/input_torch.json @@ -1,18 +1,13 @@ { "_comment": "that's all", "model": { - "type_embedding": { - "neuron": [ - 8 - ], - "tebd_input_mode": "concat" - }, "type_map": [ "O", "H" ], "descriptor": { "type": "dpa2", + "tebd_dim": 8, "repinit_rcut": 9.0, "repinit_rcut_smth": 8.0, "repinit_nsel": 120, @@ -74,6 +69,7 @@ "_comment": " that's all" }, "training": { + "stat_file": "./dpa2", "training_data": { "systems": [ "../data/data_0", diff --git a/examples/water/se_atten/input_torch.json b/examples/water/se_atten/input_torch.json index 7da3d64164..bc948cc2a0 100644 --- a/examples/water/se_atten/input_torch.json +++ b/examples/water/se_atten/input_torch.json @@ -15,6 +15,7 @@ 50, 100 ], + "tebd_dim": 8, "axis_neuron": 16, "attn": 128, "attn_layer": 2, @@ -59,6 +60,7 @@ "_comment": " that's all" }, "training": { + "stat_file": "./dpa1", "training_data": { "systems": [ "../data/data_0", diff --git a/examples/water/se_e2_a/input_torch.json b/examples/water/se_e2_a/input_torch.json index 053a721a44..c686b49d45 100644 --- a/examples/water/se_e2_a/input_torch.json +++ b/examples/water/se_e2_a/input_torch.json @@ -51,6 +51,7 @@ "_comment": " that's all" }, "training": { + "stat_file": "./se_e2_a", "training_data": { "systems": [ "../data/data_0",