Skip to content

Commit

Permalink
fix: fix DeepGlobalPolar and DeepWFC initlization
Browse files Browse the repository at this point in the history
Fix #3561. Fix #3562.

Not sure if some one uses them, but it's good to keep compatibility.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed May 28, 2024
1 parent 0bcb84f commit 14d9364
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 19 deletions.
3 changes: 3 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class DeepEvalBackend(ABC):
"dos_redu": "dos",
"mask_mag": "mask_mag",
"mask": "mask",
# old models in v1
"global_polar": "global_polar",
"wfc": "wfc",
}

@abstractmethod
Expand Down
27 changes: 26 additions & 1 deletion deepmd/infer/deep_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@

import numpy as np

from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
)
from deepmd.infer.deep_tensor import (
DeepTensor,
OldDeepTensor,
)


Expand Down Expand Up @@ -36,7 +42,7 @@ def output_tensor_name(self) -> str:
return "polar"


class DeepGlobalPolar(DeepTensor):
class DeepGlobalPolar(OldDeepTensor):
@property
def output_tensor_name(self) -> str:
return "global_polar"
Expand Down Expand Up @@ -95,3 +101,22 @@ def eval(
mixed_type=mixed_type,
**kwargs,
)

@property
def output_def(self) -> ModelOutputDef:
"""Get the output definition of this model."""
# no atomic or differentiable output is defined
return ModelOutputDef(
FittingOutputDef(
[
OutputVariableDef(
self.output_tensor_name,
shape=[-1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
atomic=False,
),
]
)
)
21 changes: 21 additions & 0 deletions deepmd/infer/deep_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,24 @@ def output_def(self) -> ModelOutputDef:
]
)
)


class OldDeepTensor(DeepTensor):
"""Old tensor models from v1, which has no gradient output."""

# See https://github.com/deepmodeling/deepmd-kit/blob/1d1b251a2c5f05d1401aa89be792f9ed18b8f096/source/train/Model.py#L264
def eval_full(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: dict,
) -> Tuple[np.ndarray, ...]:
"""Unsupported method."""
raise RuntimeError(

Check warning on line 255 in deepmd/infer/deep_tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_tensor.py#L255

Added line #L255 was not covered by tests
"This model does not support eval_full method. Use eval instead."
)
28 changes: 26 additions & 2 deletions deepmd/infer/deep_wfc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
)
from deepmd.infer.deep_tensor import (
DeepTensor,
OldDeepTensor,
)


class DeepWFC(DeepTensor):
class DeepWFC(OldDeepTensor):
"""Deep WFC model.
Parameters
Expand All @@ -26,3 +31,22 @@ class DeepWFC(DeepTensor):
@property
def output_tensor_name(self) -> str:
return "wfc"

@property
def output_def(self) -> ModelOutputDef:
"""Get the output definition of this model."""
# no reduciable or differentiable output is defined
return ModelOutputDef(
FittingOutputDef(
[
OutputVariableDef(
self.output_tensor_name,
shape=[-1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
atomic=True,
),
]
)
)
41 changes: 25 additions & 16 deletions source/tests/tf/test_get_potential.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Test if `DeepPotential` facto function returns the right type of potential."""

import tempfile
import unittest

from deepmd.infer.deep_polar import (
DeepGlobalPolar,
)
from deepmd.infer.deep_wfc import (
DeepWFC,
)
from deepmd.tf.infer import (
DeepDipole,
DeepPolar,
Expand Down Expand Up @@ -35,16 +42,19 @@ def setUp(self):
str(self.work_dir / "deeppolar.pbtxt"), str(self.work_dir / "deep_polar.pb")
)

# TODO add model files for globalpolar and WFC
# convert_pbtxt_to_pb(
# str(self.work_dir / "deepglobalpolar.pbtxt"),
# str(self.work_dir / "deep_globalpolar.pb")
# )
with open(self.work_dir / "deeppolar.pbtxt") as f:
deeppolar_pbtxt = f.read()

# convert_pbtxt_to_pb(
# str(self.work_dir / "deepwfc.pbtxt"),
# str(self.work_dir / "deep_wfc.pb")
# )
# not an actual globalpolar and wfc model, but still good enough for testing factory
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(deeppolar_pbtxt.replace("polar", "global_polar"))
f.flush()
convert_pbtxt_to_pb(f.name, str(self.work_dir / "deep_globalpolar.pb"))

with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(deeppolar_pbtxt.replace("polar", "wfc"))
f.flush()
convert_pbtxt_to_pb(f.name, str(self.work_dir / "deep_wfc.pb"))

def tearDown(self):
for f in self.work_dir.glob("*.pb"):
Expand All @@ -62,11 +72,10 @@ def test_factory(self):
dp = DeepPotential(self.work_dir / "deep_pot.pb")
self.assertIsInstance(dp, DeepPot, msg.format(DeepPot, type(dp)))

# TODO add model files for globalpolar and WFC
# dp = DeepPotential(self.work_dir / "deep_globalpolar.pb")
# self.assertIsInstance(
# dp, DeepGlobalPolar, msg.format(DeepGlobalPolar, type(dp))
# )
dp = DeepPotential(self.work_dir / "deep_globalpolar.pb")
self.assertIsInstance(
dp, DeepGlobalPolar, msg.format(DeepGlobalPolar, type(dp))
)

# dp = DeepPotential(self.work_dir / "deep_wfc.pb")
# self.assertIsInstance(dp, DeepWFC, msg.format(DeepWFC, type(dp)))
dp = DeepPotential(self.work_dir / "deep_wfc.pb")
self.assertIsInstance(dp, DeepWFC, msg.format(DeepWFC, type(dp)))

0 comments on commit 14d9364

Please sign in to comment.