Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): support neural networks #4156

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
- run: |
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch] mpi4py
source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch,jax] mpi4py
env:
DP_VARIANT: cuda
DP_ENABLE_NATIVE_OPTIMIZATION: 1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
source/install/uv_with_retry.sh pip install --system mpich
source/install/uv_with_retry.sh pip install --system "torch==2.3.0+cpu.cxx11.abi" -i https://download.pytorch.org/whl/
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test] horovod[tensorflow-cpu] mpi4py
source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test,jax] horovod[tensorflow-cpu] mpi4py
env:
# Please note that uv has some issues with finding
# existing TensorFlow package. Currently, it uses
Expand Down
110 changes: 110 additions & 0 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from importlib.util import (
find_spec,
)
from typing import (
TYPE_CHECKING,
Callable,
ClassVar,
List,
Type,
)

from deepmd.backend.backend import (
Backend,
)

if TYPE_CHECKING:
from argparse import (

Check warning on line 18 in deepmd/backend/jax.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/jax.py#L18

Added line #L18 was not covered by tests
Namespace,
)

from deepmd.infer.deep_eval import (

Check warning on line 22 in deepmd/backend/jax.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/jax.py#L22

Added line #L22 was not covered by tests
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (

Check warning on line 25 in deepmd/backend/jax.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/jax.py#L25

Added line #L25 was not covered by tests
NeighborStat,
)


@Backend.register("jax")
class JAXBackend(Backend):
"""JAX backend."""

name = "JAX"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature(0)
# Backend.Feature.ENTRY_POINT
# | Backend.Feature.DEEP_EVAL
# | Backend.Feature.NEIGHBOR_STAT
# | Backend.Feature.IO
)
"""The features of the backend."""
njzjz marked this conversation as resolved.
Show resolved Hide resolved
suffixes: ClassVar[List[str]] = []
"""The suffixes of the backend."""

def is_available(self) -> bool:
"""Check if the backend is available.

Returns
-------
bool
Whether the backend is available.
"""
return find_spec("jax") is not None

@property
def entry_point_hook(self) -> Callable[["Namespace"], None]:
"""The entry point hook of the backend.

Returns
-------
Callable[[Namespace], None]
The entry point hook of the backend.
"""
raise NotImplementedError

Check warning on line 66 in deepmd/backend/jax.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/jax.py#L66

Added line #L66 was not covered by tests

@property
def deep_eval(self) -> Type["DeepEvalBackend"]:
"""The Deep Eval backend of the backend.

Returns
-------
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError

Check warning on line 77 in deepmd/backend/jax.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/jax.py#L77

Added line #L77 was not covered by tests

@property
def neighbor_stat(self) -> Type["NeighborStat"]:
"""The neighbor statistics of the backend.

Returns
-------
type[NeighborStat]
The neighbor statistics of the backend.
"""
raise NotImplementedError

Check warning on line 88 in deepmd/backend/jax.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/jax.py#L88

Added line #L88 was not covered by tests

@property
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.

Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
raise NotImplementedError

Check warning on line 99 in deepmd/backend/jax.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/jax.py#L99

Added line #L99 was not covered by tests

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.

Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
raise NotImplementedError

Check warning on line 110 in deepmd/backend/jax.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/jax.py#L110

Added line #L110 was not covered by tests
22 changes: 22 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
ABC,
abstractmethod,
)
from typing import (
Any,
Optional,
)

import ml_dtypes
import numpy as np
Expand Down Expand Up @@ -59,6 +63,24 @@ def __call__(self, *args, **kwargs):
return self.call(*args, **kwargs)


def to_numpy_array(x: Any) -> Optional[np.ndarray]:
"""Convert an array to a NumPy array.

Parameters
----------
x : Any
The array to be converted.

