Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor cpp] support vectorization for index_expr that depends on t…
…iling itervar or with indirect indexing (pytorch#114545) As the title, this PR enables vectorization for the situation when the the index_expr depends on vectorized itervar. There are two cases here: 1. The vectorized itervar has constant stride in the index_expr. We vectorize the index_expr with `Vectorized<int32>::arange` for this case. 2. Otherwise, we load the index_expr vector in a non-contiguous way with a loop. Below is the generated code for the first case from the test `test_concat_inner_vec`. Here `x1` is the index_expr and depends on the vectorized itervar `x1`. It has constant stride 1. We vectorized it with arange. We use `all_zero` to implement a short-cut for masks to avoid unnecessary execution of nested masked regions which are invalid. Before: ```c++ #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(155L); x1+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(x1); auto tmp1 = static_cast<long>(0); auto tmp2 = tmp0 >= tmp1; auto tmp3 = static_cast<long>(35); auto tmp4 = tmp0 < tmp3; auto tmp5 = [&] { auto tmp6 = in_ptr0[static_cast<long>(x1 + (35L*x0))]; return tmp6; } ; auto tmp7 = tmp4 ? tmp5() : static_cast<decltype(tmp5())>(0.0); auto tmp8 = tmp0 >= tmp3; auto tmp9 = static_cast<long>(155); auto tmp10 = tmp0 < tmp9; auto tmp11 = [&] { auto tmp12 = in_ptr1[static_cast<long>((-35L) + x1 + (120L*x0))]; return tmp12; } ; ... ``` After: ```c++ #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(144L); x1+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = at::vec::Vectorized<int32_t>::arange(tmp0, 1); auto tmp2 = static_cast<int>(0); auto tmp3 = at::vec::Vectorized<int>(tmp2); auto tmp4 = to_float_mask(tmp1 >= tmp3); auto tmp5 = static_cast<int>(35); auto tmp6 = at::vec::Vectorized<int>(tmp5); auto tmp7 = to_float_mask(tmp1 < tmp6); auto tmp8 = [&] { auto tmp9 = masked_load(in_ptr0 + static_cast<long>(x1 + (35L*x0)), to_float_mask(tmp7)); return tmp9; } ; auto tmp10 = [&] { if (all_zero(to_float_mask(tmp7))) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp8())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp8(), to_float_mask(tmp7)); } } () ; ... ``` Below is the generated code for the second case from the test case `test_expr_vec_non_contiguous`. Here, the index_expr is `31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))` which depends on the vectorized itervar `x2` and doesn't have constant stride. So, we load the index_expr vector with a loop. (In fact, this can be further optimized since the index_expr is invariant with the data points in the range [x2, x2+16). So it can be regarded as a scalar. This will be optimized in the follow-up PR.) The code uses `vector_lane_mask_check` to implement the masked version of non-contiguous load. Before: ```c++ #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { { float tmp_acc0 = -std::numeric_limits<float>::infinity(); for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))); auto tmp1 = static_cast<long>(2048); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (2048L*(static_cast<long>(x1) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))]; return tmp4; } ; auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); tmp_acc0 = max_propagate_nan(tmp_acc0, tmp5); } out_ptr0[static_cast<long>(x1 + (1024L*x0))] = tmp_acc0; } } } ``` After: ```c++ #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { { #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())}) float tmp_acc0 = -std::numeric_limits<float>::infinity(); at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()); for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L)) { auto tmp0 = [&] { __at_align__ std::array<int, 16> tmpbuf; #pragma GCC unroll 16 for (long x1_inner = 0; x1_inner < 16; x1_inner++) { tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L))); } return at::vec::Vectorized<int>::loadu(tmpbuf.data()); } () ; auto tmp1 = static_cast<int>(2048); auto tmp2 = at::vec::Vectorized<int>(tmp1); auto tmp3 = to_float_mask(tmp0 < tmp2); auto tmp4 = [&] { auto tmp5 = [&] { __at_align__ std::array<float, 16> tmpbuf; #pragma GCC unroll 16 for (long x1_inner = 0; x1_inner < 16; x1_inner++) { if (vector_lane_mask_check(tmp3, x1_inner)) { tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))]; } } return at::vec::Vectorized<float>::loadu(tmpbuf.data()); } () ; return tmp5; } ; auto tmp6 = [&] { if (all_zero(to_float_mask(tmp3))) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3)); } } () ; tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6); } tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0))); } } } } ``` Pull Request resolved: pytorch#114545 Approved by: https://github.com/lezcano
- Loading branch information