Skip to content

Commit

Permalink
[inductor]Let output or input_as_strided match exact strides (pytorch…
Browse files Browse the repository at this point in the history
…#130956)

Fixes pytorch#130394

TorchInductor doesn't respect original strides of outputs. It opens up optimization opportunities like changing up memory layout. But for some cases, such as the case in pytorch#130394, we do need the output match the exact stride as required. The correctness is the first priority goal. So, this PR adds a new API `ir.ExternKernel.require_exact_strides(x, exact_strides, allow_padding=False)` to fix the issue.  This PR enables non-dense outputs' strides follow the strides required by semantics.

The comparison between the original and after this fix for the test is the below.

```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 8
    x1 = (xindex // 8)
-   x2 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (16*x1)), xmask)
    tmp1 = tmp0 + tmp0
-   tl.store(out_ptr0 + (x2), tmp1, xmask)
+   tl.store(out_ptr0 + (x0 + (16*x1)), tmp1, xmask)

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (16, 8), (16, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
-       buf1 = empty_strided_cuda((16, 8), (8, 1), torch.float32)
+       buf1 = empty_strided_cuda((16, 8), (16, 1), torch.float32)
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_copy_0.run(arg0_1, buf1, 128, grid=grid(128), stream=stream0)
        del arg0_1
    return (buf1, )
```

The buf1 is created with exact stride required by users, and its values are written in same stride with the input.

Pull Request resolved: pytorch#130956
Approved by: https://github.com/eellison, https://github.com/blaine-rister
  • Loading branch information
FindHao authored and pytorchmergebot committed Aug 24, 2024
1 parent cdb9df5 commit a63efee
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 51 deletions.
18 changes: 18 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7257,6 +7257,24 @@ def fn_channels_last(x):
[torch.randn(8, 384, 20, 20).to(memory_format=torch.channels_last)],
)

def test_exact_stride(self):
full = torch.randn((16, 16), device=self.device)
view = torch.as_strided(full, (16, 8), full.stride())

def fn(x):
result = x + x
result_strided = torch.empty_strided(
x.size(), x.stride(), device=self.device
)
result_strided[:] = result
return result_strided

self.common(fn, [view])
reference_out = fn(view)
compiled_fn = torch.compile(fn)
actual_out = compiled_fn(view)
self.assertEqual(reference_out.stride(), actual_out.stride())

