Skip to content

Commit

Permalink
Add DataRequirementItem
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 29, 2024
1 parent 6020a2b commit 5db7883
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 118 deletions.
5 changes: 4 additions & 1 deletion deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
Type,
)

from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
Expand Down Expand Up @@ -93,7 +96,7 @@ def model_output_type(self) -> str:
"""Get the output type for the model."""

@abstractmethod
def data_requirement(self) -> dict:
def data_requirement(self) -> List[DataRequirementItem]:
"""Get the data requirement for the model."""

@abstractmethod
Expand Down
9 changes: 8 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.utils.data import (
DataRequirementItem,
)

from .make_model import (
make_model,
Expand All @@ -14,6 +21,6 @@
# use "class" to resolve "Variable not allowed in type expression"
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
def data_requirement(self) -> dict:
def data_requirement(self) -> List[DataRequirementItem]:
"""Get the data requirement for the model."""
raise NotImplementedError
41 changes: 24 additions & 17 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
)

import torch

from deepmd.utils.data import (
DataRequirementItem,
)

from .dp_model import (
DPModel,
)
Expand Down Expand Up @@ -92,21 +97,23 @@ def forward_lower(
return model_predict

@property
def data_requirement(self):
data_requirement = {
"dipole": {
"ndof": 3,
"atomic": False,
"must": False,
"high_prec": False,
"type_sel": self.get_sel_type(),
},
"atomic_dipole": {
"ndof": 3,
"atomic": True,
"must": False,
"high_prec": False,
"type_sel": self.get_sel_type(),
},
}
def data_requirement(self) -> List[DataRequirementItem]:
data_requirement = [
DataRequirementItem(
"dipole",
ndof=3,
atomic=False,
must=False,
high_prec=False,
type_sel=self.get_sel_type(),
),
DataRequirementItem(
"atomic_dipole",
ndof=3,
atomic=True,
must=False,
high_prec=False,
type_sel=self.get_sel_type(),
),
]
return data_requirement
77 changes: 43 additions & 34 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
)

Expand All @@ -12,6 +13,9 @@
from deepmd.pt.model.model.model import (
BaseModel,
)
from deepmd.utils.data import (
DataRequirementItem,
)

from .make_model import (
make_model,
Expand Down Expand Up @@ -99,38 +103,43 @@ def forward_lower(
return model_predict

@property
def data_requirement(self):
data_requirement = {
"energy": {
"ndof": 1,
"atomic": False,
"must": False,
"high_prec": True,
},
"force": {
"ndof": 3,
"atomic": True,
"must": False,
"high_prec": False,
},
"virial": {
"ndof": 9,
"atomic": False,
"must": False,
"high_prec": False,
},
"atom_ener": {
"ndof": 1,
"atomic": True,
"must": False,
"high_prec": False,
},
"atom_pref": {
"ndof": 1,
"atomic": True,
"must": False,
"high_prec": False,
"repeat": 3,
},
}
def data_requirement(self) -> List[DataRequirementItem]:
data_requirement = [
DataRequirementItem(
"energy",
ndof=1,
atomic=False,
must=False,
high_prec=True,
),
DataRequirementItem(
"force",
ndof=3,
atomic=True,
must=False,
high_prec=False,
),
DataRequirementItem(
"virial",
ndof=9,
atomic=False,
must=False,
high_prec=False,
),
DataRequirementItem(
"atom_ener",
ndof=1,
atomic=True,
must=False,
high_prec=False,
),
DataRequirementItem(
"atom_pref",
ndof=1,
atomic=True,
must=False,
high_prec=False,
repeat=3,
),
]
return data_requirement
78 changes: 44 additions & 34 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
)

import torch

from deepmd.utils.data import (
DataRequirementItem,
)

from .dp_model import (
DPModel,
)
Expand Down Expand Up @@ -97,38 +102,43 @@ def forward_lower(
return model_predict

@property
def data_requirement(self):
data_requirement = {
"energy": {
"ndof": 1,
"atomic": False,
"must": False,
"high_prec": True,
},
"force": {
"ndof": 3,
"atomic": True,
"must": False,
"high_prec": False,
},
"virial": {
"ndof": 9,
"atomic": False,
"must": False,
"high_prec": False,
},
"atom_ener": {
"ndof": 1,
"atomic": True,
"must": False,
"high_prec": False,
},
"atom_pref": {
"ndof": 1,
"atomic": True,
"must": False,
"high_prec": False,
"repeat": 3,
},
}
def data_requirement(self) -> List[DataRequirementItem]:
data_requirement = [
DataRequirementItem(
"energy",
ndof=1,
atomic=False,
must=False,
high_prec=True,
),
DataRequirementItem(
"force",
ndof=3,
atomic=True,
must=False,
high_prec=False,
),
DataRequirementItem(
"virial",
ndof=9,
atomic=False,
must=False,
high_prec=False,
),
DataRequirementItem(
"atom_ener",
ndof=1,
atomic=True,
must=False,
high_prec=False,
),
DataRequirementItem(
"atom_pref",
ndof=1,
atomic=True,
must=False,
high_prec=False,
repeat=3,
),
]
return data_requirement
6 changes: 5 additions & 1 deletion deepmd/pt/model/model/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Optional,
)

from deepmd.dpmodel.model.base_model import (
make_base_model,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -85,6 +89,6 @@ def compute_or_load_stat(
"""
raise NotImplementedError

def data_requirement(self) -> dict:
def data_requirement(self) -> List[DataRequirementItem]:
"""Get the data requirement for the model."""
raise NotImplementedError
41 changes: 24 additions & 17 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
)

import torch

from deepmd.utils.data import (
DataRequirementItem,
)

from .dp_model import (
DPModel,
)
Expand Down Expand Up @@ -76,21 +81,23 @@ def forward_lower(
return model_predict

@property
def get_data_requirement(self):
data_requirement = {
"polar": {
"ndof": 9,
"atomic": False,
"must": False,
"high_prec": False,
"type_sel": self.get_sel_type(),
},
"atomic_polar": {
"ndof": 9,
"atomic": True,
"must": False,
"high_prec": False,
"type_sel": self.get_sel_type(),
},
}
def get_data_requirement(self) -> List[DataRequirementItem]:
data_requirement = [
DataRequirementItem(
"polar",
ndof=9,
atomic=False,
must=False,
high_prec=False,
type_sel=self.get_sel_type(),
),
DataRequirementItem(
"atomic_polar",
ndof=9,
atomic=True,
must=False,
high_prec=False,
type_sel=self.get_sel_type(),
),
]
return data_requirement
7 changes: 5 additions & 2 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from deepmd.pt.utils.dataset import (
DeepmdDataSetForLoader,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.data_system import (
prob_sys_size_ext,
process_sys_probs,
Expand Down Expand Up @@ -147,10 +150,10 @@ def __getitem__(self, idx):
batch["sid"] = idx
return batch

def add_data_requirement(self, dict_of_keys):
def add_data_requirement(self, data_requirement: List[DataRequirementItem]):
"""Add data requirement for each system in multiple systems."""
for system in self.systems:
system.add_data_requirement(dict_of_keys)
system.add_data_requirement(data_requirement)


_sentinel = object()
Expand Down
Loading

0 comments on commit 5db7883

Please sign in to comment.