Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax/array-api): dpa1 #4160

Merged
merged 17 commits into from
Oct 9, 2024
241 changes: 175 additions & 66 deletions deepmd/dpmodel/descriptor/dpa1.py

Large diffs are not rendered by default.

37 changes: 24 additions & 13 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,43 @@ def _make_env_mat(
protection: float = 0.0,
):
"""Make smooth environment matrix."""
xp = array_api_compat.array_namespace(nlist)
nf, nloc, nnei = nlist.shape
# nf x nall x 3
coord = coord.reshape(nf, -1, 3)
coord = xp.reshape(coord, (nf, -1, 3))
mask = nlist >= 0
nlist = nlist * mask
nlist = nlist * xp.astype(mask, nlist.dtype)
# nf x (nloc x nnei) x 3
index = np.tile(nlist.reshape(nf, -1, 1), (1, 1, 3))
coord_r = np.take_along_axis(coord, index, 1)
# index = xp.reshape(nlist, (nf, -1, 1))
# index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3))
# coord_r = xp.take_along_axis(coord, xp.tile(index, (1, 1, 3)), 1)
# note: array api doesn't contain take_along_axis until the next version
# reimplement
nall = coord.shape[1]
index = xp.reshape(nlist, (nf * nloc * nnei,)) + xp.repeat(
(xp.arange(nf) * nall), nloc * nnei
)
coord_ = xp.reshape(coord, (-1, 3))
coord_r = xp.take(coord_, index, axis=0)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
# nf x nloc x nnei x 3
coord_r = coord_r.reshape(nf, nloc, nnei, 3)
coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3))
# nf x nloc x 1 x 3
coord_l = coord[:, :nloc].reshape(nf, -1, 1, 3)
coord_l = xp.reshape(coord[:, :nloc, ...], (nf, -1, 1, 3))
# nf x nloc x nnei x 3
diff = coord_r - coord_l
# nf x nloc x nnei
length = np.linalg.norm(diff, axis=-1, keepdims=True)
length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True)
# for index 0 nloc atom
length = length + ~np.expand_dims(mask, -1)
length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype)
t0 = 1 / (length + protection)
t1 = diff / (length + protection) ** 2
weight = compute_smooth_weight(length, ruct_smth, rcut)
weight = weight * np.expand_dims(mask, -1)
weight = weight * xp.astype(xp.expand_dims(mask, axis=-1), weight.dtype)
if radial_only:
env_mat = t0 * weight
else:
env_mat = np.concatenate([t0, t1], axis=-1) * weight
return env_mat, diff * np.expand_dims(mask, -1), weight
env_mat = xp.concat([t0, t1], axis=-1) * weight
njzjz marked this conversation as resolved.
Show resolved Hide resolved
return env_mat, diff * xp.astype(xp.expand_dims(mask, axis=-1), diff.dtype), weight