def test_like_channels_last(self):
def foo():
randn = torch.randn((4, 3, 8, 8), device=self.device, dtype=torch.float32)
Expand Down
63 changes: 44 additions & 19 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def getattr_recursive(
return attr_itr


def mark_nodes_dislike_padding(g: Graph) -> None:
def mark_nodes_dislike_padding(
g: Graph, user_visible_outputs: Optional[Dict[str, None]]
) -> None:
"""
Nodes like convolution/convolution_backward want its input to be dense.
If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction.
Expand Down Expand Up @@ -233,7 +235,9 @@ def _get_overload_packet(
op = _get_overload_packet(cur)
if not op:
continue
if op in ops_dislike_padding:
if op in ops_dislike_padding or (
user_visible_outputs and cur.name in user_visible_outputs
):
cur.meta["dislike_padding"] = True

if cur.meta.get("dislike_padding", False):
Expand Down Expand Up @@ -415,11 +419,11 @@ def __init__(
self.nodes_prefer_channels_last = (
self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet()
)
mark_nodes_dislike_padding(gm.graph)
self._warned_fallback = {"aten.convolution_backward"}
self.user_visible_outputs = (
user_visible_outputs if user_visible_outputs is not None else {}
)
mark_nodes_dislike_padding(gm.graph, user_visible_outputs)
self.cache_key: str = "" # This is the cache key for the compiled artifact
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
self.cache_linemap: List[
Expand Down Expand Up @@ -1352,27 +1356,48 @@ def debug(msg: str) -> None:
n.meta["val"], torch.Tensor
):
strides = n.meta["val"].stride()
dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"])
unbacked_symbols_in_strides = len(free_unbacked_symbols(strides)) > 0
# requiring a stride order for a non-dense output wouldn't
# recreate the same strides, and would fail with view, defer for now.
if not unbacked_symbols_in_strides and dense and len(strides):
stride_order = ir.get_stride_order(strides)
if (
len(result.get_size()) == 4
and n in self.nodes_prefer_channels_last
and n.name not in self.user_visible_outputs
and not is_input_for_as_strided
):
stride_order = ir.NHWC_STRIDE_ORDER

if len(strides):
allow_padding = (
n.name not in self.user_visible_outputs
and not is_input_for_as_strided
)
result = ir.ExternKernel.require_stride_order(
result, stride_order, allow_padding=allow_padding
dense = torch._prims_common.is_non_overlapping_and_dense(
n.meta["val"]
)
unbacked_symbols_in_strides = (
len(free_unbacked_symbols(strides)) > 0
)
if (
not unbacked_symbols_in_strides
and dense
and len(result.get_size()) == 4
and n in self.nodes_prefer_channels_last
and n.name not in self.user_visible_outputs
and not is_input_for_as_strided
):
strides = ir.FlexibleLayout.stride_ordered_for_memory_format(
result.get_size(), torch.channels_last
)
if not unbacked_symbols_in_strides and len(strides):
# To avoid converting possible view ops to a copy kernel, we use the previous
# require_exact_strides to handle views. But ultimately it's better to require
# the right strides at the tensor definition.
if n.meta["val"]._is_view() or isinstance(
result.data, ir.BaseView
):
result = ir.ExternKernel.require_stride_order(
result,
ir.get_stride_order(strides),
allow_padding=allow_padding,
)
else:
strides = [
s.node.expr if isinstance(s, torch.SymInt) else s
for s in strides
]
result = ir.ExternKernel.require_exact_strides(
result, strides, allow_padding=allow_padding
)

# Realize if (1) any user need inputs realized, or (2) there is
# already too many reads and rematerializing can be bad.
Expand Down
164 changes: 132 additions & 32 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,22 @@ def welford_combine_fn(
raise NotImplementedError(f"unknown reduction_type={reduction_type}")


def significant_strides_equal(
strides1: Sequence[_IntLike], strides2: Sequence[_IntLike], size: Sequence[_IntLike]
) -> bool:
"""
Returns true if the strides are equal, ignoring dimensions of size 1 .
"""
non_1_indices = [
i
for i, dim in enumerate(size)
if V.graph.sizevars.size_hint(dim, fallback=2) != 1
]
strides1 = [V.graph.sizevars.size_hint(strides1[i]) for i in non_1_indices]
strides2 = [V.graph.sizevars.size_hint(strides2[i]) for i in non_1_indices]
return strides1 == strides2


@dataclasses.dataclass
class Reduction(Loops):
reduction_ranges: List[Expr]
Expand Down Expand Up @@ -2085,6 +2101,7 @@ def as_storage_and_layout(
want_contiguous: bool = False,
stride_order: Optional[Sequence[Union[int, Integer]]] = None,
allow_padding: bool = False,
exact_strides: Optional[Sequence[Union[int, Integer]]] = None,
) -> Tuple[StorageBox, Layout]:
"""
Try to simplify x into a StorageBox and a Layout.
Expand All @@ -2099,6 +2116,7 @@ def as_storage_and_layout(
want_contiguous=want_contiguous,
stride_order=stride_order,
allow_padding=allow_padding,
exact_strides=exact_strides,
)
if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
if freeze:
Expand All @@ -2109,6 +2127,10 @@ def as_storage_and_layout(
x.data.freeze_layout_with_stride_order(
stride_order, allow_padding=allow_padding
)
elif exact_strides is not None:
x.data.freeze_layout_with_exact_strides(
exact_strides, allow_padding=allow_padding
)
else:
x.data.decide_layout()
return x, x.data.layout
Expand Down Expand Up @@ -3202,6 +3224,19 @@ def as_stride_order(self, order, allow_padding=False):
self.offset,
)

def as_exact_strides(self, exact_strides, allow_padding=False):
new_stride = exact_strides
if self.should_pad_strides() and allow_padding:
new_stride = self._pad_strides(new_stride, self.size, self.dtype)

return FixedLayout(
self.device,
self.dtype,
self.size,
new_stride,
self.offset,
)

def as_fill_order(self, order):
new_stride = self.fill_ordered(self.size, order)
if self.should_pad_strides():
Expand Down Expand Up @@ -3428,6 +3463,12 @@ def freeze_layout_with_same_order(self, stride):
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_same_order(stride)

def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False):
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_exact_strides(
exact_strides, allow_padding=allow_padding
)

def is_zero_elements(self):
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]

Expand Down Expand Up @@ -4680,53 +4721,93 @@ def require_stride1(cls, x):
return cls.copy_input(x)

@classmethod
def require_stride_order(cls, x, order, allow_padding=False):
def require_strides(
cls,
x,
order: Optional[Sequence[int]] = None,
exact_strides: Optional[Sequence[_IntLike]] = None,
allow_padding=False,
):
assert order is not None or exact_strides is not None
if x.get_numel() == 0: # Layout doesn't matter
return x

