Skip to content

Commit

Permalink
fix se_a_ebd_v2 when nloc != nall (#3037)
Browse files Browse the repository at this point in the history
See also #2390 and #2505...

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Dec 7, 2023
1 parent 3c54949 commit fe488a4
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,16 +782,16 @@ def _pass_filter(
type_i = -1
if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor:
inputs_i = descrpt2r4(inputs_i, natoms)
self.atype_nloc = tf.reshape(
tf.slice(atype, [0, 0], [-1, natoms[0]]), [-1]
) # when nloc != nall, pass nloc to mask
if len(self.exclude_types):
atype_nloc = tf.reshape(
tf.slice(atype, [0, 0], [-1, natoms[0]]), [-1]
) # when nloc != nall, pass nloc to mask
mask = self.build_type_exclude_mask(
self.exclude_types,
self.ntypes,
self.sel_a,
self.ndescrpt,
atype_nloc,
self.atype_nloc,
tf.shape(inputs_i)[0],
)
inputs_i *= mask
Expand Down Expand Up @@ -956,7 +956,7 @@ def _filter_lower(
extra_embedding_index = self.nei_type_vec
else:
padding_ntypes = type_embedding.shape[0]
atype_expand = tf.reshape(self.atype, [-1, 1])
atype_expand = tf.reshape(self.atype_nloc, [-1, 1])
idx_i = tf.tile(atype_expand * padding_ntypes, [1, self.nnei])
idx_j = tf.reshape(self.nei_type_vec, [-1, self.nnei])
idx = idx_i + idx_j
Expand Down Expand Up @@ -1002,7 +1002,7 @@ def _filter_lower(
[-1, two_side_type_embedding.shape[-1]],
)

atype_expand = tf.reshape(self.atype, [-1, 1])
atype_expand = tf.reshape(self.atype_nloc, [-1, 1])
idx_i = tf.tile(atype_expand * padding_ntypes, [1, self.nnei])
idx_j = tf.reshape(self.nei_type_vec, [-1, self.nnei])
idx = idx_i + idx_j
Expand Down

0 comments on commit fe488a4

Please sign in to comment.