class EnvMat(NativeOP):
Expand Down Expand Up @@ -122,13 +132,14 @@ def call(
switch
The value of switch function. shape: nf x nloc x nnei
"""
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
em, diff, sw = self._call(nlist, coord_ext, radial_only)
nf, nloc, nnei = nlist.shape
atype = atype_ext[:, :nloc]
if davg is not None:
em -= davg[atype]
em -= xp.reshape(xp.take(davg, xp.reshape(atype, (-1,)), axis=0), em.shape)
if dstd is not None:
em /= dstd[atype]
em /= xp.reshape(xp.take(dstd, xp.reshape(atype, (-1,)), axis=0), em.shape)
return em, diff, sw

def _call(self, nlist, coord_ext, radial_only):
Expand Down
33 changes: 22 additions & 11 deletions deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Tuple,
)

import array_api_compat
import numpy as np


Expand Down Expand Up @@ -49,8 +50,9 @@ def build_type_exclude_mask(
otherwise being 1.

"""
xp = array_api_compat.array_namespace(atype)
nf, natom = atype.shape
return self.type_mask[atype].reshape(nf, natom)
return xp.reshape(self.type_mask[atype], (nf, natom))
njzjz marked this conversation as resolved.
Show resolved Hide resolved


class PairExcludeMask:
Expand All @@ -68,7 +70,7 @@ def __init__(
self.exclude_types.add((tt[0], tt[1]))
self.exclude_types.add((tt[1], tt[0]))
# ntypes + 1 for nlist masks
self.type_mask = np.array(
type_mask = np.array(
[
[
1 if (tt_i, tt_j) not in self.exclude_types else 0
Expand All @@ -79,7 +81,7 @@ def __init__(
dtype=np.int32,
)
# (ntypes+1 x ntypes+1)
self.type_mask = self.type_mask.reshape([-1])
self.type_mask = type_mask.reshape([-1])

def get_exclude_types(self):
return self.exclude_types
Expand All @@ -106,23 +108,32 @@ def build_type_exclude_mask(
otherwise being 1.

"""
xp = array_api_compat.array_namespace(nlist, atype_ext)
if len(self.exclude_types) == 0:
# safely return 1 if nothing is excluded.
return np.ones_like(nlist, dtype=np.int32)
return xp.ones_like(nlist, dtype=xp.int32)
nf, nloc, nnei = nlist.shape
nall = atype_ext.shape[1]
# add virtual atom of type ntypes. nf x nall+1
ae = np.concatenate(
[atype_ext, self.ntypes * np.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1
ae = xp.concat(
[atype_ext, self.ntypes * xp.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1
)
type_i = atype_ext[:, :nloc].reshape(nf, nloc) * (self.ntypes + 1)
type_i = xp.reshape(atype_ext[:, :nloc], (nf, nloc)) * (self.ntypes + 1)
# nf x nloc x nnei
index = np.where(nlist == -1, nall, nlist).reshape(nf, nloc * nnei)
type_j = np.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei)
index = xp.reshape(
xp.where(nlist == -1, xp.full_like(nlist, nall), nlist), (nf, nloc * nnei)
)
# type_j = xp.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei)
index = xp.reshape(index, [-1])
index += xp.repeat(xp.arange(nf) * (nall + 1), nloc * nnei)
type_j = xp.take(xp.reshape(ae, [-1]), index, axis=0)
type_j = xp.reshape(type_j, (nf, nloc, nnei))
type_ij = type_i[:, :, None] + type_j
# nf x (nloc x nnei)
type_ij = type_ij.reshape(nf, nloc * nnei)
mask = self.type_mask[type_ij].reshape(nf, nloc, nnei)
type_ij = xp.reshape(type_ij, (nf, nloc * nnei))
mask = xp.reshape(
xp.take(self.type_mask, xp.reshape(type_ij, (-1,))), (nf, nloc, nnei)
)
return mask

def __contains__(self, item):
Expand Down
38 changes: 23 additions & 15 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,18 @@ def deserialize(cls, data: dict) -> "NativeLayer":
num_out,
**data,
)
obj.w, obj.b, obj.idt = (
w, b, idt = (
variables["w"],
variables.get("b", None),
variables.get("idt", None),
)
if obj.b is not None:
obj.b = obj.b.ravel()
if obj.idt is not None:
obj.idt = obj.idt.ravel()
if b is not None:
b = b.ravel()
if idt is not None:
idt = idt.ravel()
obj.w = w
obj.b = b
obj.idt = idt
obj.check_shape_consistency()
return obj

Expand All @@ -177,8 +180,11 @@ def check_type_consistency(self):

def check_var(var):
if var is not None:
# array api standard doesn't provide a API to get the dtype name
# this is really hacked
dtype_name = str(var.dtype).split(".")[-1]
# assertion "float64" == "double" would fail
assert PRECISION_DICT[var.dtype.name] is PRECISION_DICT[precision]
assert PRECISION_DICT[dtype_name] is PRECISION_DICT[precision]
njzjz marked this conversation as resolved.
Show resolved Hide resolved

check_var(self.w)
check_var(self.b)
Expand Down Expand Up @@ -251,7 +257,7 @@ def call(self, x: np.ndarray) -> np.ndarray:
if self.resnet and self.w.shape[1] == self.w.shape[0]:
y += x
elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]:
y += xp.concatenate([x, x], axis=-1)
y += xp.concat([x, x], axis=-1)
return y


Expand Down Expand Up @@ -362,10 +368,11 @@ def __init__(
precision=precision,
seed=seed,
)
self.w = self.w.squeeze(0) # keep the weight shape to be [num_in]
xp = array_api_compat.array_namespace(self.w, self.b)
self.w = xp.squeeze(self.w, 0) # keep the weight shape to be [num_in]
if self.uni_init:
self.w = np.ones_like(self.w)
self.b = np.zeros_like(self.b)
self.w = xp.ones_like(self.w)
self.b = xp.zeros_like(self.b)
# only to keep consistent with other backends
self.trainable = trainable

Expand All @@ -378,8 +385,8 @@ def serialize(self) -> dict:
The serialized layer.
"""
data = {
"w": self.w,
"b": self.b,
"w": np.array(self.w),
"b": np.array(self.b),
njzjz marked this conversation as resolved.
Show resolved Hide resolved
}
return {
"@class": "LayerNorm",
Expand Down Expand Up @@ -473,11 +480,12 @@ def call(self, x: np.ndarray) -> np.ndarray:

@staticmethod
def layer_norm_numpy(x, shape, weight=None, bias=None, eps=1e-5):
xp = array_api_compat.array_namespace(x)
# mean and variance
mean = np.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True)
var = np.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True)
mean = xp.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True)
var = xp.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True)
# normalize
x_normalized = (x - mean) / np.sqrt(var + eps)
x_normalized = (x - mean) / xp.sqrt(var + eps)
# shift and scale
if weight is not None and bias is not None:
x_normalized = x_normalized * weight + bias
Expand Down
Loading
Loading