# require x to have the layout as strided_ordered as order
# require x to have the layout
if is_storage_and_layout(x):
while isinstance(x.get_layout(), NonOwningLayout):
x = x.get_layout().view
if isinstance(x.get_layout(), FlexibleLayout):
# If the the FlexibleLayout already has the size and stride in the required order,
# freeze it to a FixedLayout by using its current size and stride.
# The behavior of using its current size and stride or the given order can be different
# if the size and stride has ambiguilty, for example for a 4D input where the iC = 1:
# size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last),
# the current size and stride already satisfies this order.
# However by freezing it to the required order, the layout will be changed to:
# size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary.

# fix flexiblelayout to be FixedLayout with stride_order
as_storage_and_layout(
x,
freeze=True,
want_contiguous=False,
stride_order=get_stride_order(
V.graph.sizevars.size_hints(x.get_layout().stride)
if order:
# If the the FlexibleLayout already has the size and stride in the required order,
# freeze it to a FixedLayout by using its current size and stride.
# The behavior of using its current size and stride or the given order can be different
# if the size and stride has ambiguilty, for example for a 4D input where the iC = 1:
# size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last),
# the current size and stride already satisfies this order.
# However by freezing it to the required order, the layout will be changed to:
# size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary.

# fix flexiblelayout to be FixedLayout with stride_order
as_storage_and_layout(
x,
freeze=True,
want_contiguous=False,
stride_order=get_stride_order(
V.graph.sizevars.size_hints(x.get_layout().stride)
)
if is_stride_order_storage_and_layout(x, order)
else order,
allow_padding=allow_padding,
)
return x
else:
# If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides.
as_storage_and_layout(
x,
freeze=True,
want_contiguous=False,
stride_order=None,
allow_padding=allow_padding,
exact_strides=exact_strides,
)
return x
elif isinstance(x.get_layout(), FixedLayout) and (
(order and x.get_layout().is_stride_ordered(order))
or (
exact_strides
and significant_strides_equal(
exact_strides, x.get_layout().stride, x.get_size()
)
if is_stride_order_storage_and_layout(x, order)
else order,
allow_padding=allow_padding,
)
return x
elif isinstance(
x.get_layout(), FixedLayout
) and x.get_layout().is_stride_ordered(order):
):
return x
elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE):
if isinstance(x.get_layout().real_layout(), FlexibleLayout):
raise AssertionError(
"the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout"
)
elif isinstance(
x.get_layout().real_layout(), FixedLayout
) and x.get_layout().real_layout().is_stride_ordered(order):
elif isinstance(x.get_layout().real_layout(), FixedLayout) and (
(order and x.get_layout().real_layout().is_stride_ordered(order))
or (
exact_strides
and significant_strides_equal(
exact_strides,
x.get_layout().real_layout().stride,
x.get_size(),
)
)
):
return x

# TODO - Storage to InputBuffer
if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order):
if isinstance(x, InputBuffer) and (
(order and x.get_layout().is_stride_ordered(order))
or (
exact_strides
and significant_strides_equal(
exact_strides, x.get_layout().stride, x.get_size()
)
)
):
return x
if (
isinstance(x, TensorBox)
Expand All @@ -4737,7 +4818,14 @@ def require_stride_order(cls, x, order, allow_padding=False):
):
try:
x.data = cls.convert_to_reinterpret_view(x.data)
return cls.require_stride_order(x, order, allow_padding=allow_padding)
if order:
return cls.require_stride_order(
x, order, allow_padding=allow_padding
)
elif exact_strides:
return cls.require_exact_strides(
x, exact_strides, allow_padding=allow_padding
)
except NotImplementedError:
pass
# Although this is a clone, inductor is good about fusing clones into previous
Expand All @@ -4749,10 +4837,22 @@ def require_stride_order(cls, x, order, allow_padding=False):
want_contiguous=False,
stride_order=order,
allow_padding=allow_padding,
exact_strides=exact_strides,
)
assert is_stride_order_storage_and_layout(x, order)
if order:
assert is_stride_order_storage_and_layout(x, order)
return x

@classmethod
def require_exact_strides(cls, x, exact_strides, allow_padding=False):
return cls.require_strides(
x, exact_strides=exact_strides, allow_padding=allow_padding
)

@classmethod
def require_stride_order(cls, x, order, allow_padding=False):
return cls.require_strides(x, order=order, allow_padding=allow_padding)

@classmethod
def require_channels_last(cls, x):
return cls.require_stride_order(x, NHWC_STRIDE_ORDER)
Expand Down

0 comments on commit a63efee

Please sign in to comment.