Skip to content

Commit

Permalink
fix masked scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 20, 2025
1 parent 21a3413 commit 6ed21f5
Showing 1 changed file with 0 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,6 @@ def _masked_scatter(
assert isinstance(backup, Tensor)
if in_dim not in s.dims:
s = rf.expand_dim(s, in_dim)
if in_dim not in backup.dims:
backup = rf.expand_dim(backup, in_dim)
# Do the reverse of _masked_select above.
# First replace the dims back.
if any(d in reverse_dim_map for d in s.dims):
Expand Down

0 comments on commit 6ed21f5

Please sign in to comment.