Skip to content

Commit

Permalink
use a Python implementation of take_along_axis
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Sep 25, 2024
1 parent 7de9ee3 commit d65206f
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 31 deletions.
37 changes: 37 additions & 0 deletions deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utilities for the array API."""

import array_api_compat


def support_array_api(version: str) -> callable:
"""Mark a function as supporting the specific version of the array API.
Expand All @@ -27,3 +29,38 @@ def set_version(func: callable) -> callable:
return func

return set_version


# array api adds take_along_axis in https://github.com/data-apis/array-api/pull/816
# but it hasn't been released yet
# below is a pure Python implementation of take_along_axis
# https://github.com/data-apis/array-api/issues/177#issuecomment-2093630595
def xp_swapaxes(a, axis1, axis2):
xp = array_api_compat.array_namespace(a)
axes = list(range(a.ndim))
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
a = xp.permute_dims(a, axes)
return a


def xp_take_along_axis(arr, indices, axis):
xp = array_api_compat.array_namespace(arr)
arr = xp_swapaxes(arr, axis, -1)
indices = xp_swapaxes(indices, axis, -1)

m = arr.shape[-1]
n = indices.shape[-1]

shape = list(arr.shape)
shape.pop(-1)
shape = [*shape, n]

arr = xp.reshape(arr, (-1,))
indices = xp.reshape(indices, (-1, n))

offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis]
indices = xp.reshape(offset + indices, (-1,))

out = xp.take(arr, indices)
out = xp.reshape(out, shape)
return xp_swapaxes(out, axis, -1)
18 changes: 6 additions & 12 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
EnvMat,
Expand Down Expand Up @@ -900,19 +903,10 @@ def call(
nlist_mask = nlist != -1
# nfnl x nnei x 1
sw = xp.where(nlist_mask[:, :, None], sw, xp.full_like(sw, 0.0))
nall = atype_embd_ext.shape[1]
nfidx = xp.reshape(
xp.repeat(xp.arange(nf) * nall, nloc * nnei), (nf * nloc, nnei)
)
nlist_ = nlist + nfidx
nlist_masked = xp.where(nlist_mask, nlist_, nfidx)
# index = xp.tile(xp.reshape(nlist_masked,(nf, -1, 1)), (1, 1, self.tebd_dim))
nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist))
index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim))
# nfnl x nnei x tebd_dim
# atype_embd_nlist = xp.take_along_axis(atype_embd_ext, index, axis=1)
index = xp.reshape(nlist_masked, [-1])
atype_embd_nlist = xp.take(
xp.reshape(atype_embd_ext, (nf * nall, self.tebd_dim)), index, axis=0
)
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
Expand Down
14 changes: 3 additions & 11 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from deepmd.dpmodel.array_api import (
support_array_api,
xp_take_along_axis,
)


Expand Down Expand Up @@ -51,17 +52,8 @@ def _make_env_mat(
mask = nlist >= 0
nlist = nlist * xp.astype(mask, nlist.dtype)
# nf x (nloc x nnei) x 3
# index = xp.reshape(nlist, (nf, -1, 1))
# index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3))
# coord_r = xp.take_along_axis(coord, xp.tile(index, (1, 1, 3)), 1)
# note: array api doesn't contain take_along_axis until the next version
# reimplement
nall = coord.shape[1]
index = xp.reshape(nlist, (nf * nloc * nnei,)) + xp.repeat(
(xp.arange(nf) * nall), nloc * nnei
)
coord_ = xp.reshape(coord, (-1, 3))
coord_r = xp.take(coord_, index, axis=0)
index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3))
coord_r = xp_take_along_axis(coord, index, 1)
# nf x nloc x nnei x 3
coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3))
# nf x nloc x 1 x 3
Expand Down
10 changes: 5 additions & 5 deletions deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)


class AtomExcludeMask:
"""Computes the type exclusion mask for atoms."""
Expand Down Expand Up @@ -123,11 +127,7 @@ def build_type_exclude_mask(
index = xp.reshape(
xp.where(nlist == -1, xp.full_like(nlist, nall), nlist), (nf, nloc * nnei)
)
# type_j = xp.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei)
index = xp.reshape(index, [-1])
index += xp.repeat(xp.arange(nf) * (nall + 1), nloc * nnei)
type_j = xp.take(xp.reshape(ae, [-1]), index, axis=0)
type_j = xp.reshape(type_j, (nf, nloc, nnei))
type_j = xp_take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei)
type_ij = type_i[:, :, None] + type_j
# nf x (nloc x nnei)
type_ij = xp.reshape(type_ij, (nf, nloc * nnei))
Expand Down
10 changes: 7 additions & 3 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)

from .region import (
normalize_coord,
to_face_distance,
Expand Down Expand Up @@ -165,17 +169,17 @@ def nlist_distinguish_types(
mask = nlist == -1
tnlist_0 = nlist.copy()
tnlist_0[mask] = 0
tnlist = xp.take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze()
tnlist = xp_take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze()
tnlist = xp.where(mask, -1, tnlist)
snsel = tnlist.shape[2]
for ii, ss in enumerate(sel):
pick_mask = (tnlist == ii).astype(xp.int32)
sorted_indices = xp.argsort(-pick_mask, kind="stable", axis=-1)
pick_mask_sorted = -xp.sort(-pick_mask, axis=-1)
inlist = xp.take_along_axis(nlist, sorted_indices, axis=2)
inlist = xp_take_along_axis(nlist, sorted_indices, axis=2)
inlist = xp.where(~pick_mask_sorted.astype(bool), -1, inlist)
ret_nlist.append(xp.split(inlist, [ss, snsel - ss], axis=-1)[0])
ret = xp.concatenate(ret_nlist, axis=-1)
ret = xp.concat(ret_nlist, axis=-1)
return ret


Expand Down

0 comments on commit d65206f

Please sign in to comment.