diff --git a/.github/workflows/build_cc.yml b/.github/workflows/build_cc.yml index f029517d80..991be798aa 100644 --- a/.github/workflows/build_cc.yml +++ b/.github/workflows/build_cc.yml @@ -27,6 +27,10 @@ jobs: cache: 'pip' - uses: lukka/get-cmake@latest - run: python -m pip install tensorflow + - name: Download libtorch + run: | + wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip -O libtorch.zip + unzip libtorch.zip - run: | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb \ && sudo dpkg -i cuda-keyring_1.0-1_all.deb \ @@ -48,13 +52,17 @@ jobs: && sudo apt-get update \ && sudo apt-get install -y rocm-dev hipcub-dev if: matrix.variant == 'rocm' - - run: source/install/build_cc.sh + - run: | + export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch + source/install/build_cc.sh env: DP_VARIANT: ${{ matrix.dp_variant }} DOWNLOAD_TENSORFLOW: "FALSE" CMAKE_GENERATOR: Ninja if: matrix.variant != 'clang' - - run: source/install/build_cc.sh + - run: | + export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch + source/install/build_cc.sh env: DP_VARIANT: cpu DOWNLOAD_TENSORFLOW: "FALSE" diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index 23076e9bf5..fa109cac5e 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -68,6 +68,17 @@ jobs: - uses: docker/setup-qemu-action@v3 name: Setup QEMU if: matrix.platform_id == 'manylinux_aarch64' && matrix.os == 'ubuntu-latest' + # detect version in advance. See #3168 + - uses: actions/setup-python@v5 + name: Install Python + with: + python-version: '3.11' + cache: 'pip' + if: matrix.dp_pkg_name == 'deepmd-kit-cu11' + - run: | + python -m pip install setuptools_scm + python -c "from setuptools_scm import get_version;print('SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DEEPMD-KIT-CU11='+get_version())" >> $GITHUB_ENV + if: matrix.dp_pkg_name == 'deepmd-kit-cu11' - name: Build wheels uses: pypa/cibuildwheel@v2.16 env: diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index a9a162432c..c5460109f4 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -37,6 +37,8 @@ jobs: && sudo apt-get update \ && sudo apt-get -y install cuda-cudart-dev-12-2 cuda-nvcc-12-2 python -m pip install tensorflow + wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip -O libtorch.zip + unzip libtorch.zip env: DEBIAN_FRONTEND: noninteractive # Initializes the CodeQL tools for scanning. @@ -46,7 +48,9 @@ jobs: languages: ${{ matrix.language }} queries: security-extended,security-and-quality - name: "Run, Build Application using script" - run: source/install/build_cc.sh + run: | + export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch + source/install/build_cc.sh env: DP_VARIANT: cuda DOWNLOAD_TENSORFLOW: "FALSE" diff --git a/.github/workflows/test_cc.yml b/.github/workflows/test_cc.yml index ef6fade8e5..1ded666070 100644 --- a/.github/workflows/test_cc.yml +++ b/.github/workflows/test_cc.yml @@ -18,7 +18,13 @@ jobs: mpi: mpich - uses: lukka/get-cmake@latest - run: python -m pip install tensorflow - - run: source/install/test_cc_local.sh + - name: Download libtorch + run: | + wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip -O libtorch.zip + unzip libtorch.zip + - run: | + export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch + source/install/test_cc_local.sh env: OMP_NUM_THREADS: 1 TF_INTRA_OP_PARALLELISM_THREADS: 1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d4e89f1129..0fd2d1b40f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: exclude: ^source/3rdparty - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.1.13 + rev: v0.1.14 hooks: - id: ruff args: ["--fix"] diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 81cfdde7a8..45db3fcb0c 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -26,6 +26,7 @@ from deepmd.utils.sess import ( run_sess, ) +from deepmd_utils.infer.deep_pot import DeepPot as DeepPotBase if TYPE_CHECKING: from pathlib import ( @@ -35,7 +36,7 @@ log = logging.getLogger(__name__) -class DeepPot(DeepEval): +class DeepPot(DeepEval, DeepPotBase): """Constructor. Parameters diff --git a/deepmd_utils/infer/__init__.py b/deepmd_utils/infer/__init__.py new file mode 100644 index 0000000000..644f5e1f43 --- /dev/null +++ b/deepmd_utils/infer/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .deep_pot import ( + DeepPot, +) + +__all__ = ["DeepPot"] diff --git a/deepmd_utils/infer/backend.py b/deepmd_utils/infer/backend.py new file mode 100644 index 0000000000..809e19466b --- /dev/null +++ b/deepmd_utils/infer/backend.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from enum import ( + Enum, +) + + +class DPBackend(Enum): + """DeePMD-kit backend.""" + + TensorFlow = 1 + PyTorch = 2 + Paddle = 3 + Unknown = 4 + + +def detect_backend(filename: str) -> DPBackend: + """Detect the backend of the given model file. + + Parameters + ---------- + filename : str + The model file name + """ + if filename.endswith(".pb"): + return DPBackend.TensorFlow + elif filename.endswith(".pth") or filename.endswith(".pt"): + return DPBackend.PyTorch + elif filename.endswith(".pdmodel"): + return DPBackend.Paddle + return DPBackend.Unknown + + +__all__ = ["DPBackend", "detect_backend"] diff --git a/deepmd_utils/infer/deep_pot.py b/deepmd_utils/infer/deep_pot.py new file mode 100644 index 0000000000..dec0a7c47c --- /dev/null +++ b/deepmd_utils/infer/deep_pot.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + List, + Optional, + Tuple, + Union, +) + +import numpy as np + +from deepmd_utils.utils.batch_size import ( + AutoBatchSize, +) + +from .backend import ( + DPBackend, + detect_backend, +) + + +class DeepPot(ABC): + """Potential energy model. + + Parameters + ---------- + model_file : Path + The name of the frozen model file. + auto_batch_size : bool or int or AutoBatchSize, default: True + If True, automatic batch size will be used. If int, it will be used + as the initial batch size. + neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional + The ASE neighbor list class to produce the neighbor list. If None, the + neighbor list will be built natively in the model. + """ + + @abstractmethod + def __init__( + self, + model_file, + *args, + auto_batch_size: Union[bool, int, AutoBatchSize] = True, + neighbor_list=None, + **kwargs, + ) -> None: + pass + + def __new__(cls, model_file: str, *args, **kwargs): + if cls is DeepPot: + backend = detect_backend(model_file) + if backend == DPBackend.TensorFlow: + from deepmd.infer.deep_pot import DeepPot as DeepPotTF + + return super().__new__(DeepPotTF) + elif backend == DPBackend.PyTorch: + from deepmd_pt.infer.deep_eval import DeepPot as DeepPotPT + + return super().__new__(DeepPotPT) + else: + raise NotImplementedError("Unsupported backend: " + str(backend)) + return super().__new__(cls) + + @abstractmethod + def eval( + self, + coords: np.ndarray, + cells: np.ndarray, + atom_types: List[int], + atomic: bool = False, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + efield: Optional[np.ndarray] = None, + mixed_type: bool = False, + ) -> Tuple[np.ndarray, ...]: + """Evaluate energy, force, and virial. If atomic is True, + also return atomic energy and atomic virial. + + Parameters + ---------- + coords : np.ndarray + The coordinates of the atoms, in shape (nframes, natoms, 3). + cells : np.ndarray + The cell vectors of the system, in shape (nframes, 9). If the system + is not periodic, set it to None. + atom_types : List[int] + The types of the atoms. If mixed_type is False, the shape is (natoms,); + otherwise, the shape is (nframes, natoms). + atomic : bool, optional + Whether to return atomic energy and atomic virial, by default False. + fparam : np.ndarray, optional + The frame parameters, by default None. + aparam : np.ndarray, optional + The atomic parameters, by default None. + efield : np.ndarray, optional + The electric field, by default None. + mixed_type : bool, optional + Whether the system contains mixed atom types, by default False. + + Returns + ------- + energy + The energy of the system, in shape (nframes,). + force + The force of the system, in shape (nframes, natoms, 3). + virial + The virial of the system, in shape (nframes, 9). + atomic_energy + The atomic energy of the system, in shape (nframes, natoms). Only returned + when atomic is True. + atomic_virial + The atomic virial of the system, in shape (nframes, natoms, 9). Only returned + when atomic is True. + """ + # This method has been used by: + # documentation python.md + # dp model_devi: +fparam, +aparam, +mixed_type + # dp test: +atomic, +fparam, +aparam, +efield, +mixed_type + # finetune: +mixed_type + # dpdata + # ase + + +__all__ = ["DeepPot"] diff --git a/doc/_static/css/custom.css b/doc/_static/css/custom.css index 1569dc4a38..8894f47813 100644 --- a/doc/_static/css/custom.css +++ b/doc/_static/css/custom.css @@ -7,8 +7,14 @@ pre{ .wy-side-nav-search .wy-dropdown > a img.logo, .wy-side-nav-search > a img.logo { width: 275px; } +img.platform-icon { + height: 2ex; +} @media (prefers-color-scheme: dark) { .wy-side-nav-search .wy-dropdown > a img.logo, .wy-side-nav-search > a img.logo { content: url("../logo-dark.svg"); } + img.platform-icon { + filter: invert(1); + } } diff --git a/doc/_static/pytorch.svg b/doc/_static/pytorch.svg new file mode 100644 index 0000000000..04aae0c2a3 --- /dev/null +++ b/doc/_static/pytorch.svg @@ -0,0 +1 @@ +PyTorch icon diff --git a/doc/_static/tensorflow.svg b/doc/_static/tensorflow.svg new file mode 100644 index 0000000000..48746104ec --- /dev/null +++ b/doc/_static/tensorflow.svg @@ -0,0 +1 @@ +TensorFlow icon diff --git a/doc/conf.py b/doc/conf.py index 63af974a86..e6bb4b6ba2 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -213,7 +213,10 @@ def setup(app): exhale_projects_args = { "cc": { "containmentFolder": "./API_CC", - "exhaleDoxygenStdin": "INPUT = ../source/api_cc/include/", + "exhaleDoxygenStdin": """INPUT = ../source/api_cc/include/ + PREDEFINED += BUILD_TENSORFLOW + BUILD_PYTORCH + """, "rootFileTitle": "C++ API", "rootFileName": "api_cc.rst", }, @@ -275,6 +278,11 @@ def setup(app): .. |PRECISION| replace:: {list_to_doc(PRECISION_DICT.keys())} """ +myst_substitutions = { + "tensorflow_icon": """![TensorFlow](/_static/tensorflow.svg){class=platform-icon}""", + "pytorch_icon": """![PyTorch](/_static/pytorch.svg){class=platform-icon}""", +} + # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for @@ -298,6 +306,8 @@ def setup(app): myst_enable_extensions = [ "dollarmath", "colon_fence", + "substitution", + "attrs_inline", ] myst_fence_as_directive = ("math",) # fix emoji issue in pdf diff --git a/doc/freeze/compress.md b/doc/freeze/compress.md index 7394f77143..b6c8966c60 100644 --- a/doc/freeze/compress.md +++ b/doc/freeze/compress.md @@ -1,4 +1,8 @@ -# Compress a model +# Compress a model {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: ## Theory diff --git a/doc/model/dplr.md b/doc/model/dplr.md index feea84e562..317630ebe5 100644 --- a/doc/model/dplr.md +++ b/doc/model/dplr.md @@ -1,4 +1,8 @@ -# Deep potential long-range (DPLR) +# Deep potential long-range (DPLR) {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: Notice: **The interfaces of DPLR are not stable and subject to change** diff --git a/doc/model/dprc.md b/doc/model/dprc.md index c7547a769f..48e18e8d89 100644 --- a/doc/model/dprc.md +++ b/doc/model/dprc.md @@ -1,4 +1,8 @@ -# Deep Potential - Range Correction (DPRc) +# Deep Potential - Range Correction (DPRc) {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: Deep Potential - Range Correction (DPRc) is designed to combine with QM/MM method, and corrects energies from a low-level QM/MM method to a high-level QM/MM method: diff --git a/doc/model/linear.md b/doc/model/linear.md index b5e7c5c76a..3891559d90 100644 --- a/doc/model/linear.md +++ b/doc/model/linear.md @@ -1,4 +1,8 @@ -## Linear model +## Linear model {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: One can linearly combine existing models with arbitrary coefficients: diff --git a/doc/model/pairtab.md b/doc/model/pairtab.md index 115345796a..719bb95004 100644 --- a/doc/model/pairtab.md +++ b/doc/model/pairtab.md @@ -1,4 +1,8 @@ -# Interpolation or combination with a pairwise potential +# Interpolation or combination with a pairwise potential {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: ## Theory In applications like the radiation damage simulation, the interatomic distance may become too close, so that the DFT calculations fail. diff --git a/doc/model/train-energy-spin.md b/doc/model/train-energy-spin.md index d155ec977d..e0b3968c09 100644 --- a/doc/model/train-energy-spin.md +++ b/doc/model/train-energy-spin.md @@ -1,4 +1,8 @@ -# Fit spin energy +# Fit spin energy {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: In this section, we will take `$deepmd_source_dir/examples/NiO/se_e2_a/input.json` as an example of the input file. diff --git a/doc/model/train-energy.md b/doc/model/train-energy.md index 90e027d7a0..74a933c79c 100644 --- a/doc/model/train-energy.md +++ b/doc/model/train-energy.md @@ -1,4 +1,8 @@ -# Fit energy +# Fit energy {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: In this section, we will take `$deepmd_source_dir/examples/water/se_e2_a/input.json` as an example of the input file. diff --git a/doc/model/train-fitting-dos.md b/doc/model/train-fitting-dos.md index bbe5b50690..b74ab3acf7 100644 --- a/doc/model/train-fitting-dos.md +++ b/doc/model/train-fitting-dos.md @@ -1,4 +1,8 @@ -# Fit electronic density of states (DOS) +# Fit electronic density of states (DOS) {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: Here we present an API to DeepDOS model, which can be used to fit electronic density of state (DOS) (which is a vector). diff --git a/doc/model/train-fitting-tensor.md b/doc/model/train-fitting-tensor.md index 90370adfcf..3272418a7c 100644 --- a/doc/model/train-fitting-tensor.md +++ b/doc/model/train-fitting-tensor.md @@ -1,4 +1,8 @@ -# Fit `tensor` like `Dipole` and `Polarizability` +# Fit `tensor` like `Dipole` and `Polarizability` {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: Unlike `energy`, which is a scalar, one may want to fit some high dimensional physical quantity, like `dipole` (vector) and `polarizability` (matrix, shorted as `polar`). Deep Potential has provided different APIs to do this. In this example, we will show you how to train a model to fit a water system. A complete training input script of the examples can be found in diff --git a/doc/model/train-hybrid.md b/doc/model/train-hybrid.md index 58b66f25e0..1db3f49a1f 100644 --- a/doc/model/train-hybrid.md +++ b/doc/model/train-hybrid.md @@ -1,4 +1,8 @@ -# Descriptor `"hybrid"` +# Descriptor `"hybrid"` {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: This descriptor hybridizes multiple descriptors to form a new descriptor. For example, we have a list of descriptors denoted by $\mathcal D_1$, $\mathcal D_2$, ..., $\mathcal D_N$, the hybrid descriptor this the concatenation of the list, i.e. $\mathcal D = (\mathcal D_1, \mathcal D_2, \cdots, \mathcal D_N)$. diff --git a/doc/model/train-se-a-mask.md b/doc/model/train-se-a-mask.md index 17c211ec73..6d0e2e0320 100644 --- a/doc/model/train-se-a-mask.md +++ b/doc/model/train-se-a-mask.md @@ -1,4 +1,8 @@ -# Descriptor `"se_a_mask"` +# Descriptor `"se_a_mask"` {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: Descriptor `se_a_mask` is a concise implementation of the descriptor `se_e2_a`, diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index 7480ddbc12..b4e346327d 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -1,4 +1,8 @@ -# Descriptor `"se_atten"` +# Descriptor `"se_atten"` {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: ## DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation diff --git a/doc/model/train-se-e2-a-tebd.md b/doc/model/train-se-e2-a-tebd.md index cb6ce6674f..7797a8f3c0 100644 --- a/doc/model/train-se-e2-a-tebd.md +++ b/doc/model/train-se-e2-a-tebd.md @@ -1,4 +1,8 @@ -# Type embedding approach +# Type embedding approach {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: We generate specific a type embedding vector for each atom type so that we can share one descriptor embedding net and one fitting net in total, which decline training complexity largely. diff --git a/doc/model/train-se-e2-a.md b/doc/model/train-se-e2-a.md index 537253a6d9..d40bb513ea 100644 --- a/doc/model/train-se-e2-a.md +++ b/doc/model/train-se-e2-a.md @@ -1,4 +1,8 @@ -# Descriptor `"se_e2_a"` +# Descriptor `"se_e2_a"` {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: The notation of `se_e2_a` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from all information (both angular and radial) of atomic configurations. The `e2` stands for the embedding with two-atoms information. This descriptor was described in detail in [the DeepPot-SE paper](https://arxiv.org/abs/1805.09003). diff --git a/doc/model/train-se-e2-r.md b/doc/model/train-se-e2-r.md index f2f990b16a..c2c5fcfcd9 100644 --- a/doc/model/train-se-e2-r.md +++ b/doc/model/train-se-e2-r.md @@ -1,4 +1,8 @@ -# Descriptor `"se_e2_r"` +# Descriptor `"se_e2_r"` {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: The notation of `se_e2_r` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from the radial information of atomic configurations. The `e2` stands for the embedding with two-atom information. diff --git a/doc/model/train-se-e3.md b/doc/model/train-se-e3.md index 5b0710a389..4eb35357a0 100644 --- a/doc/model/train-se-e3.md +++ b/doc/model/train-se-e3.md @@ -1,4 +1,8 @@ -# Descriptor `"se_e3"` +# Descriptor `"se_e3"` {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: The notation of `se_e3` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from all information (both angular and radial) of atomic configurations. The embedding takes bond angles between a central atom and its two neighboring atoms as input (denoted by `e3`). diff --git a/doc/nvnmd/nvnmd.md b/doc/nvnmd/nvnmd.md index c11fee0bc9..7c00baad27 100644 --- a/doc/nvnmd/nvnmd.md +++ b/doc/nvnmd/nvnmd.md @@ -1,4 +1,8 @@ -# Introduction +# Introduction {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: NVNMD stands for non-von Neumann molecular dynamics. diff --git a/doc/train/finetuning.md b/doc/train/finetuning.md index ebc7cda2c9..bbab74f41e 100644 --- a/doc/train/finetuning.md +++ b/doc/train/finetuning.md @@ -1,4 +1,8 @@ -# Finetune the pretrained model +# Finetune the pretrained model {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: Pretraining-and-finetuning is a widely used approach in other fields such as Computer Vision (CV) or Natural Language Processing (NLP) to vastly reduce the training cost, while it's not trivial in potential models. diff --git a/doc/train/gpu-limitations.md b/doc/train/gpu-limitations.md index 5df76d28c9..dee606c2a3 100644 --- a/doc/train/gpu-limitations.md +++ b/doc/train/gpu-limitations.md @@ -1,4 +1,5 @@ -# Known limitations of using GPUs +# Known limitations of using GPUs {{ tensorflow_icon }} + If you use DeePMD-kit in a GPU environment, the acceptable value range of some variables is additionally restricted compared to the CPU environment due to the software's GPU implementations: 1. The number of atom types of a given system must be less than 128. 2. The maximum distance between an atom and its neighbors must be less than 128. It can be controlled by setting the rcut value of training parameters. diff --git a/doc/train/multi-task-training.md b/doc/train/multi-task-training.md index c647e6905e..76f404ab88 100644 --- a/doc/train/multi-task-training.md +++ b/doc/train/multi-task-training.md @@ -1,4 +1,8 @@ -# Multi-task training +# Multi-task training {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: ## Theory diff --git a/doc/train/parallel-training.md b/doc/train/parallel-training.md index 98d12f2b9b..4c707e5607 100644 --- a/doc/train/parallel-training.md +++ b/doc/train/parallel-training.md @@ -1,4 +1,8 @@ -# Parallel training +# Parallel training {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: Currently, parallel training is enabled in a synchronized way with help of [Horovod](https://github.com/horovod/horovod). Depending on the number of training processes (according to MPI context) and the number of GPU cards available, DeePMD-kit will decide whether to launch the training in parallel (distributed) mode or in serial mode. Therefore, no additional options are specified in your JSON/YAML input file. diff --git a/doc/train/tensorboard.md b/doc/train/tensorboard.md index 4846005216..1d6c5f0d68 100644 --- a/doc/train/tensorboard.md +++ b/doc/train/tensorboard.md @@ -1,4 +1,8 @@ -# TensorBoard Usage +# TensorBoard Usage {{ tensorflow_icon }} + +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: TensorBoard provides the visualization and tooling needed for machine learning experimentation. Full instructions for TensorBoard can be found diff --git a/pyproject.toml b/pyproject.toml index e91fd320f3..550fbc4b54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,6 +155,7 @@ environment-pass = [ "DP_VARIANT", "CUDA_VERSION", "DP_PKG_NAME", + "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DEEPMD-KIT-CU11", ] environment = { PIP_PREFER_BINARY="1", DP_LAMMPS_VERSION="stable_2Aug2023_update2", DP_ENABLE_IPI="1", MPI_HOME="/usr/lib64/mpich", PATH="/usr/lib64/mpich/bin:$PATH" } before-all = [ diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index c1c9b8e7fe..c273bc9263 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.16) project(DeePMD) +option(ENABLE_TENSORFLOW "Enable TensorFlow interface" OFF) +option(ENABLE_PYTORCH "Enable PyTorch interface" OFF) option(BUILD_TESTING "Build test and enable converage" OFF) set(DEEPMD_C_ROOT "" @@ -131,6 +133,7 @@ if(INSTALL_TENSORFLOW) set(USE_TF_PYTHON_LIBS TRUE) endif(INSTALL_TENSORFLOW) if(USE_TF_PYTHON_LIBS) + set(ENABLE_TENSORFLOW TRUE) if(NOT "$ENV{CIBUILDWHEEL}" STREQUAL "1") find_package( Python @@ -141,11 +144,31 @@ if(USE_TF_PYTHON_LIBS) set(PYTHON_INCLUDE_DIRS ${PYTHON_INCLUDE_DIR}) endif() endif(USE_TF_PYTHON_LIBS) +if(TENSORFLOW_ROOT) + set(ENABLE_TENSORFLOW TRUE) +endif() # find tensorflow, I need tf abi info -if(NOT DEEPMD_C_ROOT) +if(ENABLE_TENSORFLOW AND NOT DEEPMD_C_ROOT) find_package(tensorflow REQUIRED) endif() +if(ENABLE_PYTORCH AND NOT DEEPMD_C_ROOT) + find_package(Torch REQUIRED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") +endif() +# log enabled backends +if(NOT DEEPMD_C_ROOT) + message(STATUS "Enabled backends:") + if(ENABLE_TENSORFLOW) + message(STATUS "- TensorFlow") + endif() + if(ENABLE_PYTORCH) + message(STATUS "- PyTorch") + endif() + if(NOT ENABLE_TENSORFLOW AND NOT ENABLE_PYTORCH) + message(FATAL_ERROR "No backend is enabled.") + endif() +endif() # find threads find_package(Threads) @@ -233,10 +256,13 @@ if(DEEPMD_C_ROOT) # use variable for TF path to set deepmd_c path set(TensorFlow_LIBRARY_PATH "${DEEPMD_C_ROOT}/lib") set(TENSORFLOW_INCLUDE_DIRS "${DEEPMD_C_ROOT}/include") + set(TORCH_LIBRARIES "${DEEPMD_C_ROOT}/lib/libtorch.so") endif() if(NOT DEEPMD_C_ROOT) - add_subdirectory(op/) + if(ENABLE_TENSORFLOW) + add_subdirectory(op/) + endif() add_subdirectory(lib/) endif() if(BUILD_PY_IF) diff --git a/source/api_cc/CMakeLists.txt b/source/api_cc/CMakeLists.txt index 2f296e3dfd..cd42594f1e 100644 --- a/source/api_cc/CMakeLists.txt +++ b/source/api_cc/CMakeLists.txt @@ -11,8 +11,16 @@ add_library(${libname} SHARED ${LIB_SRC}) # link: libdeepmd libdeepmd_op libtensorflow_cc libtensorflow_framework target_link_libraries(${libname} PUBLIC ${LIB_DEEPMD}) -target_link_libraries(${libname} PRIVATE TensorFlow::tensorflow_cc - TensorFlow::tensorflow_framework) +if(ENABLE_TENSORFLOW) + target_link_libraries(${libname} PRIVATE TensorFlow::tensorflow_cc + TensorFlow::tensorflow_framework) + target_compile_definitions(${libname} PRIVATE BUILD_TENSORFLOW) +endif() +if(ENABLE_PYTORCH) + target_link_libraries(${libname} PRIVATE "${TORCH_LIBRARIES}") + target_compile_definitions(${libname} PRIVATE BUILD_PYTORCH) +endif() + target_include_directories( ${libname} PUBLIC $ @@ -55,3 +63,14 @@ ${CMAKE_INSTALL_PREFIX}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}${libname}${LOW_PREC_V add_subdirectory(tests) endif() endif(BUILD_PY_IF) + +if(BUILD_TESTING) + # A compilation test to make sure api_cc can compile without any backend + add_library(deepmd_cc_test_no_backend SHARED ${LIB_SRC}) + target_link_libraries(deepmd_cc_test_no_backend PUBLIC ${LIB_DEEPMD}) + target_include_directories( + deepmd_cc_test_no_backend + PUBLIC $ + $ + $) +endif() diff --git a/source/api_cc/include/DataModifierTF.h b/source/api_cc/include/DataModifierTF.h index 2ca3729525..c0021c6947 100644 --- a/source/api_cc/include/DataModifierTF.h +++ b/source/api_cc/include/DataModifierTF.h @@ -3,6 +3,7 @@ #include "DataModifier.h" #include "common.h" +#include "commonTF.h" namespace deepmd { /** diff --git a/source/api_cc/include/DeepPotTF.h b/source/api_cc/include/DeepPotTF.h index 0580c61da5..699b0ff7fe 100644 --- a/source/api_cc/include/DeepPotTF.h +++ b/source/api_cc/include/DeepPotTF.h @@ -3,6 +3,7 @@ #include "DeepPot.h" #include "common.h" +#include "commonTF.h" #include "neighbor_list.h" namespace deepmd { diff --git a/source/api_cc/include/DeepTensorTF.h b/source/api_cc/include/DeepTensorTF.h index 3c724dce88..3ca316a29f 100644 --- a/source/api_cc/include/DeepTensorTF.h +++ b/source/api_cc/include/DeepTensorTF.h @@ -3,6 +3,7 @@ #include "DeepTensor.h" #include "common.h" +#include "commonTF.h" #include "neighbor_list.h" namespace deepmd { diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index 7982c4f89d..0392747979 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -10,12 +10,6 @@ #include "neighbor_list.h" #include "version.h" -#ifdef TF_PRIVATE -#include "tf_private.h" -#else -#include "tf_public.h" -#endif - namespace deepmd { typedef double ENERGYTYPE; @@ -175,143 +169,8 @@ struct tf_exception : public deepmd::deepmd_exception { : deepmd::deepmd_exception(std::string("TensorFlow Error: ") + msg){}; }; -/** - * @brief Check TensorFlow status. Exit if not OK. - * @param[in] status TensorFlow status. - **/ -void check_status(const tensorflow::Status& status); - std::string name_prefix(const std::string& name_scope); -/** - * @brief Get the value of a tensor. - * @param[in] session TensorFlow session. - * @param[in] name The name of the tensor. - * @param[in] scope The scope of the tensor. - * @return The value of the tensor. - **/ -template -VT session_get_scalar(tensorflow::Session* session, - const std::string name, - const std::string scope = ""); - -/** - * @brief Get the vector of a tensor. - * @param[out] o_vec The output vector. - * @param[in] session TensorFlow session. - * @param[in] name The name of the tensor. - * @param[in] scope The scope of the tensor. - **/ -template -void session_get_vector(std::vector& o_vec, - tensorflow::Session* session, - const std::string name_, - const std::string scope = ""); - -/** - * @brief Get the type of a tensor. - * @param[in] session TensorFlow session. - * @param[in] name The name of the tensor. - * @param[in] scope The scope of the tensor. - * @return The type of the tensor as int. - **/ -int session_get_dtype(tensorflow::Session* session, - const std::string name, - const std::string scope = ""); - -/** - * @brief Get input tensors. - * @param[out] input_tensors Input tensors. - * @param[in] dcoord_ Coordinates of atoms. - * @param[in] ntypes Number of atom types. - * @param[in] datype_ Atom types. - * @param[in] dbox Box matrix. - * @param[in] cell_size Cell size. - * @param[in] fparam_ Frame parameters. - * @param[in] aparam_ Atom parameters. - * @param[in] atommap Atom map. - * @param[in] scope The scope of the tensors. - * @param[in] aparam_nall Whether the atomic dimesion of atomic parameters is - * nall. - */ -template -int session_input_tensors( - std::vector>& input_tensors, - const std::vector& dcoord_, - const int& ntypes, - const std::vector& datype_, - const std::vector& dbox, - const double& cell_size, - const std::vector& fparam_, - const std::vector& aparam_, - const deepmd::AtomMap& atommap, - const std::string scope = "", - const bool aparam_nall = false); - -/** - * @brief Get input tensors. - * @param[out] input_tensors Input tensors. - * @param[in] dcoord_ Coordinates of atoms. - * @param[in] ntypes Number of atom types. - * @param[in] datype_ Atom types. - * @param[in] dlist Neighbor list. - * @param[in] fparam_ Frame parameters. - * @param[in] aparam_ Atom parameters. - * @param[in] atommap Atom map. - * @param[in] nghost Number of ghost atoms. - * @param[in] ago Update the internal neighbour list if ago is 0. - * @param[in] scope The scope of the tensors. - * @param[in] aparam_nall Whether the atomic dimesion of atomic parameters is - * nall. - */ -template -int session_input_tensors( - std::vector>& input_tensors, - const std::vector& dcoord_, - const int& ntypes, - const std::vector& datype_, - const std::vector& dbox, - InputNlist& dlist, - const std::vector& fparam_, - const std::vector& aparam_, - const deepmd::AtomMap& atommap, - const int nghost, - const int ago, - const std::string scope = "", - const bool aparam_nall = false); - -/** - * @brief Get input tensors for mixed type. - * @param[out] input_tensors Input tensors. - * @param[in] nframes Number of frames. - * @param[in] dcoord_ Coordinates of atoms. - * @param[in] ntypes Number of atom types. - * @param[in] datype_ Atom types. - * @param[in] dlist Neighbor list. - * @param[in] fparam_ Frame parameters. - * @param[in] aparam_ Atom parameters. - * @param[in] atommap Atom map. - * @param[in] nghost Number of ghost atoms. - * @param[in] ago Update the internal neighbour list if ago is 0. - * @param[in] scope The scope of the tensors. - * @param[in] aparam_nall Whether the atomic dimesion of atomic parameters is - * nall. - */ -template -int session_input_tensors_mixed_type( - std::vector>& input_tensors, - const int& nframes, - const std::vector& dcoord_, - const int& ntypes, - const std::vector& datype_, - const std::vector& dbox, - const double& cell_size, - const std::vector& fparam_, - const std::vector& aparam_, - const deepmd::AtomMap& atommap, - const std::string scope = "", - const bool aparam_nall = false); - /** * @brief Read model file to a string. * @param[in] model Path to the model. diff --git a/source/api_cc/include/commonTF.h b/source/api_cc/include/commonTF.h new file mode 100644 index 0000000000..0c14597e30 --- /dev/null +++ b/source/api_cc/include/commonTF.h @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#include +#include + +#ifdef TF_PRIVATE +#include "tf_private.h" +#else +#include "tf_public.h" +#endif + +namespace deepmd { +/** + * @brief Check TensorFlow status. Exit if not OK. + * @param[in] status TensorFlow status. + **/ +void check_status(const tensorflow::Status& status); + +/** + * @brief Get the value of a tensor. + * @param[in] session TensorFlow session. + * @param[in] name The name of the tensor. + * @param[in] scope The scope of the tensor. + * @return The value of the tensor. + **/ +template +VT session_get_scalar(tensorflow::Session* session, + const std::string name, + const std::string scope = ""); + +/** + * @brief Get the vector of a tensor. + * @param[out] o_vec The output vector. + * @param[in] session TensorFlow session. + * @param[in] name The name of the tensor. + * @param[in] scope The scope of the tensor. + **/ +template +void session_get_vector(std::vector& o_vec, + tensorflow::Session* session, + const std::string name_, + const std::string scope = ""); + +/** + * @brief Get the type of a tensor. + * @param[in] session TensorFlow session. + * @param[in] name The name of the tensor. + * @param[in] scope The scope of the tensor. + * @return The type of the tensor as int. + **/ +int session_get_dtype(tensorflow::Session* session, + const std::string name, + const std::string scope = ""); + +/** + * @brief Get input tensors. + * @param[out] input_tensors Input tensors. + * @param[in] dcoord_ Coordinates of atoms. + * @param[in] ntypes Number of atom types. + * @param[in] datype_ Atom types. + * @param[in] dbox Box matrix. + * @param[in] cell_size Cell size. + * @param[in] fparam_ Frame parameters. + * @param[in] aparam_ Atom parameters. + * @param[in] atommap Atom map. + * @param[in] scope The scope of the tensors. + * @param[in] aparam_nall Whether the atomic dimesion of atomic parameters is + * nall. + */ +template +int session_input_tensors( + std::vector>& input_tensors, + const std::vector& dcoord_, + const int& ntypes, + const std::vector& datype_, + const std::vector& dbox, + const double& cell_size, + const std::vector& fparam_, + const std::vector& aparam_, + const deepmd::AtomMap& atommap, + const std::string scope = "", + const bool aparam_nall = false); + +/** + * @brief Get input tensors. + * @param[out] input_tensors Input tensors. + * @param[in] dcoord_ Coordinates of atoms. + * @param[in] ntypes Number of atom types. + * @param[in] datype_ Atom types. + * @param[in] dlist Neighbor list. + * @param[in] fparam_ Frame parameters. + * @param[in] aparam_ Atom parameters. + * @param[in] atommap Atom map. + * @param[in] nghost Number of ghost atoms. + * @param[in] ago Update the internal neighbour list if ago is 0. + * @param[in] scope The scope of the tensors. + * @param[in] aparam_nall Whether the atomic dimesion of atomic parameters is + * nall. + */ +template +int session_input_tensors( + std::vector>& input_tensors, + const std::vector& dcoord_, + const int& ntypes, + const std::vector& datype_, + const std::vector& dbox, + InputNlist& dlist, + const std::vector& fparam_, + const std::vector& aparam_, + const deepmd::AtomMap& atommap, + const int nghost, + const int ago, + const std::string scope = "", + const bool aparam_nall = false); + +/** + * @brief Get input tensors for mixed type. + * @param[out] input_tensors Input tensors. + * @param[in] nframes Number of frames. + * @param[in] dcoord_ Coordinates of atoms. + * @param[in] ntypes Number of atom types. + * @param[in] datype_ Atom types. + * @param[in] dlist Neighbor list. + * @param[in] fparam_ Frame parameters. + * @param[in] aparam_ Atom parameters. + * @param[in] atommap Atom map. + * @param[in] nghost Number of ghost atoms. + * @param[in] ago Update the internal neighbour list if ago is 0. + * @param[in] scope The scope of the tensors. + * @param[in] aparam_nall Whether the atomic dimesion of atomic parameters is + * nall. + */ +template +int session_input_tensors_mixed_type( + std::vector>& input_tensors, + const int& nframes, + const std::vector& dcoord_, + const int& ntypes, + const std::vector& datype_, + const std::vector& dbox, + const double& cell_size, + const std::vector& fparam_, + const std::vector& aparam_, + const deepmd::AtomMap& atommap, + const std::string scope = "", + const bool aparam_nall = false); + +} // namespace deepmd diff --git a/source/api_cc/include/version.h.in b/source/api_cc/include/version.h.in index c6bf6cf491..26b0c1be48 100644 --- a/source/api_cc/include/version.h.in +++ b/source/api_cc/include/version.h.in @@ -9,4 +9,5 @@ const std::string global_git_date="@GIT_DATE@"; const std::string global_git_branch="@GIT_BRANCH@"; const std::string global_tf_include_dir="@TensorFlow_INCLUDE_DIRS@"; const std::string global_tf_lib="@TensorFlow_LIBRARY@"; +const std::string global_pt_lib="@TORCH_LIBRARIES@"; const std::string global_model_version="@MODEL_VERSION@"; diff --git a/source/api_cc/src/DataModifier.cc b/source/api_cc/src/DataModifier.cc index 954c969c13..38d1fc879a 100644 --- a/source/api_cc/src/DataModifier.cc +++ b/source/api_cc/src/DataModifier.cc @@ -1,7 +1,9 @@ // SPDX-License-Identifier: LGPL-3.0-or-later #include "DataModifier.h" +#ifdef BUILD_TENSORFLOW #include "DataModifierTF.h" +#endif #include "common.h" using namespace deepmd; @@ -29,9 +31,12 @@ void DipoleChargeModifier::init(const std::string& model, // TODO: To implement detect_backend DPBackend backend = deepmd::DPBackend::TensorFlow; if (deepmd::DPBackend::TensorFlow == backend) { - // TODO: throw errors if TF backend is not built, without mentioning TF +#ifdef BUILD_TENSORFLOW dcm = std::make_shared(model, gpu_rank, name_scope_); +#else + throw deepmd::deepmd_exception("TensorFlow backend is not built"); +#endif } else if (deepmd::DPBackend::PyTorch == backend) { throw deepmd::deepmd_exception("PyTorch backend is not supported yet"); } else if (deepmd::DPBackend::Paddle == backend) { diff --git a/source/api_cc/src/DataModifierTF.cc b/source/api_cc/src/DataModifierTF.cc index 219139cf89..324cb14098 100644 --- a/source/api_cc/src/DataModifierTF.cc +++ b/source/api_cc/src/DataModifierTF.cc @@ -1,4 +1,5 @@ // SPDX-License-Identifier: LGPL-3.0-or-later +#ifdef BUILD_TENSORFLOW #include "DataModifierTF.h" #include "common.h" @@ -361,3 +362,4 @@ void DipoleChargeModifierTF::computew( compute(dfcorr_, dvcorr_, dcoord_, datype_, dbox, pairs, delef_, nghost, lmp_list); } +#endif // BUILD_TENSORFLOW diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index 083e9b091f..c598549844 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -7,7 +7,9 @@ #include #include "AtomMap.h" +#ifdef BUILD_TENSORFLOW #include "DeepPotTF.h" +#endif #include "device.h" using namespace deepmd; @@ -35,8 +37,11 @@ void DeepPot::init(const std::string& model, // TODO: To implement detect_backend DPBackend backend = deepmd::DPBackend::TensorFlow; if (deepmd::DPBackend::TensorFlow == backend) { - // TODO: throw errors if TF backend is not built, without mentioning TF +#ifdef BUILD_TENSORFLOW dp = std::make_shared(model, gpu_rank, file_content); +#else + throw deepmd::deepmd_exception("TensorFlow backend is not built"); +#endif } else if (deepmd::DPBackend::PyTorch == backend) { throw deepmd::deepmd_exception("PyTorch backend is not supported yet"); } else if (deepmd::DPBackend::Paddle == backend) { diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc new file mode 100644 index 0000000000..c94fb4247b --- /dev/null +++ b/source/api_cc/src/DeepPotPT.cc @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#ifdef BUILD_PYTORCH +#include + +void test_function_please_remove_after_torch_is_actually_used() { + torch::Tensor tensor = torch::rand({2, 3}); +} +#endif diff --git a/source/api_cc/src/DeepPotTF.cc b/source/api_cc/src/DeepPotTF.cc index ef348fe14c..7bf2bebce4 100644 --- a/source/api_cc/src/DeepPotTF.cc +++ b/source/api_cc/src/DeepPotTF.cc @@ -1,4 +1,5 @@ // SPDX-License-Identifier: LGPL-3.0-or-later +#ifdef BUILD_TENSORFLOW #include "DeepPotTF.h" #include @@ -1051,3 +1052,4 @@ void DeepPotTF::computew_mixed_type(std::vector& ener, compute_mixed_type(ener, force, virial, atom_energy, atom_virial, nframes, coord, atype, box, fparam, aparam); } +#endif diff --git a/source/api_cc/src/DeepTensor.cc b/source/api_cc/src/DeepTensor.cc index 2c88ab2f4b..a0596e046f 100644 --- a/source/api_cc/src/DeepTensor.cc +++ b/source/api_cc/src/DeepTensor.cc @@ -3,7 +3,9 @@ #include +#ifdef BUILD_TENSORFLOW #include "DeepTensorTF.h" +#endif #include "common.h" using namespace deepmd; @@ -31,8 +33,11 @@ void DeepTensor::init(const std::string &model, // TODO: To implement detect_backend DPBackend backend = deepmd::DPBackend::TensorFlow; if (deepmd::DPBackend::TensorFlow == backend) { - // TODO: throw errors if TF backend is not built, without mentioning TF +#ifdef BUILD_TENSORFLOW dt = std::make_shared(model, gpu_rank, name_scope_); +#else + throw deepmd::deepmd_exception("TensorFlow backend is not built."); +#endif } else if (deepmd::DPBackend::PyTorch == backend) { throw deepmd::deepmd_exception("PyTorch backend is not supported yet"); } else if (deepmd::DPBackend::Paddle == backend) { diff --git a/source/api_cc/src/DeepTensorTF.cc b/source/api_cc/src/DeepTensorTF.cc index 436e389ad2..34a47bc6f3 100644 --- a/source/api_cc/src/DeepTensorTF.cc +++ b/source/api_cc/src/DeepTensorTF.cc @@ -1,4 +1,5 @@ // SPDX-License-Identifier: LGPL-3.0-or-later +#ifdef BUILD_TENSORFLOW #include "DeepTensorTF.h" using namespace deepmd; @@ -844,3 +845,4 @@ void DeepTensorTF::computew(std::vector &global_tensor, atom_virial.clear(); } } +#endif diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 2f75aaa291..a552f646f1 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -3,6 +3,8 @@ #include +#include + #include "AtomMap.h" #include "device.h" #if defined(_WIN32) @@ -20,10 +22,13 @@ // not windows #include #endif +#ifdef BUILD_TENSORFLOW +#include "commonTF.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/text_format.h" using namespace tensorflow; +#endif static std::vector split(const std::string& input_, const std::string& delimiter) { @@ -300,12 +305,14 @@ void deepmd::NeighborListData::make_inlist(InputNlist& inlist) { inlist.firstneigh = &firstneigh[0]; } +#ifdef BUILD_TENSORFLOW void deepmd::check_status(const tensorflow::Status& status) { if (!status.ok()) { std::cout << status.ToString() << std::endl; throw deepmd::tf_exception(status.ToString()); } } +#endif void throw_env_not_set_warning(std::string env_name) { std::cerr << "DeePMD-kit WARNING: Environmental variable " << env_name @@ -345,6 +352,7 @@ void deepmd::get_env_nthreads(int& num_intra_nthreads, } void deepmd::load_op_library() { +#ifdef BUILD_TENSORFLOW tensorflow::Env* env = tensorflow::Env::Default(); #if defined(_WIN32) std::string dso_path = "deepmd_op.dll"; @@ -358,6 +366,7 @@ void deepmd::load_op_library() { dso_path + " is not found! You can add the library directory to LD_LIBRARY_PATH"); } +#endif } std::string deepmd::name_prefix(const std::string& scope) { @@ -368,6 +377,7 @@ std::string deepmd::name_prefix(const std::string& scope) { return prefix; } +#ifdef BUILD_TENSORFLOW template int deepmd::session_input_tensors( std::vector>& input_tensors, @@ -850,6 +860,7 @@ int deepmd::session_get_dtype(tensorflow::Session* session, // cast enum to int return (int)output_rc.dtype(); } +#endif template void deepmd::select_map(std::vector& out, @@ -940,6 +951,7 @@ void deepmd::select_map_inv(typename std::vector::iterator out, } } +#ifdef BUILD_TENSORFLOW template int deepmd::session_get_scalar(Session*, const std::string, const std::string); @@ -989,6 +1001,7 @@ template void deepmd::session_get_vector(std::vector&, Session*, const std::string, const std::string); +#endif template void deepmd::select_map(std::vector& out, const std::vector& in, @@ -1018,6 +1031,7 @@ template void deepmd::select_map_inv( const std::vector& idx_map, const int& stride); +#ifdef BUILD_TENSORFLOW template double deepmd::session_get_scalar(Session*, const std::string, const std::string); @@ -1026,6 +1040,7 @@ template void deepmd::session_get_vector(std::vector&, Session*, const std::string, const std::string); +#endif template void deepmd::select_map(std::vector& out, const std::vector& in, @@ -1055,6 +1070,7 @@ template void deepmd::select_map_inv( const std::vector& idx_map, const int& stride); +#ifdef BUILD_TENSORFLOW template deepmd::STRINGTYPE deepmd::session_get_scalar( Session*, const std::string, const std::string); @@ -1093,13 +1109,19 @@ template void deepmd::select_map_inv( const typename std::vector::const_iterator in, const std::vector& idx_map, const int& stride); +#endif void deepmd::read_file_to_string(std::string model, std::string& file_content) { +#ifdef BUILD_TENSORFLOW deepmd::check_status(tensorflow::ReadFileToString(tensorflow::Env::Default(), model, &file_content)); +#else + throw deepmd::deepmd_exception("TODO: read_file_to_string only support TF"); +#endif } void deepmd::convert_pbtxt_to_pb(std::string fn_pb_txt, std::string fn_pb) { +#ifdef BUILD_TENSORFLOW int fd = open(fn_pb_txt.c_str(), O_RDONLY); tensorflow::protobuf::io::ZeroCopyInputStream* input = new tensorflow::protobuf::io::FileInputStream(fd); @@ -1109,8 +1131,13 @@ void deepmd::convert_pbtxt_to_pb(std::string fn_pb_txt, std::string fn_pb) { std::fstream output(fn_pb, std::ios::out | std::ios::trunc | std::ios::binary); graph_def.SerializeToOstream(&output); +#else + throw deepmd::deepmd_exception( + "convert_pbtxt_to_pb: TensorFlow backend is not enabled."); +#endif } +#ifdef BUILD_TENSORFLOW template int deepmd::session_input_tensors( std::vector>& input_tensors, const std::vector& dcoord_, @@ -1272,6 +1299,7 @@ template int deepmd::session_input_tensors_mixed_type( const deepmd::AtomMap& atommap, const std::string scope, const bool aparam_nall); +#endif void deepmd::print_summary(const std::string& pre) { int num_intra_nthreads, num_inter_nthreads; @@ -1292,8 +1320,13 @@ void deepmd::print_summary(const std::string& pre) { std::cout << pre << "build variant: cpu" << "\n"; #endif +#ifdef BUILD_TENSORFLOW std::cout << pre << "build with tf inc: " + global_tf_include_dir << "\n"; std::cout << pre << "build with tf lib: " + global_tf_lib << "\n"; +#endif +#ifdef BUILD_PYTORCH + std::cout << pre << "build with pt lib: " + global_pt_lib << "\n"; +#endif std::cout << pre << "set tf intra_op_parallelism_threads: " << num_intra_nthreads << "\n"; diff --git a/source/install/build_cc.sh b/source/install/build_cc.sh index fef9e82ebc..83a586049d 100755 --- a/source/install/build_cc.sh +++ b/source/install/build_cc.sh @@ -20,7 +20,13 @@ NPROC=$(nproc --all) BUILD_TMP_DIR=${SCRIPT_PATH}/../build mkdir -p ${BUILD_TMP_DIR} cd ${BUILD_TMP_DIR} -cmake -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} -DUSE_TF_PYTHON_LIBS=TRUE ${CUDA_ARGS} -DLAMMPS_VERSION=stable_2Aug2023_update2 .. +cmake -D ENABLE_TENSORFLOW=ON \ + -D ENABLE_PYTORCH=ON \ + -D CMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \ + -D USE_TF_PYTHON_LIBS=TRUE \ + ${CUDA_ARGS} \ + -D LAMMPS_VERSION=stable_2Aug2023_update2 \ + .. cmake --build . -j${NPROC} cmake --install . diff --git a/source/op/prod_env_mat_multi_device.cc b/source/op/prod_env_mat_multi_device.cc index 5ee43d2af3..7037a00a6c 100644 --- a/source/op/prod_env_mat_multi_device.cc +++ b/source/op/prod_env_mat_multi_device.cc @@ -577,6 +577,15 @@ class ProdEnvMatAOp : public OpKernel { mesh_tensor.flat().data(), mesh_tensor_size, nloc, nei_mode, rcut_r, max_cpy_trial, max_nnei_trial); + // max_nbor_size may be changed after _prepare_coord_nlist_gpu + // So we need to update the uint64_temp tensor if necessary + if (uint64_temp.NumElements() < int_64(nloc) * max_nbor_size * 2) { + TensorShape uint64_shape; + uint64_shape.AddDim(int_64(nloc) * max_nbor_size * 2); + OP_REQUIRES_OK(context, context->allocate_temp( + DT_UINT64, uint64_shape, &uint64_temp)); + array_longlong = uint64_temp.flat().data(); + } // launch the gpu(nv) compute function deepmd::prod_env_mat_a_gpu(em, em_deriv, rij, nlist, coord, type, gpu_inlist, array_int, array_longlong, @@ -875,6 +884,16 @@ class ProdEnvMatROp : public OpKernel { mesh_tensor.flat().data(), mesh_tensor_size, nloc, nei_mode, rcut, max_cpy_trial, max_nnei_trial); + // max_nbor_size may be changed after _prepare_coord_nlist_gpu + // So we need to update the uint64_temp tensor if necessary + if (uint64_temp.NumElements() < int_64(nloc) * max_nbor_size * 2) { + TensorShape uint64_shape; + uint64_shape.AddDim(int_64(nloc) * max_nbor_size * 2); + OP_REQUIRES_OK(context, context->allocate_temp( + DT_UINT64, uint64_shape, &uint64_temp)); + array_longlong = uint64_temp.flat().data(); + } + // launch the gpu(nv) compute function deepmd::prod_env_mat_r_gpu(em, em_deriv, rij, nlist, coord, type, gpu_inlist, array_int, array_longlong, @@ -1221,6 +1240,16 @@ class ProdEnvMatAMixOp : public OpKernel { mesh_tensor.flat().data(), mesh_tensor_size, nloc, nei_mode, rcut_r, max_cpy_trial, max_nnei_trial); + // max_nbor_size may be changed after _prepare_coord_nlist_gpu + // So we need to update the uint64_temp tensor if necessary + if (uint64_temp.NumElements() < int_64(nloc) * max_nbor_size * 2) { + TensorShape uint64_shape; + uint64_shape.AddDim(int_64(nloc) * max_nbor_size * 2); + OP_REQUIRES_OK(context, context->allocate_temp( + DT_UINT64, uint64_shape, &uint64_temp)); + array_longlong = uint64_temp.flat().data(); + } + // launch the gpu(nv) compute function deepmd::prod_env_mat_a_gpu(em, em_deriv, rij, nlist, coord, type, gpu_inlist, array_int, array_longlong, diff --git a/source/tests/test_uni_infer.py b/source/tests/test_uni_infer.py new file mode 100644 index 0000000000..6b70d17f7e --- /dev/null +++ b/source/tests/test_uni_infer.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for the universal Python inference interface.""" + +import os +import unittest + +from common import ( + tests_path, +) + +from deepmd.infer.deep_pot import DeepPot as DeepPotTF +from deepmd.utils.convert import ( + convert_pbtxt_to_pb, +) +from deepmd_utils.infer.deep_pot import DeepPot as DeepPot + + +class TestUniversalInfer(unittest.TestCase): + @classmethod + def setUpClass(cls): + convert_pbtxt_to_pb( + str(tests_path / os.path.join("infer", "deeppot-r.pbtxt")), "deeppot.pb" + ) + + def test_deep_pot(self): + dp = DeepPot("deeppot.pb") + self.assertIsInstance(dp, DeepPotTF)