Skip to content

Commit

Permalink
use prim function instead of vanilla function, supporting
Browse files Browse the repository at this point in the history
double-backward
  • Loading branch information
HydrogenSulfate committed Sep 10, 2024
1 parent 55f71f6 commit ebdbed2
Show file tree
Hide file tree
Showing 14 changed files with 75 additions and 45 deletions.
2 changes: 1 addition & 1 deletion deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def freeze(FLAGS):
input_spec=[
InputSpec([None, 192, 3], dtype="float64", name="coord"),
InputSpec([None, 192], dtype="int64", name="atype"),
InputSpec([None, 192, 3], dtype="float64", name="box"),
InputSpec([None, 3, 3], dtype="float64", name="box"),
],
)
extra_files = {}
Expand Down
19 changes: 14 additions & 5 deletions deepmd/pd/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
drdq_reshape = drdq.reshape(
[-1, natoms * 3, self.numb_generalized_coord]
)
gen_force_label = paddle.einsum(
"bij,bi->bj", drdq_reshape, force_label_reshape_nframes
)
gen_force = paddle.einsum(
"bij,bi->bj", drdq_reshape, force_reshape_nframes

# gen_force_label = paddle.einsum(
# "bij,bi->bj", drdq_reshape, force_label_reshape_nframes
# )
gen_force_label = (
drdq_reshape * force_label_reshape_nframes.unsqueeze(-1)
).sum([-2])

# gen_force = paddle.einsum(
# "bij,bi->bj", drdq_reshape, force_reshape_nframes
# )
gen_force = (drdq_reshape * force_reshape_nframes.unsqueeze(-1)).sum(
[-2]
)

diff_gen_force = gen_force_label - gen_force
l2_gen_force_loss = paddle.square(diff_gen_force).mean()
if not self.inference:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pd/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def forward_common_atomic(
ret_dict = self.apply_out_stat(ret_dict, atype)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].to(paddle.int32)
atom_mask = ext_atom_mask[:, :nloc].astype(paddle.int32)
if self.atom_excl is not None:
atom_mask *= self.atom_excl(atype)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pd/model/descriptor/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _make_env_mat(
if radial_only:
env_mat = t0 * weight
else:
env_mat = paddle.concat([t0, t1], axis=-1) * weight
env_mat = paddle.concat([t0.astype(t1.dtype), t1], axis=-1) * weight
return env_mat, diff * mask.unsqueeze(-1).astype(diff.dtype), weight


Expand Down
8 changes: 4 additions & 4 deletions deepmd/pd/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def forward(
else:
assert self.filter_layers is not None
dmatrix = dmatrix.reshape([-1, self.nnei, 4])
dmatrix = dmatrix.to(dtype=self.prec)
dmatrix = dmatrix.astype(self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = paddle.zeros(
Expand All @@ -672,7 +672,7 @@ def forward(
# ti: center atom type, ii: neighbor type...
ii = embedding_idx // self.ntypes
ti = embedding_idx % self.ntypes
ti_mask = atype.flatten().equal(ti)
ti_mask = atype.flatten() == ti
# nfnl x nt
if ti_mask is not None:
mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]]
Expand Down Expand Up @@ -704,8 +704,8 @@ def forward(
result = result.reshape([nf, nloc, self.filter_neuron[-1] * self.axis_neuron])
rot_mat = rot_mat.reshape([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
return (
result.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
result.astype(env.GLOBAL_PD_FLOAT_PRECISION),
rot_mat.astype(env.GLOBAL_PD_FLOAT_PRECISION),
None,
None,
sw,
Expand Down
21 changes: 14 additions & 7 deletions deepmd/pd/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,10 @@ def forward(
atype_tebd=atype_tebd_nnei,
nlist_tebd=atype_tebd_nlist,
) # shape is [nframes*nall, self.neei, out_size]
input_r = paddle.nn.functional.normalize(
# input_r = paddle.nn.functional.normalize(
# dmatrix.reshape([-1, self.nnei, 4])[:, :, 1:4], axis=-1
# )
input_r = aux.normalize(
dmatrix.reshape([-1, self.nnei, 4])[:, :, 1:4], axis=-1
)
gg = self.dpa1_attention(
Expand Down Expand Up @@ -566,9 +569,10 @@ def forward(
else:
raise NotImplementedError

input_r = paddle.nn.functional.normalize(
rr.reshape([-1, self.nnei, 4])[:, :, 1:4], axis=-1
)
# input_r = paddle.nn.functional.normalize(
# rr.reshape([-1, self.nnei, 4])[:, :, 1:4], axis=-1
# )
input_r = aux.normalize(rr.reshape([-1, self.nnei, 4])[:, :, 1:4], axis=-1)
gg = self.dpa1_attention(
gg, nlist_mask, input_r=input_r, sw=sw
) # shape is [nframes*nloc, self.neei, out_size]
Expand Down Expand Up @@ -946,9 +950,12 @@ def forward(
)

if self.normalize:
q = paddle_func.normalize(q, axis=-1)
k = paddle_func.normalize(k, axis=-1)
v = paddle_func.normalize(v, axis=-1)
# q = paddle_func.normalize(q, axis=-1)
# k = paddle_func.normalize(k, axis=-1)
# v = paddle_func.normalize(v, axis=-1)
q = aux.normalize(q, axis=-1)
k = aux.normalize(k, axis=-1)
v = aux.normalize(v, axis=-1)

q = q * self.scaling
# (nf x nloc) x num_heads x head_dim x nnei
Expand Down
12 changes: 10 additions & 2 deletions deepmd/pd/model/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,11 @@ def forward(
# nfnl x nt_j x 3
rr_j = rr[:, :, 1:]
# nfnl x nt_i x nt_j
env_ij = paddle.einsum("ijm,ikm->ijk", rr_i, rr_j)
# env_ij = paddle.einsum("ijm,ikm->ijk", rr_i, rr_j)
env_ij = (
# ij1m x i1km -> ijkm -> ijk
rr_i.unsqueeze(2) * rr_j.unsqueeze(1)
).sum(-1)
# nfnl x nt_i x nt_j x 1
ss = env_ij.unsqueeze(-1)

Expand Down Expand Up @@ -850,7 +854,11 @@ def forward(
raise NotImplementedError

# nfnl x ng
res_ij = paddle.einsum("ijk,ijkm->im", env_ij, gg)
# res_ij = paddle.einsum("ijk,ijkm->im", env_ij, gg)
res_ij = (
# ijk1 x ijkm -> ijkm -> im
env_ij.unsqueeze(-1) * gg
).sum([1, 2])
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))
# nf x nl x ng
result = res_ij.reshape([nframes, nloc, self.filter_neuron[-1]])
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pd/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def fit_output_to_model_output(
atom_axis = -(len(shap) + 1)
if vdef.reducible:
kk_redu = get_reduce_name(kk)
model_ret[kk_redu] = paddle.sum(vv.to(redu_prec), axis=atom_axis)
model_ret[kk_redu] = paddle.sum(vv.astype(redu_prec), axis=atom_axis)
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
dr, dc = take_deriv(
Expand All @@ -197,7 +197,7 @@ def fit_output_to_model_output(
assert dc is not None
model_ret[kk_derv_c] = dc
model_ret[kk_derv_c + "_redu"] = paddle.sum(
model_ret[kk_derv_c].to(redu_prec), axis=1
model_ret[kk_derv_c].astype(redu_prec), axis=1
)
return model_ret

Expand Down
4 changes: 2 additions & 2 deletions deepmd/pd/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def forward(
The output.
"""
ori_prec = xx.dtype
xx = xx.to(self.prec)
xx = xx.astype(self.prec)
yy = (
paddle.matmul(xx, self.matrix.astype(self.prec)) + self.bias
if self.bias is not None
Expand All @@ -237,7 +237,7 @@ def forward(
yy += paddle.concat([xx, xx], axis=-1)
else:
yy = yy
yy = yy.to(ori_prec)
yy = yy.astype(ori_prec)
return yy

def serialize(self) -> dict:
Expand Down
9 changes: 6 additions & 3 deletions deepmd/pd/model/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,9 +1220,12 @@ def forward(
k = k.reshape([-1, self.nnei, self.hidden_dim])
v = v.reshape([-1, self.nnei, self.hidden_dim])
if self.normalize:
q = F.normalize(q, axis=-1)
k = F.normalize(k, axis=-1)
v = F.normalize(v, axis=-1)
# q = F.normalize(q, axis=-1)
# k = F.normalize(k, axis=-1)
# v = F.normalize(v, axis=-1)
q = aux.normalize(q, axis=-1)
k = aux.normalize(k, axis=-1)
v = aux.normalize(v, axis=-1)
q = q * self.scaling
k = k.transpose([0, 2, 1])
# [nframes * nloc, nnei, nnei]
Expand Down
3 changes: 2 additions & 1 deletion deepmd/pd/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def _forward_common(
else:
for type_i, ll in enumerate(self.filter_layers.networks):
mask = (atype == type_i).unsqueeze(-1)
mask.stop_gradient = True
mask = paddle.tile(mask, (1, 1, net_dim_out))
atom_property = ll(xx)
if xx_zeros is not None:
Expand All @@ -536,4 +537,4 @@ def _forward_common(
mask = self.emask(atype)
# nf x nloc x nod
outs = outs * mask[:, :, None].astype(outs.dtype)
return {self.var_name: outs.to(env.GLOBAL_PD_FLOAT_PRECISION)}
return {self.var_name: outs.astype(env.GLOBAL_PD_FLOAT_PRECISION)}
8 changes: 4 additions & 4 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
DEVICE,
JIT,
LOCAL_RANK,
NUM_WORKERS,
SAMPLER_RECORD,
enable_prim,
)
Expand Down Expand Up @@ -179,13 +180,12 @@ def get_dataloader_and_buffer(_data, _params):
_dataloader = DataLoader(
_data,
batch_sampler=paddle.io.BatchSampler(
sampler=_sampler, drop_last=False
sampler=_sampler,
drop_last=False,
),
# batch_size=None,
num_workers=0
num_workers=NUM_WORKERS
if dist.is_available()
else 0, # setting to 0 diverges the behavior of its iterator; should be >=1
# drop_last=False,
collate_fn=lambda batch: batch[0], # prevent extra conversion
# pin_memory=True,
)
Expand Down
24 changes: 13 additions & 11 deletions deepmd/pd/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ def forward(
# remove the diagonal elements
mask = paddle.eye(nloc, nall).to(dtype=paddle.bool, device=diff.place)
# diff[:, mask] = float("inf")
diff.masked_fill_(
paddle.broadcast_to(mask.unsqueeze([0, -1]), diff.shape),
paddle.to_tensor(float("inf")),
)
# diff.masked_fill_(
# paddle.broadcast_to(mask.unsqueeze([0, -1]), diff.shape),
# paddle.to_tensor(float("inf")),
# )
diff[paddle.broadcast_to(mask.unsqueeze([0, -1]), diff.shape)] = float("inf")
rr2 = paddle.sum(paddle.square(diff), axis=-1)
min_rr2 = paddle.min(rr2, axis=-1)
# count the number of neighbors
Expand All @@ -102,12 +103,12 @@ def forward(
nnei = paddle.zeros((nframes, nloc, self.ntypes), dtype=paddle.int64)
for ii in range(self.ntypes):
nnei[:, :, ii] = paddle.sum(
mask & extend_atype.equal(ii)[:, None, :], axis=-1
mask & ((extend_atype == ii)[:, None, :]), axis=-1
)
else:
mask = rr2 < self.rcut**2
# virtual types (<0) are not counted
nnei = paddle.sum(mask & extend_atype.ge(0)[:, None, :], axis=-1).reshape(
nnei = paddle.sum(mask & ((extend_atype > 0)[:, None, :]), axis=-1).reshape(
[nframes, nloc, 1]
)
max_nnei = paddle.max(nnei, axis=1)
Expand Down Expand Up @@ -184,11 +185,12 @@ def _execute(
cell
The cell.
"""
minrr2, max_nnei = self.op(
paddle.to_tensor(coord, place=DEVICE),
paddle.to_tensor(atype, place=DEVICE),
paddle.to_tensor(cell, place=DEVICE) if cell is not None else None,
)
with paddle.no_grad():
minrr2, max_nnei = self.op(
paddle.to_tensor(coord, place=DEVICE),
paddle.to_tensor(atype, place=DEVICE),
paddle.to_tensor(cell, place=DEVICE) if cell is not None else None,
)
minrr2 = minrr2.numpy()
max_nnei = max_nnei.numpy()
return minrr2, max_nnei
2 changes: 1 addition & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
else:
FLAGS = args

set_log_handles(FLAGS.log_level, FLAGS.log_path, mpi_log=None)
set_log_handles(FLAGS.log_level, Path(FLAGS.log_path), mpi_log=None)
log.debug("Log handles were successfully set")
log.info("DeePMD version: %s", __version__)

Expand Down

0 comments on commit ebdbed2

Please sign in to comment.