Returns
-------
Optional[np.ndarray]
The NumPy array.
"""
if x is None:
return None
return np.asarray(x)


__all__ = [
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_ENER_FLOAT_PRECISION",
Expand Down
50 changes: 38 additions & 12 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.array_api import (
support_array_api,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.seed import (
child_seed,
)
Expand Down Expand Up @@ -105,9 +112,9 @@ def serialize(self) -> dict:
The serialized layer.
"""
data = {
"w": self.w,
"b": self.b,
"idt": self.idt,
"w": to_numpy_array(self.w),
"b": to_numpy_array(self.b),
"idt": to_numpy_array(self.idt),
}
return {
"@class": "Layer",
Expand Down Expand Up @@ -215,6 +222,7 @@ def dim_in(self) -> int:
def dim_out(self) -> int:
return self.w.shape[1]

@support_array_api(version="2022.12")
def call(self, x: np.ndarray) -> np.ndarray:
"""Forward pass.

Expand All @@ -230,59 +238,77 @@ def call(self, x: np.ndarray) -> np.ndarray:
"""
if self.w is None or self.activation_function is None:
raise ValueError("w, b, and activation_function must be set")
xp = array_api_compat.array_namespace(x)
fn = get_activation_fn(self.activation_function)
y = (
np.matmul(x, self.w) + self.b
xp.matmul(x, self.w) + self.b
if self.b is not None
else np.matmul(x, self.w)
else xp.matmul(x, self.w)
)
y = fn(y)
if self.idt is not None:
y *= self.idt
if self.resnet and self.w.shape[1] == self.w.shape[0]:
y += x
elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]:
y += np.concatenate([x, x], axis=-1)
y += xp.concatenate([x, x], axis=-1)
return y


@support_array_api(version="2022.12")
def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.ndarray]:
activation_function = activation_function.lower()
if activation_function == "tanh":
return np.tanh

def fn(x):
xp = array_api_compat.array_namespace(x)
return xp.tanh(x)

return fn
elif activation_function == "relu":

def fn(x):
xp = array_api_compat.array_namespace(x)
# https://stackoverflow.com/a/47936476/9567349
return x * (x > 0)
return x * xp.astype(x > 0, x.dtype)

return fn
elif activation_function in ("gelu", "gelu_tf"):

def fn(x):
xp = array_api_compat.array_namespace(x)
# generated by GitHub Copilot
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))
return (
0.5
* x
* (1 + xp.tanh(xp.sqrt(xp.asarray(2 / xp.pi)) * (x + 0.044715 * x**3)))
)

return fn
elif activation_function == "relu6":

def fn(x):
xp = array_api_compat.array_namespace(x)
# generated by GitHub Copilot
return np.minimum(np.maximum(x, 0), 6)
return xp.where(
x < 0, xp.full_like(x, 0), xp.where(x > 6, xp.full_like(x, 6), x)
)

return fn
elif activation_function == "softplus":

def fn(x):
xp = array_api_compat.array_namespace(x)
# generated by GitHub Copilot
return np.log(1 + np.exp(x))
return xp.log(1 + xp.exp(x))

return fn
elif activation_function == "sigmoid":

def fn(x):
xp = array_api_compat.array_namespace(x)
# generated by GitHub Copilot
return 1 / (1 + np.exp(-x))
return 1 / (1 + xp.exp(-x))

return fn
elif activation_function.lower() in ("none", "linear"):
Expand Down
14 changes: 10 additions & 4 deletions deepmd/dpmodel/utils/type_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
support_array_api,
)
from deepmd.dpmodel.common import (
PRECISION_DICT,
NativeOP,
Expand Down Expand Up @@ -92,16 +96,18 @@ def __init__(
bias=self.use_tebd_bias,
)

@support_array_api(version="2022.12")
def call(self) -> np.ndarray:
"""Compute the type embedding network."""
sample_array = self.embedding_net[0]["w"]
xp = array_api_compat.array_namespace(sample_array)
if not self.use_econf_tebd:
embed = self.embedding_net(
np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision])
)
embed = self.embedding_net(xp.eye(self.ntypes, dtype=sample_array.dtype))
else:
embed = self.embedding_net(self.econf_tebd)
if self.padding:
embed = np.pad(embed, ((0, 1), (0, 0)), mode="constant")
embed_pad = xp.zeros((1, embed.shape[-1]), dtype=embed.dtype)
embed = xp.concatenate([embed, embed_pad], axis=0)
return embed

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions deepmd/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""JAX backend."""
37 changes: 37 additions & 0 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Union,
overload,
)

import numpy as np

from deepmd.jax.env import (
jnp,
)


@overload
def to_jax_array(array: np.ndarray) -> jnp.ndarray: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


@overload
def to_jax_array(array: None) -> None: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


def to_jax_array(array: Union[np.ndarray]) -> Union[jnp.ndarray]:
"""Convert a numpy array to a JAX array.

Parameters
----------
array : np.ndarray
The numpy array to convert.

Returns
-------
jnp.ndarray
The JAX tensor.
"""
if array is None:
return None
return jnp.array(array)
14 changes: 14 additions & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

__all__ = [
"jax",
"jnp",
]
1 change: 1 addition & 0 deletions deepmd/jax/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
Loading