Skip to content

Commit

Permalink
Merge branch 'devel' into chore/stat-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml authored Apr 1, 2024
2 parents e4545cf + 2e6ab1b commit 1a143a3
Show file tree
Hide file tree
Showing 100 changed files with 894 additions and 73 deletions.
76 changes: 76 additions & 0 deletions backend/find_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import os
import site
from functools import (
lru_cache,
)
from importlib.machinery import (
FileFinder,
)
from importlib.util import (
find_spec,
)
from pathlib import (
Path,
)
from sysconfig import (
get_path,
)
from typing import (
Optional,
)


@lru_cache
def find_pytorch() -> Optional[str]:
"""Find PyTorch library.
Tries to find PyTorch in the order of:
1. Environment variable `PYTORCH_ROOT` if set
2. The current Python environment.
3. user site packages directory if enabled
4. system site packages directory (purelib)
Considering the default PyTorch package still uses old CXX11 ABI, we
cannot install it automatically.
Returns
-------
str, optional
PyTorch library path if found.
"""
if os.environ.get("DP_ENABLE_PYTORCH", "0") == "0":
return None
pt_spec = None

if (pt_spec is None or not pt_spec) and os.environ.get("PYTORCH_ROOT") is not None:
site_packages = Path(os.environ.get("PYTORCH_ROOT")).parent.absolute()
pt_spec = FileFinder(str(site_packages)).find_spec("torch")

# get pytorch spec
# note: isolated build will not work for backend
if pt_spec is None or not pt_spec:
pt_spec = find_spec("torch")

if not pt_spec and site.ENABLE_USER_SITE:
# first search TF from user site-packages before global site-packages
site_packages = site.getusersitepackages()
if site_packages:
pt_spec = FileFinder(site_packages).find_spec("torch")

if not pt_spec:
# purelib gets site-packages path
site_packages = get_path("purelib")
if site_packages:
pt_spec = FileFinder(site_packages).find_spec("torch")

# get install dir from spec
try:
pt_install_dir = pt_spec.submodule_search_locations[0] # type: ignore
# AttributeError if ft_spec is None
# TypeError if submodule_search_locations are None
# IndexError if submodule_search_locations is an empty list
except (AttributeError, TypeError, IndexError):
pt_install_dir = None
return pt_install_dir
16 changes: 16 additions & 0 deletions backend/read_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
Version,
)

from .find_pytorch import (
find_pytorch,
)
from .find_tensorflow import (
find_tensorflow,
get_tf_version,
Expand Down Expand Up @@ -99,6 +102,19 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str]:
cmake_args.append("-DENABLE_TENSORFLOW=OFF")
tf_version = None

if os.environ.get("DP_ENABLE_PYTORCH", "0") == "1":
pt_install_dir = find_pytorch()
if pt_install_dir is None:
raise RuntimeError("Cannot find installed PyTorch.")
cmake_args.extend(
[
"-DENABLE_PYTORCH=ON",
f"-DCMAKE_PREFIX_PATH={pt_install_dir}",
]
)
else:
cmake_args.append("-DENABLE_PYTORCH=OFF")

cmake_args = [
"-DBUILD_PY_IF:BOOL=TRUE",
*cmake_args,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class DescrptSeA(NativeOP, BaseDescriptor):
The cut-off radius :math:`r_c`
rcut_smth
From where the environment matrix should be smoothed :math:`r_s`
sel : list[str]
sel : list[int]
sel[i] specifies the maxmum number of type i atoms in the cut-off radius
neuron : list[int]
Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}`
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(
self,
rcut: float,
rcut_smth: float,
sel: List[str],
sel: List[int],
neuron: List[int] = [24, 48, 96],
axis_neuron: int = 8,
resnet_dt: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class DescrptSeR(NativeOP, BaseDescriptor):
The cut-off radius :math:`r_c`
rcut_smth
From where the environment matrix should be smoothed :math:`r_s`
sel : list[str]
sel : list[int]
sel[i] specifies the maxmum number of type i atoms in the cut-off radius
neuron : list[int]
Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}`
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(
self,
rcut: float,
rcut_smth: float,
sel: List[str],
sel: List[int],
neuron: List[int] = [24, 48, 96],
resnet_dt: bool = False,
trainable: bool = True,
Expand Down
124 changes: 124 additions & 0 deletions deepmd/dpmodel/utils/type_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Optional,
)

