Skip to content

Commit

Permalink
Move checkConcretization for reshapes (#2363)
Browse files Browse the repository at this point in the history
The included test is the small one provided by @jjsjann123 in #2359 and
it's actually tougher than the original repro. It necessitates either
removing the check that concretized squeezed extents are constant 1 or
to concretize Resized to broadcast extents as constant 1 so that we can
evaluate `max(0, min(i0, 1))` as `oneVal()` without calling
`simplifyExpr`. I went with removing the check, which means in this
example we have broadcast dimension with a dynamic shape like `max(0,
min(i0, 1))`. Since we're concretizing to Broadcast, we know that
dimension is not zero; if it were then we'd concretize to `Iteration`
and `SqueezeOp::checkConcretization` would fail the IterType check.
Still, I don't love that the expression cannot be simplified so it
appears in the kernel (`i9` and `i10`):
```c++
__global__ void nvfuser_pointwise_f0_c1_r0_g1(Tensor<float, 3, 3> T0, Tensor<float, 2, 2> T4) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (128LL * ((nvfuser_index_t)blockIdx.x));
  Tensor<float, 3, 3> s1;
  s1.data = T0.data;
  s1.logical_size = T0.logical_size;
  s1.alloc_stride = T0.alloc_stride;
  Array<nvfuser_index_t, 3, 1> a2;
  a2 = s1.logical_size;
  nvfuser_index_t i3;
  i3 = a2[2LL];
  nvfuser_index_t i4;
  i4 = max(0LL, (min(4LL, i3)));
  nvfuser_index_t i5;
  i5 = min(i3, 4LL);
  nvfuser_index_t i6;
  i6 = max(0LL, i5);
  Array<nvfuser_index_t, 3, 1> a7;
  a7 = s1.logical_size;
  nvfuser_index_t i8;
  i8 = a7[0LL];
  nvfuser_index_t i9;
  i9 = min(i8, 1LL);
  nvfuser_index_t i10;
  i10 = max(0LL, i9);
  Array<nvfuser_index_t, 3, 1> a11;
  a11 = s1.logical_size;
  nvfuser_index_t i12;
  i12 = a11[1LL];
  nvfuser_index_t i13;
  i13 = (max(0LL, (min(2LL, i12)))) * i4;
  nvfuser_index_t i14;
  i14 = i0 % i13;
  nvfuser_index_t i15;
  i15 = min(i12, 2LL);
  nvfuser_index_t i16;
  i16 = max(0LL, i15);
  if ((i0 < i13)) {
    float T1[1LL];
    T1[0LL] = 0LL;
    T1[0LL]
       = T0[((((i3 * i12) * (i0 / i13)) + (i3 * (i14 / i4))) + (i14 % i4))];
    float T5[1LL];
    T5[0LL]
       = T1[0LL];
    T4[i0]
       = T5[0LL];
  }
}
```
If you look closely though, `i10` is not used so it will be DCEd anyway.
Still, it might be nice to concretize broadcast extents to 1 just to
clean up these expressions if they appear downstream. I tried that
hastily but ran into some issues so I'll leave it for another PR.

Fixes #2359
  • Loading branch information
jacobhinkle authored Jun 8, 2024
1 parent f5c8c9c commit b7e3694
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 8 deletions.
10 changes: 7 additions & 3 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,9 +739,10 @@ void DynamicTransformConcretizer::concretizeReshape() {

auto concrete_reshape_out_tv = reshape(inp_tv, view_analysis);

// We do the replacement directly here, but we must still check that the
// replacement is valid
checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv);
// NOTE: The replacement might not yet actually be valid. For example, if
// inp_tv contains Symbolic domains that need to be squeezed, this check
// would fail at this point. So we skip checkConcretizedUses here and
// perform it later in mutate(TensorView*).

// Extent expressions often change when concretizing a reshape. Here we
// replace these in all downstream expressions so that the Fusion looks just
Expand Down Expand Up @@ -1031,6 +1032,9 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {
// TensorDomain and then TensorView
mutate(tv->domain());
OptOutMutator::mutate(tv);
// Check concretization is valid after we've done the replacement. See note
// about squeeze inside concretizeReshape above.
checkConcretizedUses(tv, tv);
}

// Almost an exact copy of OptOutMutator::mutate(TensorDomain*), but
Expand Down
8 changes: 3 additions & 5 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1409,11 +1409,9 @@ void SqueezeOp::checkConcretization(Val* old_val, Val* new_val) const {
new_id->toString(),
" must concretize to IterType::Broadcast but found ",
new_id->toString());
NVF_CHECK(
!new_id->hasExpandedExtent(), "Can not squeeze expanded dimension(s).");
NVF_CHECK(
new_id->extent()->isOneInt(),
"Can not squeeze dimension(s) with size != 1.");
// NOTE: we do not check the extent here. Even if the extent is not a const
// scalar we know that it would simplify to 1 for these inputs, since this
// IterDomain is concretized to Broadcast.
}
}

Expand Down
31 changes: 31 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4055,6 +4055,37 @@ def fusion_func(fd: FusionDefinition) -> None:
"FusionDefinition's execute() did not run correctly with profile enabled!"
)

# Small repro from https://github.com/NVIDIA/Fuser/issues/2359
def test_reshape_squeeze_concretization(self):
inputs = [
torch.randn((100,), dtype=torch.float32, device="cuda:0").as_strided(
(2, 5, 10), (50, 10, 1)
),
]

def fusion_func(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
shape=[-1, -1, -1],
contiguity=[True, True, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[2, 1, 0],
)
T1 = fd.ops.slice(
T0, start_indices=[0, 0, 0], end_indices=[1, 2, 4], strides=[1, 1, 1]
)
S2 = fd.define_scalar(1, dtype=DataType.Int)
S3 = fd.define_scalar(8, dtype=DataType.Int)
V4 = fd.define_vector([S2, S3], dtype=DataType.Int)
V5 = fd.define_vector([S3], dtype=DataType.Int)
T6 = fd.ops.reshape(T1, new_shape=V4)
T7 = fd.ops.reshape(T6, new_shape=V5)
# this works fine
# T7 = fd.ops.reshape(T1, new_shape=V5)
fd.add_output(T7)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)


if __name__ == "__main__":
run_tests()

0 comments on commit b7e3694

Please sign in to comment.