Skip to content

Commit

Permalink
Merge branch 'devel' into spin_rf
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Feb 19, 2024
2 parents c19f829 + ab35468 commit cda2e60
Show file tree
Hide file tree
Showing 98 changed files with 2,620 additions and 688 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_cc.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
on:
push:
branches-ignore:
- "gh-readonly-queue/*"
- "gh-readonly-queue/**"
pull_request:
merge_group:
concurrency:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Build and upload to PyPI
on:
push:
branches-ignore:
- "gh-readonly-queue/*"
- "gh-readonly-queue/**"
pull_request:
merge_group:

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/codeql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: "CodeQL"
on:
push:
branches-ignore:
- "gh-readonly-queue/*"
- "gh-readonly-queue/**"
pull_request:
schedule:
- cron: '45 2 * * 2'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/package_c.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Build C library
on:
push:
branches-ignore:
- "gh-readonly-queue/*"
- "gh-readonly-queue/**"
pull_request:
merge_group:
concurrency:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_cc.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
on:
push:
branches-ignore:
- "gh-readonly-queue/*"
- "gh-readonly-queue/**"
pull_request:
merge_group:
concurrency:
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ on:
types:
- "labeled"
# to let the PR pass the test
- "created"
- "opened"
- "reopened"
- "synchronize"
merge_group:
concurrency:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
on:
push:
branches-ignore:
- "gh-readonly-queue/*"
- "gh-readonly-queue/**"
pull_request:
merge_group:
concurrency:
Expand Down
29 changes: 29 additions & 0 deletions deepmd/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Backends.
Avoid directly importing third-party libraries in this module for performance.
"""
# copy from dpdata
from importlib import (
import_module,
metadata,
)
from pathlib import (
Path,
)

PACKAGE_BASE = "deepmd.backend"
NOT_LOADABLE = ("__init__.py",)

for module_file in Path(__file__).parent.glob("*.py"):
if module_file.name not in NOT_LOADABLE:
module_name = f".{module_file.stem}"
import_module(module_name, PACKAGE_BASE)

# https://setuptools.readthedocs.io/en/latest/userguide/entry_point.html
try:
eps = metadata.entry_points(group="deepmd.backend")
except TypeError:
eps = metadata.entry_points().get("deepmd.backend", [])
for ep in eps:
plugin = ep.load()
201 changes: 201 additions & 0 deletions deepmd/backend/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
abstractmethod,
)
from enum import (
Flag,
auto,
)
from typing import (
TYPE_CHECKING,
Callable,
ClassVar,
Dict,
List,
Type,
)

from deepmd.utils.plugin import (
Plugin,
PluginVariant,
)

if TYPE_CHECKING:
from argparse import (
Namespace,
)

from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (
NeighborStat,
)


class Backend(PluginVariant):
r"""General backend class.
Examples
--------
>>> @Backend.register("tf")
>>> @Backend.register("tensorflow")
>>> class TensorFlowBackend(Backend):
... pass
"""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a backend plugin.
Parameters
----------
key : str
the key of a backend
Returns
-------
Callable[[object], object]
the decorator to register backend
"""
return Backend.__plugins.register(key.lower())

@staticmethod
def get_backend(key: str) -> Type["Backend"]:
"""Get the backend by key.
Parameters
----------
key : str
the key of a backend
Returns
-------
Backend
the backend
"""
try:
backend = Backend.__plugins.get_plugin(key.lower())
except KeyError:
raise KeyError(f"Backend {key} is not registered.")
assert isinstance(backend, type)
return backend

@staticmethod
def get_backends() -> Dict[str, Type["Backend"]]:
"""Get all the registered backend names.
Returns
-------
list
all the registered backends
"""
return Backend.__plugins.plugins

@staticmethod
def get_backends_by_feature(
feature: "Backend.Feature",
) -> Dict[str, Type["Backend"]]:
"""Get all the registered backend names with a specific feature.
Parameters
----------
feature : Backend.Feature
the feature flag
Returns
-------
list
all the registered backends with the feature
"""
return {
key: backend
for key, backend in Backend.__plugins.plugins.items()
if backend.features & feature
}

@staticmethod
def detect_backend_by_model(filename: str) -> Type["Backend"]:
"""Detect the backend of the given model file.
Parameters
----------
filename : str
The model file name
"""
filename = str(filename).lower()
for backend in Backend.get_backends().values():
for suffix in backend.suffixes:
if filename.endswith(suffix):
return backend
raise ValueError(f"Cannot detect the backend of the model file {filename}.")

class Feature(Flag):
"""Feature flag to indicate whether the backend supports certain features."""

ENTRY_POINT = auto()
"""Support entry point hook."""
DEEP_EVAL = auto()
"""Support Deep Eval backend."""
NEIGHBOR_STAT = auto()
"""Support neighbor statistics."""

name: ClassVar[str] = "Unknown"
"""The formal name of the backend.
To be consistent, this name should be also registered in the plugin system."""

features: ClassVar[Feature] = Feature(0)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = []
"""The supported suffixes of the saved model.
The first element is considered as the default suffix."""

@abstractmethod
def is_available(self) -> bool:
"""Check if the backend is available.
Returns
-------
bool
Whether the backend is available.
"""

@property
@abstractmethod
def entry_point_hook(self) -> Callable[["Namespace"], None]:
"""The entry point hook of the backend.
Returns
-------
Callable[[Namespace], None]
The entry point hook of the backend.
"""
pass

@property
@abstractmethod
def deep_eval(self) -> Type["DeepEvalBackend"]:
"""The Deep Eval backend of the backend.
Returns
-------
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
pass

@property
@abstractmethod
def neighbor_stat(self) -> Type["NeighborStat"]:
"""The neighbor statistics of the backend.
Returns
-------
type[NeighborStat]
The neighbor statistics of the backend.
"""
pass
86 changes: 86 additions & 0 deletions deepmd/backend/dpmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
TYPE_CHECKING,
Callable,
ClassVar,
List,
Type,
)

from deepmd.backend.backend import (
Backend,
)

if TYPE_CHECKING:
from argparse import (
Namespace,
)

from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (
NeighborStat,
)


@Backend.register("dp")
@Backend.register("dpmodel")
@Backend.register("np")
@Backend.register("numpy")
class DPModelBackend(Backend):
"""DPModel backend that uses NumPy as the reference implementation."""

name = "DPModel"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = Backend.Feature.NEIGHBOR_STAT
"""The features of the backend."""
suffixes: ClassVar[List[str]] = [".dp"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
"""Check if the backend is available.
Returns
-------
bool
Whether the backend is available.
"""
return True

@property
def entry_point_hook(self) -> Callable[["Namespace"], None]:
"""The entry point hook of the backend.
Returns
-------
Callable[[Namespace], None]
The entry point hook of the backend.
"""
raise NotImplementedError(f"Unsupported backend: {self.name}")

@property
def deep_eval(self) -> Type["DeepEvalBackend"]:
"""The Deep Eval backend of the backend.
Returns
-------
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError(f"Unsupported backend: {self.name}")

@property
def neighbor_stat(self) -> Type["NeighborStat"]:
"""The neighbor statistics of the backend.
Returns
-------
type[NeighborStat]
The neighbor statistics of the backend.
"""
from deepmd.dpmodel.utils.neighbor_stat import (
NeighborStat,
)

return NeighborStat
Loading

0 comments on commit cda2e60

Please sign in to comment.