Skip to content

Commit

Permalink
Add small repro as python test
Browse files Browse the repository at this point in the history
This test is actually tougher than the big repro on #2359. It
necessitate either removing the check altogether or the path I took
which is to concretize Resized to broadcast extents as constant 1 so
that we can evaluate max(0, min(i0, 1)) as 1 without calling
simplifyExpr. A less invasive solution would be to remove the extent
check in `SqueezeOp::checkConcretization`.

We could also just remove `Expr::checkConcretization` and
`checkConcretizedUses`. They are only used for SqueezeOp currently and
are not adding much value anyway probably.
  • Loading branch information
jacobhinkle committed Jun 7, 2024
1 parent c3bf547 commit 4b10eb1
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4055,6 +4055,37 @@ def fusion_func(fd: FusionDefinition) -> None:
"FusionDefinition's execute() did not run correctly with profile enabled!"
)

# Small repro from https://github.com/NVIDIA/Fuser/issues/2359
def test_reshape_squeeze_concretization(self):
inputs = [
torch.randn((100,), dtype=torch.float32, device="cuda:0").as_strided(
(2, 5, 10), (50, 10, 1)
),
]

def fusion_func(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
shape=[-1, -1, -1],
contiguity=[True, True, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[2, 1, 0],
)
T1 = fd.ops.slice(
T0, start_indices=[0, 0, 0], end_indices=[1, 2, 4], strides=[1, 1, 1]
)
S2 = fd.define_scalar(1, dtype=DataType.Int)
S3 = fd.define_scalar(8, dtype=DataType.Int)
V4 = fd.define_vector([S2, S3], dtype=DataType.Int)
V5 = fd.define_vector([S3], dtype=DataType.Int)
T6 = fd.ops.reshape(T1, new_shape=V4)
T7 = fd.ops.reshape(T6, new_shape=V5)
# this works fine
# T7 = fd.ops.reshape(T1, new_shape=V5)
fd.add_output(T7)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)


if __name__ == "__main__":
run_tests()

0 comments on commit 4b10eb1

Please sign in to comment.