Skip to content

Commit

Permalink
feat(jax/array-api): dpa1
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Sep 24, 2024
1 parent 0b72dae commit cad9034
Show file tree
Hide file tree
Showing 25 changed files with 900 additions and 170 deletions.
241 changes: 175 additions & 66 deletions deepmd/dpmodel/descriptor/dpa1.py

Large diffs are not rendered by default.

37 changes: 24 additions & 13 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,43 @@ def _make_env_mat(
protection: float = 0.0,
):
"""Make smooth environment matrix."""
xp = array_api_compat.array_namespace(nlist)
nf, nloc, nnei = nlist.shape
# nf x nall x 3
coord = coord.reshape(nf, -1, 3)
coord = xp.reshape(coord, (nf, -1, 3))
mask = nlist >= 0
nlist = nlist * mask
nlist = nlist * xp.astype(mask, nlist.dtype)
# nf x (nloc x nnei) x 3
index = np.tile(nlist.reshape(nf, -1, 1), (1, 1, 3))
coord_r = np.take_along_axis(coord, index, 1)
# index = xp.reshape(nlist, (nf, -1, 1))
# index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3))
# coord_r = xp.take_along_axis(coord, xp.tile(index, (1, 1, 3)), 1)
# note: array api doesn't contain take_along_axis until the next version
# reimplement
nall = coord.shape[1]
index = xp.reshape(nlist, (nf * nloc * nnei,)) + xp.repeat(
(xp.arange(nf) * nall), nloc * nnei
)
coord_ = xp.reshape(coord, (-1, 3))
coord_r = xp.take(coord_, index, axis=0)
# nf x nloc x nnei x 3
coord_r = coord_r.reshape(nf, nloc, nnei, 3)
coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3))
# nf x nloc x 1 x 3
coord_l = coord[:, :nloc].reshape(nf, -1, 1, 3)
coord_l = xp.reshape(coord[:, :nloc, ...], (nf, -1, 1, 3))
# nf x nloc x nnei x 3
diff = coord_r - coord_l
# nf x nloc x nnei
length = np.linalg.norm(diff, axis=-1, keepdims=True)
length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True)
# for index 0 nloc atom
length = length + ~np.expand_dims(mask, -1)
length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype)
t0 = 1 / (length + protection)
t1 = diff / (length + protection) ** 2
weight = compute_smooth_weight(length, ruct_smth, rcut)
weight = weight * np.expand_dims(mask, -1)
weight = weight * xp.astype(xp.expand_dims(mask, axis=-1), weight.dtype)
if radial_only:
env_mat = t0 * weight
else:
env_mat = np.concatenate([t0, t1], axis=-1) * weight
return env_mat, diff * np.expand_dims(mask, -1), weight
env_mat = xp.concat([t0, t1], axis=-1) * weight
return env_mat, diff * xp.astype(xp.expand_dims(mask, axis=-1), diff.dtype), weight