import numpy as np

from deepmd.dpmodel.common import (
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.utils.network import (
EmbeddingNet,
)
from deepmd.utils.version import (
check_version_compatibility,
)


class TypeEmbedNet(NativeOP):
r"""Type embedding network.
Parameters
----------
ntypes : int
Number of atom types
neuron : list[int]
Number of neurons in each hidden layers of the embedding net
resnet_dt
Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b)
activation_function
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
trainable
If the weights of embedding net are trainable.
seed
Random seed for initializing the network parameters.
padding
Concat the zero padding to the output, as the default embedding of empty type.
"""

def __init__(
self,
*,
ntypes: int,
neuron: List[int],
resnet_dt: bool = False,
activation_function: str = "tanh",
precision: str = "default",
trainable: bool = True,
seed: Optional[int] = None,
padding: bool = False,
) -> None:
self.ntypes = ntypes
self.neuron = neuron
self.seed = seed
self.resnet_dt = resnet_dt
self.precision = precision
self.activation_function = str(activation_function)
self.trainable = trainable
self.padding = padding
self.embedding_net = EmbeddingNet(
ntypes,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)

def call(self) -> np.ndarray:
"""Compute the type embedding network."""
embed = self.embedding_net(
np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision])
)
if self.padding:
embed = np.pad(embed, ((0, 1), (0, 0)), mode="constant")
return embed

@classmethod
def deserialize(cls, data: dict):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data_cls = data.pop("@class")
assert data_cls == "TypeEmbedNet", f"Invalid class {data_cls}"

embedding_net = EmbeddingNet.deserialize(data.pop("embedding"))
type_embedding_net = cls(**data)
type_embedding_net.embedding_net = embedding_net
return type_embedding_net

def serialize(self) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
return {
"@class": "TypeEmbedNet",
"@version": 1,
"ntypes": self.ntypes,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"precision": self.precision,
"activation_function": self.activation_function,
"trainable": self.trainable,
"padding": self.padding,
"embedding": self.embedding_net.serialize(),
}
9 changes: 9 additions & 0 deletions deepmd/pt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

# import customized OPs globally
from deepmd.pt.cxx_op import (
ENABLE_CUSTOMIZED_OP,
)

__all__ = [
"ENABLE_CUSTOMIZED_OP",
]
43 changes: 43 additions & 0 deletions deepmd/pt/cxx_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import platform

import torch

from deepmd.env import (
SHARED_LIB_DIR,
)


def load_library(module_name: str) -> bool:
"""Load OP library.
Parameters
----------
module_name : str
Name of the module
Returns
-------
bool
Whether the library is loaded successfully
"""
if platform.system() == "Windows":
ext = ".dll"
prefix = ""
else:
ext = ".so"
prefix = "lib"

module_file = (SHARED_LIB_DIR / (prefix + module_name)).with_suffix(ext).resolve()

if module_file.is_file():
torch.ops.load_library(module_file)
return True
return False


ENABLE_CUSTOMIZED_OP = load_library("deepmd_op_pt")

__all__ = [
"ENABLE_CUSTOMIZED_OP",
]
4 changes: 4 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from deepmd.main import (
parse_args,
)
from deepmd.pt.cxx_op import (
ENABLE_CUSTOMIZED_OP,
)
from deepmd.pt.infer import (
inference,
)
Expand Down Expand Up @@ -224,6 +227,7 @@ def get_backend_info(self) -> dict:
return {
"Backend": "PyTorch",
"PT ver": f"v{torch.__version__}-g{torch.version.git_version[:11]}",
"Enable custom OP": ENABLE_CUSTOMIZED_OP,
}


Expand Down
Loading

0 comments on commit 1a143a3

Please sign in to comment.