From 04c414a57dafabbbb2b37fcc634d2b86d96423bf Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 23 Jan 2024 22:30:31 -0500 Subject: [PATCH] add universal Python inference interface DeepPot (#3164) Need discussion for other classes. --------- Signed-off-by: Jinzhe Zeng --- deepmd/infer/deep_pot.py | 3 +- deepmd_utils/infer/__init__.py | 6 ++ deepmd_utils/infer/backend.py | 33 +++++++++ deepmd_utils/infer/deep_pot.py | 126 +++++++++++++++++++++++++++++++++ source/tests/test_uni_infer.py | 27 +++++++ 5 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 deepmd_utils/infer/__init__.py create mode 100644 deepmd_utils/infer/backend.py create mode 100644 deepmd_utils/infer/deep_pot.py create mode 100644 source/tests/test_uni_infer.py diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 81cfdde7a8..45db3fcb0c 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -26,6 +26,7 @@ from deepmd.utils.sess import ( run_sess, ) +from deepmd_utils.infer.deep_pot import DeepPot as DeepPotBase if TYPE_CHECKING: from pathlib import ( @@ -35,7 +36,7 @@ log = logging.getLogger(__name__) -class DeepPot(DeepEval): +class DeepPot(DeepEval, DeepPotBase): """Constructor. Parameters diff --git a/deepmd_utils/infer/__init__.py b/deepmd_utils/infer/__init__.py new file mode 100644 index 0000000000..644f5e1f43 --- /dev/null +++ b/deepmd_utils/infer/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .deep_pot import ( + DeepPot, +) + +__all__ = ["DeepPot"] diff --git a/deepmd_utils/infer/backend.py b/deepmd_utils/infer/backend.py new file mode 100644 index 0000000000..809e19466b --- /dev/null +++ b/deepmd_utils/infer/backend.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from enum import ( + Enum, +) + + +class DPBackend(Enum): + """DeePMD-kit backend.""" + + TensorFlow = 1 + PyTorch = 2 + Paddle = 3 + Unknown = 4 + + +def detect_backend(filename: str) -> DPBackend: + """Detect the backend of the given model file. + + Parameters + ---------- + filename : str + The model file name + """ + if filename.endswith(".pb"): + return DPBackend.TensorFlow + elif filename.endswith(".pth") or filename.endswith(".pt"): + return DPBackend.PyTorch + elif filename.endswith(".pdmodel"): + return DPBackend.Paddle + return DPBackend.Unknown + + +__all__ = ["DPBackend", "detect_backend"] diff --git a/deepmd_utils/infer/deep_pot.py b/deepmd_utils/infer/deep_pot.py new file mode 100644 index 0000000000..dec0a7c47c --- /dev/null +++ b/deepmd_utils/infer/deep_pot.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + List, + Optional, + Tuple, + Union, +) + +import numpy as np + +from deepmd_utils.utils.batch_size import ( + AutoBatchSize, +) + +from .backend import ( + DPBackend, + detect_backend, +) + + +class DeepPot(ABC): + """Potential energy model. + + Parameters + ---------- + model_file : Path + The name of the frozen model file. + auto_batch_size : bool or int or AutoBatchSize, default: True + If True, automatic batch size will be used. If int, it will be used + as the initial batch size. + neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional + The ASE neighbor list class to produce the neighbor list. If None, the + neighbor list will be built natively in the model. + """ + + @abstractmethod + def __init__( + self, + model_file, + *args, + auto_batch_size: Union[bool, int, AutoBatchSize] = True, + neighbor_list=None, + **kwargs, + ) -> None: + pass + + def __new__(cls, model_file: str, *args, **kwargs): + if cls is DeepPot: + backend = detect_backend(model_file) + if backend == DPBackend.TensorFlow: + from deepmd.infer.deep_pot import DeepPot as DeepPotTF + + return super().__new__(DeepPotTF) + elif backend == DPBackend.PyTorch: + from deepmd_pt.infer.deep_eval import DeepPot as DeepPotPT + + return super().__new__(DeepPotPT) + else: + raise NotImplementedError("Unsupported backend: " + str(backend)) + return super().__new__(cls) + + @abstractmethod + def eval( + self, + coords: np.ndarray, + cells: np.ndarray, + atom_types: List[int], + atomic: bool = False, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + efield: Optional[np.ndarray] = None, + mixed_type: bool = False, + ) -> Tuple[np.ndarray, ...]: + """Evaluate energy, force, and virial. If atomic is True, + also return atomic energy and atomic virial. + + Parameters + ---------- + coords : np.ndarray + The coordinates of the atoms, in shape (nframes, natoms, 3). + cells : np.ndarray + The cell vectors of the system, in shape (nframes, 9). If the system + is not periodic, set it to None. + atom_types : List[int] + The types of the atoms. If mixed_type is False, the shape is (natoms,); + otherwise, the shape is (nframes, natoms). + atomic : bool, optional + Whether to return atomic energy and atomic virial, by default False. + fparam : np.ndarray, optional + The frame parameters, by default None. + aparam : np.ndarray, optional + The atomic parameters, by default None. + efield : np.ndarray, optional + The electric field, by default None. + mixed_type : bool, optional + Whether the system contains mixed atom types, by default False. + + Returns + ------- + energy + The energy of the system, in shape (nframes,). + force + The force of the system, in shape (nframes, natoms, 3). + virial + The virial of the system, in shape (nframes, 9). + atomic_energy + The atomic energy of the system, in shape (nframes, natoms). Only returned + when atomic is True. + atomic_virial + The atomic virial of the system, in shape (nframes, natoms, 9). Only returned + when atomic is True. + """ + # This method has been used by: + # documentation python.md + # dp model_devi: +fparam, +aparam, +mixed_type + # dp test: +atomic, +fparam, +aparam, +efield, +mixed_type + # finetune: +mixed_type + # dpdata + # ase + + +__all__ = ["DeepPot"] diff --git a/source/tests/test_uni_infer.py b/source/tests/test_uni_infer.py new file mode 100644 index 0000000000..6b70d17f7e --- /dev/null +++ b/source/tests/test_uni_infer.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for the universal Python inference interface.""" + +import os +import unittest + +from common import ( + tests_path, +) + +from deepmd.infer.deep_pot import DeepPot as DeepPotTF +from deepmd.utils.convert import ( + convert_pbtxt_to_pb, +) +from deepmd_utils.infer.deep_pot import DeepPot as DeepPot + + +class TestUniversalInfer(unittest.TestCase): + @classmethod + def setUpClass(cls): + convert_pbtxt_to_pb( + str(tests_path / os.path.join("infer", "deeppot-r.pbtxt")), "deeppot.pb" + ) + + def test_deep_pot(self): + dp = DeepPot("deeppot.pb") + self.assertIsInstance(dp, DeepPotTF)