Skip to content

Commit

Permalink
update inference code(WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Sep 14, 2024
1 parent 0896c90 commit 2e79d68
Show file tree
Hide file tree
Showing 14 changed files with 1,730 additions and 185 deletions.
5 changes: 4 additions & 1 deletion deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,13 @@ def freeze(FLAGS):
paddle.jit.save(
model,
path=FLAGS.output,
skip_prune_program=True,
# extra_files,
)
pir_flag = os.getenv("FLAGS_enable_pir_api", "false")
suffix = "json" if pir_flag.lower() in ["true", "1"] else "pdmodel"
log.info(
f"Paddle inference model has been exported to: {FLAGS.output}.pdmodel(.pdiparams)"
f"Paddle inference model has been exported to: {FLAGS.output}.{suffix}(.pdiparams)"
)


Expand Down
3 changes: 2 additions & 1 deletion deepmd/pd/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ def _format_nlist(
axis=-1,
)

if n_nnei > nnei or extra_nlist_sort:
# if n_nnei > nnei or extra_nlist_sort:
if False:
n_nf, n_nloc, n_nnei = nlist.shape
m_real_nei = nlist >= 0
nlist = paddle.where(m_real_nei, nlist, 0)
Expand Down
45 changes: 29 additions & 16 deletions deepmd/pd/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Union,
)

import numpy as np
import paddle

from deepmd.pd.utils import (
Expand Down Expand Up @@ -101,10 +100,11 @@ def build_neighbor_list(
nall = coord.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
if np.prod(coord.shape) > 0:
# if coord.numel().item() > 0:
if True > 0:
xmax = paddle.max(coord) + 2.0 * rcut
else:
xmax = paddle.zeros([1], dtype=coord.dtype).to(device=coord.place) + 2.0 * rcut
xmax = paddle.zeros([], dtype=coord.dtype).to(device=coord.place) + 2.0 * rcut
# nf x nall
is_vir = atype < 0
coord1 = paddle.where(
Expand All @@ -118,7 +118,8 @@ def build_neighbor_list(
diff = coord1.reshape([batch_size, -1, 3]).unsqueeze(1) - coord0.reshape(
[batch_size, -1, 3]
).unsqueeze(2)
assert list(diff.shape) == [batch_size, nloc, nall, 3]
if paddle.in_dynamic_mode():
assert list(diff.shape) == [batch_size, nloc, nall, 3]
# nloc x nall
# rr = paddle.linalg.norm(diff, axis=-1)
rr = aux.norm(diff, axis=-1)
Expand Down Expand Up @@ -147,7 +148,8 @@ def _trim_mask_distinguish_nlist(
nsel = sum(sel)
# nloc x nsel
batch_size, nloc, nnei = rr.shape
assert batch_size == is_vir_cntl.shape[0]
if paddle.in_dynamic_mode():
assert batch_size == is_vir_cntl.shape[0]
if nsel <= nnei:
rr = rr[:, :, :nsel]
nlist = nlist[:, :, :nsel]
Expand All @@ -171,7 +173,8 @@ def _trim_mask_distinguish_nlist(
],
axis=-1,
)
assert list(nlist.shape) == [batch_size, nloc, nsel]
if paddle.in_dynamic_mode():
assert list(nlist.shape) == [batch_size, nloc, nsel]
nlist = paddle.where(
paddle.logical_or((rr > rcut), is_vir_cntl[:, :nloc, None]), -1, nlist
)
Expand Down Expand Up @@ -264,7 +267,8 @@ def build_directional_neighbor_list(
sel = [sel]
# nloc x nall x 3
diff = coord_neig[:, None, :, :] - coord_cntl[:, :, None, :]
assert list(diff.shape) == [batch_size, nloc_cntl, nall_neig, 3]
if paddle.in_dynamic_mode():
assert list(diff.shape) == [batch_size, nloc_cntl, nall_neig, 3]
# nloc x nall
# rr = paddle.linalg.norm(diff, axis=-1)
rr = aux.norm(diff, axis=-1)
Expand Down Expand Up @@ -372,7 +376,8 @@ def build_multiple_neighbor_list(
value being the corresponding nlist.
"""
assert len(rcuts) == len(nsels)
if paddle.in_dynamic_mode():
assert len(rcuts) == len(nsels)
if len(rcuts) == 0:
return {}
nb, nloc, nsel = nlist.shape
Expand Down Expand Up @@ -473,17 +478,25 @@ def extend_coord_with_ghosts(
# 3
nbuff = paddle.amax(nbuff, axis=0) # faster than paddle.max
nbuff_cpu = nbuff.cpu()
xi = paddle.arange(-nbuff_cpu[0], nbuff_cpu[0] + 1, 1).to(
dtype=env.GLOBAL_PD_FLOAT_PRECISION, device="cpu"
xi = (
paddle.arange(-nbuff_cpu[0], nbuff_cpu[0] + 1, 1)
.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION)
.cpu()
) # pylint: disable=no-explicit-dtype
yi = paddle.arange(-nbuff_cpu[1], nbuff_cpu[1] + 1, 1).to(
dtype=env.GLOBAL_PD_FLOAT_PRECISION, device="cpu"
yi = (
paddle.arange(-nbuff_cpu[1], nbuff_cpu[1] + 1, 1)
.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION)
.cpu()
) # pylint: disable=no-explicit-dtype
zi = paddle.arange(-nbuff_cpu[2], nbuff_cpu[2] + 1, 1).to(
dtype=env.GLOBAL_PD_FLOAT_PRECISION, device="cpu"
zi = (
paddle.arange(-nbuff_cpu[2], nbuff_cpu[2] + 1, 1)
.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION)
.cpu()
) # pylint: disable=no-explicit-dtype
eye_3 = paddle.eye(3, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to(
dtype=env.GLOBAL_PD_FLOAT_PRECISION, device="cpu"
eye_3 = (
paddle.eye(3, dtype=env.GLOBAL_PD_FLOAT_PRECISION)
.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION)
.cpu()
)
xyz = xi.reshape([-1, 1, 1, 1]) * eye_3[0]
xyz = xyz + yi.reshape([1, -1, 1, 1]) * eye_3[1]
Expand Down
8 changes: 4 additions & 4 deletions examples/water/lmp/in.lammps
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mass 1 16
mass 2 2

# See https://deepmd.rtfd.io/lammps/ for usage
pair_style deepmd frozen_model.pb
pair_style deepmd /workspace/hesensen/deepmd_backend/deepmd_paddle_new/examples/water/se_e2_a/torch_infer.pth
# If atom names (O H in this example) are not set in the pair_coeff command, the type_map defined by the training parameter will be used by default.
pair_coeff * * O H

Expand All @@ -21,7 +21,7 @@ velocity all create 330.0 23456789
fix 1 all nvt temp 330.0 330.0 0.5
timestep 0.0005
thermo_style custom step pe ke etotal temp press vol
thermo 100
dump 1 all custom 100 water.dump id type x y z
thermo 1
dump 1 all custom 1 water.dump id type x y z

run 1000
run 10
Loading

0 comments on commit 2e79d68

Please sign in to comment.