Skip to content

Commit

Permalink
use tf.case
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 11, 2024
1 parent 0dcd3b1 commit 76bf825
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions deepmd/jax/jax2tf/format_nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,13 @@ def n_nsel_greater_than_nsel():
def n_nsel_equal_nsel():
return nlist

ret = tf.cond(
n_nsel < nsel,
n_nsel_less_than_nsel,
lambda: tf.cond(
n_nsel > nsel,
n_nsel_greater_than_nsel,
n_nsel_equal_nsel,
),
ret = tf.case(
{
tf.less(n_nsel, nsel): n_nsel_less_than_nsel,
tf.greater(n_nsel, nsel): n_nsel_greater_than_nsel,
},
default=n_nsel_equal_nsel,
exclusive=True,
)
# do a reshape any way; this will tell the xla the shape without any dynamic shape
ret = tnp.reshape(ret, [n_nf, n_nloc, nsel])
Expand Down

0 comments on commit 76bf825

Please sign in to comment.