Skip to content

Commit

Permalink
[inductor cpp] support vectorization for index_expr that depends on t…
Browse files Browse the repository at this point in the history
…iling itervar or with indirect indexing (pytorch#114545)

As the title, this PR enables vectorization for the situation when the the index_expr depends on vectorized itervar. There are two cases here:
1. The vectorized itervar has constant stride in the index_expr. We vectorize the index_expr with `Vectorized<int32>::arange` for this case.
2. Otherwise, we load the index_expr vector in a non-contiguous way with a loop.

Below is the generated code for the first case from the test `test_concat_inner_vec`. Here `x1` is the index_expr and depends on the vectorized itervar `x1`. It has constant stride 1. We vectorized it with arange. We use `all_zero` to implement a short-cut for masks to avoid unnecessary execution of nested masked regions which are invalid.
Before:
```c++
            #pragma omp for  collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(155L); x1+=static_cast<long>(1L))
                {
                    auto tmp0 = c10::convert<long>(x1);
                    auto tmp1 = static_cast<long>(0);
                    auto tmp2 = tmp0 >= tmp1;
                    auto tmp3 = static_cast<long>(35);
                    auto tmp4 = tmp0 < tmp3;
                    auto tmp5 = [&]
                    {
                        auto tmp6 = in_ptr0[static_cast<long>(x1 + (35L*x0))];
                        return tmp6;
                    }
                    ;
                    auto tmp7 = tmp4 ? tmp5() : static_cast<decltype(tmp5())>(0.0);
                    auto tmp8 = tmp0 >= tmp3;
                    auto tmp9 = static_cast<long>(155);
                    auto tmp10 = tmp0 < tmp9;
                    auto tmp11 = [&]
                    {
                        auto tmp12 = in_ptr1[static_cast<long>((-35L) + x1 + (120L*x0))];
                        return tmp12;
                    }
                    ;
...
```
After:
```c++
            #pragma omp for
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(144L); x1+=static_cast<long>(16L))
                {
                    auto tmp0 = c10::convert<int>(x1);
                    auto tmp1 = at::vec::Vectorized<int32_t>::arange(tmp0, 1);
                    auto tmp2 = static_cast<int>(0);
                    auto tmp3 = at::vec::Vectorized<int>(tmp2);
                    auto tmp4 = to_float_mask(tmp1 >= tmp3);
                    auto tmp5 = static_cast<int>(35);
                    auto tmp6 = at::vec::Vectorized<int>(tmp5);
                    auto tmp7 = to_float_mask(tmp1 < tmp6);
                    auto tmp8 = [&]
                    {
                        auto tmp9 = masked_load(in_ptr0 + static_cast<long>(x1 + (35L*x0)), to_float_mask(tmp7));
                        return tmp9;
                    }
                    ;
                    auto tmp10 =
                    [&]
                    {
                        if (all_zero(to_float_mask(tmp7)))
                        {
                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                        }
                        else
                        {
                            return decltype(tmp8())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp8(), to_float_mask(tmp7));
                        }
                    }
                    ()
                    ;
...
```

Below is the generated code for the second case from the test case `test_expr_vec_non_contiguous`. Here, the index_expr is `31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))` which depends on the vectorized itervar `x2` and doesn't have constant stride. So, we load the index_expr vector with a loop. (In fact, this can be further optimized since the index_expr is invariant with the data points in the range [x2, x2+16). So it can be regarded as a scalar. This will be optimized in the follow-up PR.) The code uses `vector_lane_mask_check` to implement the masked version of non-contiguous load.
Before:
```c++
            #pragma omp for  collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    {
                        float tmp_acc0 = -std::numeric_limits<float>::infinity();
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L)));
                            auto tmp1 = static_cast<long>(2048);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (2048L*(static_cast<long>(x1) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
                                return tmp4;
                            }
                            ;
                            auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            tmp_acc0 = max_propagate_nan(tmp_acc0, tmp5);
                        }
                        out_ptr0[static_cast<long>(x1 + (1024L*x0))] = tmp_acc0;
                    }
                }
            }
```
After:
```c++
            #pragma omp for
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
                {
                    {
                        #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())})
                        float tmp_acc0 = -std::numeric_limits<float>::infinity();
                        at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
                        {
                            auto tmp0 =
                            [&]
                            {
                                __at_align__ std::array<int, 16> tmpbuf;
                                #pragma GCC unroll 16
                                for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                                {
                                    tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L)));
                                }
                                return at::vec::Vectorized<int>::loadu(tmpbuf.data());
                            }
                            ()
                            ;
                            auto tmp1 = static_cast<int>(2048);
                            auto tmp2 = at::vec::Vectorized<int>(tmp1);
                            auto tmp3 = to_float_mask(tmp0 < tmp2);
                            auto tmp4 = [&]
                            {
                                auto tmp5 =
                                [&]
                                {
                                    __at_align__ std::array<float, 16> tmpbuf;
                                    #pragma GCC unroll 16
                                    for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                                    {
                                        if (vector_lane_mask_check(tmp3, x1_inner))
                                        {
                                            tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
                                        }
                                    }
                                    return at::vec::Vectorized<float>::loadu(tmpbuf.data());
                                }
                                ()
                                ;
                                return tmp5;
                            }
                            ;
                            auto tmp6 =
                            [&]
                            {
                                if (all_zero(to_float_mask(tmp3)))
                                {
                                    return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                }
                                else
                                {
                                    return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3));
                                }
                            }
                            ()
                            ;
                            tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6);
                        }
                        tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0)));
                    }
                }
            }
        }
```

Pull Request resolved: pytorch#114545
Approved by: https://github.com/lezcano
  • Loading branch information
jgong5 authored and pytorchmergebot committed Dec 26, 2023
1 parent a254fbf commit ffe6f9a
Show file tree
Hide file tree
Showing 4 changed files with 367 additions and 110 deletions.
29 changes: 27 additions & 2 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,7 +1830,7 @@ def get_index():
itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")]

tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
# The moset inner loop variable is used in the index_expr
# The most inner loop variable is used in the index_expr
with CppVecKernelChecker(
args=None, num_threads=1, tiling_factor=tiling_factor
) as vec_checker:
Expand All @@ -1843,7 +1843,7 @@ def get_index():
vec_checker.ranges = ranges[:2]
submodules = {"get_index": get_index}
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
self.assertFalse(vec_checker.simd_vec)
self.assertTrue(vec_checker.simd_vec)

# Most inner loop variable irrevalant
with CppVecKernelChecker(
Expand Down Expand Up @@ -2719,6 +2719,31 @@ def forward(self, idx, x):
self.assertTrue("cvt_lowp_fp_to_fp32" not in code)
self.assertTrue("cvt_fp32_to_lowp_fp" not in code)

def test_concat_inner_vec(self):
def fn(x, y):
return F.relu(torch.cat([x, y], dim=1))

x = torch.randn(32, 35)
y = torch.randn(32, 120)
metrics.reset()
self.common(fn, (x, y))
assert metrics.generated_cpp_vec_kernel_count == 1

def test_expr_vec_non_contiguous(self):
def fn(x):
# the pattern from sebotnet33ts_256
y = torch.nn.functional.pad(x, (0, 31)).reshape(-1, 33, 63)
y = y[:, :32, 31:].reshape(4, 32, 1, 32, 32).expand(-1, -1, 32, -1, -1)
y = y.permute(0, 3, 1, 4, 2).clone(memory_format=torch.contiguous_format)
y = y.view(4, 1024, 1024)
return y.softmax(dim=-1)

x = torch.randn(128, 2048)
metrics.reset()
self.common(fn, (x,))
# 4 kernels for max, exp, sum and div
assert metrics.generated_cpp_vec_kernel_count == 4


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
21 changes: 13 additions & 8 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def clone(self):
def generate(
self,
buffer: IndentedBuffer,
expr: Union[str, CSEVariable, OpsValue],
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
*,
bounds: ValueRanges = ValueRanges.unknown(),
write=True,
Expand All @@ -832,15 +832,15 @@ def generate(
if isinstance(expr, OpsValue):
expr = expr.value

assert isinstance(expr, (str, CSEVariable)), type(expr)
assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
assert write or assignment
if isinstance(expr, CSEVariable):
# If the expressions were always created with all the information, we could
# assert expr.bounds == bounds, but sometimes the expression is created
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
expr.bounds = expr.bounds.tighten(bounds)
return expr
cache_key = expr
cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
var = self.cache.get(cache_key, None)
if not var:
var = self.newvar(bounds) if assignment else None
Expand All @@ -850,11 +850,17 @@ def generate(
V.kernel.current_node.codegen_originating_info(
buffer, only_once=True
)
if assignment:
line = f"{self.prefix}{var} = {expr}{self.suffix}"
if isinstance(expr, IndentedBuffer):
if assignment:
buffer.writeline(f"{self.prefix}{var} =")
buffer.splice(expr)
buffer.writeline(self.suffix)
else:
line = f"{expr}{self.suffix}"
buffer.writeline(line)
if assignment:
line = f"{self.prefix}{var} = {expr}{self.suffix}"
else:
line = f"{expr}{self.suffix}"
buffer.writeline(line)
else:
var.bounds = var.bounds.tighten(bounds)

Expand Down Expand Up @@ -1237,7 +1243,6 @@ class OptimizationContext:

dtype: Optional[torch.dtype] = None
ops_name: str = ""
is_most_inner_loop_irrevelant: bool = False

# Load uint8 value as float32
is_load_uint8_as_float: bool = False
Expand Down
Loading

0 comments on commit ffe6f9a

Please sign in to comment.