forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor]Let output or input_as_strided match exact strides (pytorch…
…#130956) Fixes pytorch#130394 TorchInductor doesn't respect original strides of outputs. It opens up optimization opportunities like changing up memory layout. But for some cases, such as the case in pytorch#130394, we do need the output match the exact stride as required. The correctness is the first priority goal. So, this PR adds a new API `ir.ExternKernel.require_exact_strides(x, exact_strides, allow_padding=False)` to fix the issue. This PR enables non-dense outputs' strides follow the strides required by semantics. The comparison between the original and after this fix for the test is the below. ```python @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 128 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex % 8 x1 = (xindex // 8) - x2 = xindex tmp0 = tl.load(in_ptr0 + (x0 + (16*x1)), xmask) tmp1 = tmp0 + tmp0 - tl.store(out_ptr0 + (x2), tmp1, xmask) + tl.store(out_ptr0 + (x0 + (16*x1)), tmp1, xmask) def call(args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (16, 8), (16, 1)) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) - buf1 = empty_strided_cuda((16, 8), (8, 1), torch.float32) + buf1 = empty_strided_cuda((16, 8), (16, 1), torch.float32) stream0 = get_raw_stream(0) triton_poi_fused_add_copy_0.run(arg0_1, buf1, 128, grid=grid(128), stream=stream0) del arg0_1 return (buf1, ) ``` The buf1 is created with exact stride required by users, and its values are written in same stride with the input. Pull Request resolved: pytorch#130956 Approved by: https://github.com/eellison, https://github.com/blaine-rister
- Loading branch information
1 parent
cdb9df5
commit a63efee
Showing
3 changed files
with
194 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters