Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move checkConcretization for reshapes (#2363)
The included test is the small one provided by @jjsjann123 in #2359 and it's actually tougher than the original repro. It necessitates either removing the check that concretized squeezed extents are constant 1 or to concretize Resized to broadcast extents as constant 1 so that we can evaluate `max(0, min(i0, 1))` as `oneVal()` without calling `simplifyExpr`. I went with removing the check, which means in this example we have broadcast dimension with a dynamic shape like `max(0, min(i0, 1))`. Since we're concretizing to Broadcast, we know that dimension is not zero; if it were then we'd concretize to `Iteration` and `SqueezeOp::checkConcretization` would fail the IterType check. Still, I don't love that the expression cannot be simplified so it appears in the kernel (`i9` and `i10`): ```c++ __global__ void nvfuser_pointwise_f0_c1_r0_g1(Tensor<float, 3, 3> T0, Tensor<float, 2, 2> T4) { nvfuser_index_t i0; i0 = ((nvfuser_index_t)threadIdx.x) + (128LL * ((nvfuser_index_t)blockIdx.x)); Tensor<float, 3, 3> s1; s1.data = T0.data; s1.logical_size = T0.logical_size; s1.alloc_stride = T0.alloc_stride; Array<nvfuser_index_t, 3, 1> a2; a2 = s1.logical_size; nvfuser_index_t i3; i3 = a2[2LL]; nvfuser_index_t i4; i4 = max(0LL, (min(4LL, i3))); nvfuser_index_t i5; i5 = min(i3, 4LL); nvfuser_index_t i6; i6 = max(0LL, i5); Array<nvfuser_index_t, 3, 1> a7; a7 = s1.logical_size; nvfuser_index_t i8; i8 = a7[0LL]; nvfuser_index_t i9; i9 = min(i8, 1LL); nvfuser_index_t i10; i10 = max(0LL, i9); Array<nvfuser_index_t, 3, 1> a11; a11 = s1.logical_size; nvfuser_index_t i12; i12 = a11[1LL]; nvfuser_index_t i13; i13 = (max(0LL, (min(2LL, i12)))) * i4; nvfuser_index_t i14; i14 = i0 % i13; nvfuser_index_t i15; i15 = min(i12, 2LL); nvfuser_index_t i16; i16 = max(0LL, i15); if ((i0 < i13)) { float T1[1LL]; T1[0LL] = 0LL; T1[0LL] = T0[((((i3 * i12) * (i0 / i13)) + (i3 * (i14 / i4))) + (i14 % i4))]; float T5[1LL]; T5[0LL] = T1[0LL]; T4[i0] = T5[0LL]; } } ``` If you look closely though, `i10` is not used so it will be DCEd anyway. Still, it might be nice to concretize broadcast extents to 1 just to clean up these expressions if they appear downstream. I tried that hastily but ran into some issues so I'll leave it for another PR. Fixes #2359
- Loading branch information