class EnvMat(NativeOP):
Expand Down Expand Up @@ -122,13 +132,14 @@ def call(
switch
The value of switch function. shape: nf x nloc x nnei
"""
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
em, diff, sw = self._call(nlist, coord_ext, radial_only)
nf, nloc, nnei = nlist.shape
atype = atype_ext[:, :nloc]
if davg is not None:
em -= davg[atype]
em -= xp.reshape(xp.take(davg, xp.reshape(atype, (-1,)), axis=0), em.shape)
if dstd is not None:
em /= dstd[atype]
em /= xp.reshape(xp.take(dstd, xp.reshape(atype, (-1,)), axis=0), em.shape)
return em, diff, sw

def _call(self, nlist, coord_ext, radial_only):
Expand Down
33 changes: 22 additions & 11 deletions deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Tuple,
)

import array_api_compat
import numpy as np


Expand Down Expand Up @@ -49,8 +50,9 @@ def build_type_exclude_mask(
otherwise being 1.
"""
xp = array_api_compat.array_namespace(atype)
nf, natom = atype.shape
return self.type_mask[atype].reshape(nf, natom)
return xp.reshape(self.type_mask[atype], (nf, natom))


class PairExcludeMask:
Expand All @@ -68,7 +70,7 @@ def __init__(
self.exclude_types.add((tt[0], tt[1]))
self.exclude_types.add((tt[1], tt[0]))
# ntypes + 1 for nlist masks
self.type_mask = np.array(
type_mask = np.array(
[
[
1 if (tt_i, tt_j) not in self.exclude_types else 0
Expand All @@ -79,7 +81,7 @@ def __init__(
dtype=np.int32,
)
# (ntypes+1 x ntypes+1)
self.type_mask = self.type_mask.reshape([-1])
self.type_mask = type_mask.reshape([-1])

def get_exclude_types(self):
return self.exclude_types
Expand All @@ -106,23 +108,32 @@ def build_type_exclude_mask(
otherwise being 1.
"""
xp = array_api_compat.array_namespace(nlist, atype_ext)
if len(self.exclude_types) == 0:
# safely return 1 if nothing is excluded.
return np.ones_like(nlist, dtype=np.int32)
return xp.ones_like(nlist, dtype=xp.int32)
nf, nloc, nnei = nlist.shape
nall = atype_ext.shape[1]
# add virtual atom of type ntypes. nf x nall+1
ae = np.concatenate(
[atype_ext, self.ntypes * np.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1
ae = xp.concat(
[atype_ext, self.ntypes * xp.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1
)
type_i = atype_ext[:, :nloc].reshape(nf, nloc) * (self.ntypes + 1)
type_i = xp.reshape(atype_ext[:, :nloc], (nf, nloc)) * (self.ntypes + 1)
# nf x nloc x nnei
index = np.where(nlist == -1, nall, nlist).reshape(nf, nloc * nnei)
type_j = np.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei)
index = xp.reshape(
xp.where(nlist == -1, xp.full_like(nlist, nall), nlist), (nf, nloc * nnei)
)
# type_j = xp.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei)
index = xp.reshape(index, [-1])
index += xp.repeat(xp.arange(nf) * (nall + 1), nloc * nnei)
type_j = xp.take(xp.reshape(ae, [-1]), index, axis=0)
type_j = xp.reshape(type_j, (nf, nloc, nnei))
type_ij = type_i[:, :, None] + type_j
# nf x (nloc x nnei)
type_ij = type_ij.reshape(nf, nloc * nnei)
mask = self.type_mask[type_ij].reshape(nf, nloc, nnei)
type_ij = xp.reshape(type_ij, (nf, nloc * nnei))
mask = xp.reshape(
xp.take(self.type_mask, xp.reshape(type_ij, (-1,))), (nf, nloc, nnei)
)
return mask

def __contains__(self, item):
Expand Down
38 changes: 23 additions & 15 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,18 @@ def deserialize(cls, data: dict) -> "NativeLayer":
num_out,
**data,
)
obj.w, obj.b, obj.idt = (
w, b, idt = (
variables["w"],
variables.get("b", None),
variables.get("idt", None),
)
if obj.b is not None:
obj.b = obj.b.ravel()
if obj.idt is not None:
obj.idt = obj.idt.ravel()
if b is not None:
b = b.ravel()
if idt is not None:
idt = idt.ravel()
obj.w = w
obj.b = b
obj.idt = idt
obj.check_shape_consistency()
return obj

Expand All @@ -177,8 +180,11 @@ def check_type_consistency(self):

def check_var(var):
if var is not None:
# array api standard doesn't provide a API to get the dtype name
# this is really hacked
dtype_name = str(var.dtype).split(".")[-1]
# assertion "float64" == "double" would fail
assert PRECISION_DICT[var.dtype.name] is PRECISION_DICT[precision]
assert PRECISION_DICT[dtype_name] is PRECISION_DICT[precision]

check_var(self.w)
check_var(self.b)
Expand Down Expand Up @@ -251,7 +257,7 @@ def call(self, x: np.ndarray) -> np.ndarray:
if self.resnet and self.w.shape[1] == self.w.shape[0]:
y += x
elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]:
y += xp.concatenate([x, x], axis=-1)
y += xp.concat([x, x], axis=-1)
return y


Expand Down Expand Up @@ -362,10 +368,11 @@ def __init__(
precision=precision,
seed=seed,
)
self.w = self.w.squeeze(0) # keep the weight shape to be [num_in]
xp = array_api_compat.array_namespace(self.w, self.b)
self.w = xp.squeeze(self.w, 0) # keep the weight shape to be [num_in]
if self.uni_init:
self.w = np.ones_like(self.w)
self.b = np.zeros_like(self.b)
self.w = xp.ones_like(self.w)
self.b = xp.zeros_like(self.b)
# only to keep consistent with other backends
self.trainable = trainable

Expand All @@ -378,8 +385,8 @@ def serialize(self) -> dict:
The serialized layer.
"""
data = {
"w": self.w,
"b": self.b,
"w": np.array(self.w),
"b": np.array(self.b),
}
return {
"@class": "LayerNorm",
Expand Down Expand Up @@ -473,11 +480,12 @@ def call(self, x: np.ndarray) -> np.ndarray:

@staticmethod
def layer_norm_numpy(x, shape, weight=None, bias=None, eps=1e-5):
xp = array_api_compat.array_namespace(x)
# mean and variance
mean = np.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True)
var = np.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True)
mean = xp.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True)
var = xp.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True)
# normalize
x_normalized = (x - mean) / np.sqrt(var + eps)
x_normalized = (x - mean) / xp.sqrt(var + eps)
# shift and scale
if weight is not None and bias is not None:
x_normalized = x_normalized * weight + bias
Expand Down
Loading

0 comments on commit cad9034

Please sign in to comment.