Skip to content

Commit

Permalink
array api
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 1, 2024
1 parent 1cac90b commit 84c1900
Show file tree
Hide file tree
Showing 9 changed files with 348 additions and 73 deletions.
20 changes: 14 additions & 6 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from deepmd.dpmodel import (
NativeOP,
)
from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
Expand Down Expand Up @@ -794,7 +797,7 @@ def call(
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
nall = coord_ext.reshape(nframes, -1).shape[1] // 3
nall = xp.reshape(coord_ext, (nframes, -1)).shape[1] // 3
# nlists
nlist_dict = build_multiple_neighbor_list(
coord_ext,
Expand All @@ -803,7 +806,10 @@ def call(
self.nsel_list,
)
# repinit
g1_ext = self.type_embedding.call()[atype_ext]
g1_ext = xp.reshape(
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
(nframes, nall, self.tebd_dim),
)
g1_inp = g1_ext[:, :nloc, :]
g1, _, _, _, _ = self.repinit(
nlist_dict[
Expand All @@ -828,16 +834,18 @@ def call(
g1_ext,
mapping,
)
g1 = xp.concatenate([g1, g1_three_body], axis=-1)
g1 = xp.concat([g1, g1_three_body], axis=-1)
# linear to change shape
g1 = self.g1_shape_tranform(g1)
if self.add_tebd_to_repinit_out:
assert self.tebd_transform is not None
g1 = g1 + self.tebd_transform(g1_inp)
# mapping g1
assert mapping is not None
mapping_ext = xp.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1]))
g1_ext = xp.take_along_axis(g1, mapping_ext, axis=1)
mapping_ext = xp.tile(
xp.reshape(mapping, (nframes, nall, 1)), (1, 1, g1.shape[-1])
)
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
nlist_dict[
Expand All @@ -851,7 +859,7 @@ def call(
mapping,
)
if self.concat_output_tebd:
g1 = xp.concatenate([g1, g1_inp], axis=-1)
g1 = xp.concat([g1, g1_inp], axis=-1)
return g1, rot_mat, g2, h2, sw

def serialize(self) -> dict:
Expand Down
Loading

0 comments on commit 84c1900

Please sign in to comment.