Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal assert with Thunder CUDA Falcon 40b like #3505

Closed
t-vi opened this issue Dec 2, 2024 · 11 comments · Fixed by #3530
Closed

Internal assert with Thunder CUDA Falcon 40b like #3505

t-vi opened this issue Dec 2, 2024 · 11 comments · Fixed by #3530
Assignees
Labels

Comments

@t-vi
Copy link
Contributor

t-vi commented Dec 2, 2024

nvfuser-0.2.23+git8546b62

# CUDA devices:
#  0: NVIDIA GeForce RTX 3090
#  1: NVIDIA GeForce RTX 3090
# torch version: 2.6.0a0+git1ef1b3b
# cuda version: 12.1
# nvfuser version: 0.2.23+git8546b62
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id1840(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[128, 4], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[128, 4], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[5, 5, 288], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T3 = fd.define_tensor(shape=[5, 5, 1024], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T13 = fd.ops.slice(T0, start_indices=[0, 0], end_indices=[5, 4], strides=[1, 1], manual_normalization=0)
    T23 = fd.ops.slice(T1, start_indices=[0, 0], end_indices=[5, 4], strides=[1, 1], manual_normalization=0)
    T30 = fd.ops.reshape(T2, new_shape=[5, 5, 4, 18, 4])
    T31 = fd.ops.permute(T30, dims=[0, 2, 3, 1, 4])
    T50 = fd.ops.slice(T31, start_indices=[0, 0, 0, 0, 0], end_indices=[5, 4, 16, 5, 4], strides=[1, 1, 1, 1, 1], manual_normalization=0)
    T69 = fd.ops.slice(T31, start_indices=[0, 0, 16, 0, 0], end_indices=[5, 4, 17, 5, 4], strides=[1, 1, 1, 1, 1], manual_normalization=0)
    T88 = fd.ops.slice(T31, start_indices=[0, 0, 17, 0, 0], end_indices=[5, 4, 18, 5, 4], strides=[1, 1, 1, 1, 1], manual_normalization=0)
    T95 = fd.ops.broadcast_in_dim(T69, shape=[5, 4, 16, 5, 4], broadcast_dims=[0, 1, 2, 3, 4])
    T102 = fd.ops.broadcast_in_dim(T88, shape=[5, 4, 16, 5, 4], broadcast_dims=[0, 1, 2, 3, 4])
    T108 = fd.ops.reshape(T50, new_shape=[5, 64, 5, 4])
    T114 = fd.ops.reshape(T95, new_shape=[5, 64, 5, 4])
    T120 = fd.ops.reshape(T102, new_shape=[5, 64, 5, 4])
    T136 = fd.ops.slice(T108, start_indices=[0, 0, 0, 0], end_indices=[5, 64, 5, 2], strides=[1, 1, 1, 1], manual_normalization=0)
    T152 = fd.ops.slice(T108, start_indices=[0, 0, 0, 2], end_indices=[5, 64, 5, 4], strides=[1, 1, 1, 1], manual_normalization=0)
    T153 = fd.ops.neg(T152)
    T154 = fd.ops.cat([T153, T136], dim=-1, manual_padding=0)
    T160 = fd.ops.broadcast_in_dim(T13, shape=[5, 64, 5, 4], broadcast_dims=[2, 3])
    T161 = fd.ops.mul(T108, T160)
    T167 = fd.ops.broadcast_in_dim(T23, shape=[5, 64, 5, 4], broadcast_dims=[2, 3])
    T168 = fd.ops.mul(T154, T167)
    T169 = fd.ops.add(T161, T168)
    T185 = fd.ops.slice(T114, start_indices=[0, 0, 0, 0], end_indices=[5, 64, 5, 2], strides=[1, 1, 1, 1], manual_normalization=0)
    T201 = fd.ops.slice(T114, start_indices=[0, 0, 0, 2], end_indices=[5, 64, 5, 4], strides=[1, 1, 1, 1], manual_normalization=0)
    T202 = fd.ops.neg(T201)
    T203 = fd.ops.cat([T202, T185], dim=-1, manual_padding=0)
    T204 = fd.ops.mul(T114, T160)
    T205 = fd.ops.mul(T203, T167)
    T206 = fd.ops.add(T204, T205)
    T222 = fd.ops.slice(T108, start_indices=[0, 0, 0, 0], end_indices=[5, 64, 5, 0], strides=[1, 1, 1, 1], manual_normalization=0)
    T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0)
    T239 = fd.ops.slice(T114, start_indices=[0, 0, 0, 0], end_indices=[5, 64, 5, 0], strides=[1, 1, 1, 1], manual_normalization=0)
    T240 = fd.ops.cat([T206, T239], dim=-1, manual_padding=0)
    S241 = fd.define_scalar(0.707107, dtype=DataType.Double)
    T242 = fd.ops.mul(T223, S241)
    T243 = fd.ops.permute(T240, dims=[0, 1, 3, 2])
    S244 = fd.define_scalar(0.707107, dtype=DataType.Double)
    T245 = fd.ops.mul(T243, S244)
    S246 = fd.define_scalar(1.41421, dtype=DataType.Double)
    S247 = fd.ops.reciprocal(S246)
    T248 = fd.ops.mul(T3, S247)
    T249 = fd.ops.erf(T248)
    S250 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T251 = fd.ops.mul(S250, T249)
    S252 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T253 = fd.ops.add(S252, T251)
    T254 = fd.ops.mul(T3, T253)
    fd.add_output(T120)
    fd.add_output(T160)
    fd.add_output(T167)
    fd.add_output(T242)
    fd.add_output(T245)
    fd.add_output(T254)

with FusionDefinition() as fd:
    nvfuser_fusion_id1840(fd)

inputs = [
    torch.testing.make_tensor((128, 4), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((128, 4), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((5, 5, 288), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((5, 5, 1024), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs)

Repro: pytest-3 thunder/tests/test_jit_general.py::test_litgpt_variants[cuda-falcon-40b-like]

FAILED thunder/tests/test_jit_general.py::test_litgpt_variants[cuda-falcon-40b-like] - RuntimeError:  INTERNAL ASSERT FAILED at "/Fuser/csrc/runtime/fusion_kernel_runtime.cpp":407, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Detected exception while compiling fusion segments in parallel. Error messages from all threads are printed below.

Error from segmentation group 9:  INTERNAL ASSERT FAILED at "/Fuser/csrc/id_model/indexing_traversal.cpp":102, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Indexing path for resize not found: iblockIdx.y376{( ceilDiv(1280, blockDim.x) )}
Exception raised from getExprsBetweenForResize at /Fuser/csrc/id_model/indexing_traversal.cpp:102 (most recent call first):
frame #2: <unknown function> + 0x5e3798 (0x7ff62f937798 in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x5e4492 (0x7ff62f938492 in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x5d50d5 (0x7ff62f9290d5 in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x5d8c78 (0x7ff62f92cc78 in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x5d985a (0x7ff62f92d85a in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x612c96 (0x7ff62f966c96 in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x44cbfe (0x7ff62f7a0bfe in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x4546f7 (0x7ff62f7a86f7 in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x456c8f (0x7ff62f7aac8f in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x44da4f (0x7ff62f7a1a4f in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x44da4f (0x7ff62f7a1a4f in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x456c8f (0x7ff62f7aac8f in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #14: <unknown function> + 0x44da4f (0x7ff62f7a1a4f in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #15: <unknown function> + 0x44da4f (0x7ff62f7a1a4f in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #16: <unknown function> + 0x44c56b (0x7ff62f7a056b in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #17: <unknown function> + 0x416dbe (0x7ff62f76adbe in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #18: nvfuser::GpuLower::run() + 0x239 (0x7ff62f7650c9 in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #19: nvfuser::KernelExecutor::compile(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType) + 0x85a (0x7ff62fb5158a in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #20: <unknown function> + 0x806b40 (0x7ff62fb5ab40 in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #21: <unknown function> + 0x837a35 (0x7ff62fb8ba35 in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #22: <unknown function> + 0x8386ac (0x7ff62fb8c6ac in /usr/local/lib/python3.10/dist-packages/nvfuser-0.2.23+git8546b62-py3.10-linux-x86_64.egg/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #23: c10::ThreadPool::main_loop(unsigned long) + 0x2bd (0x7ff74b8b35bd in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #24: <unknown function> + 0xdc253 (0x7ff767eb0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #25: <unknown function> + 0x94ac3 (0x7ff76ea45ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #26: clone + 0x44 (0x7ff76ead6a04 in /usr/lib/x86_64-linux-gnu/libc.so.6)


@kevinstephano
Copy link
Collaborator

kevinstephano commented Dec 2, 2024

It seemed like the cat and slice APIs are old, above.

This is the new error:

Error from segmentation group 9:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/transform_iter.cpp":580, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Error during best effort replay, a transformation was called that conflicts with an root-to-logical call.

Repro:

import torch                                                                                                                                                                                                  [21/204]
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id38(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 32, 2048], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[1, 2048, 2048], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[1, 2048, 512], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T3 = fd.define_tensor(shape=[1, 2048, 512], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T4 = fd.ops.permute(T0, dims=[0, 2, 1])
    T5 = fd.ops.cat([T4, T4], dim=-1)
    T6 = fd.ops.cos(T5)
    T7 = fd.ops.sin(T5)
    S8 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T9 = fd.ops.mul(T6, S8)
    S10 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T11 = fd.ops.mul(T7, S10)
    T12 = fd.ops.cast(T9, dtype=DataType.BFloat16)
    T13 = fd.ops.cast(T11, dtype=DataType.BFloat16)
    T19 = fd.ops.reshape(T1, new_shape=[1, 2048, 32, 64])
    T20 = fd.ops.permute(T19, dims=[0, 2, 1, 3])
    T26 = fd.ops.reshape(T2, new_shape=[1, 2048, 8, 64])
    T27 = fd.ops.permute(T26, dims=[0, 2, 1, 3])
    T33 = fd.ops.reshape(T3, new_shape=[1, 2048, 8, 64])
    T34 = fd.ops.permute(T33, dims=[0, 2, 1, 3])
    T40 = fd.ops.broadcast_in_dim(T12, shape=[1, 1, 2048, 64], broadcast_dims=[0, 2, 3])
    T46 = fd.ops.broadcast_in_dim(T13, shape=[1, 1, 2048, 64], broadcast_dims=[0, 2, 3])
    T52 = fd.ops.broadcast_in_dim(T40, shape=[1, 32, 2048, 64], broadcast_dims=[0, 1, 2, 3])
    T53 = fd.ops.cast(T20, dtype=DataType.Float)
    T54 = fd.ops.cast(T52, dtype=DataType.Float)
    T55 = fd.ops.mul(T53, T54)
    T71 = fd.ops.slice(T20, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 2048, 32], strides=[1, 1, 1, 1])
    T87 = fd.ops.slice(T20, start_indices=[0, 0, 0, 32], end_indices=[1, 32, 2048, 64], strides=[1, 1, 1, 1])
    T88 = fd.ops.cast(T87, dtype=DataType.Float)
    T89 = fd.ops.neg(T88)
    T90 = fd.ops.cast(T89, dtype=DataType.BFloat16)
    T91 = fd.ops.cat([T90, T71], dim=-1)
    T97 = fd.ops.broadcast_in_dim(T46, shape=[1, 32, 2048, 64], broadcast_dims=[0, 1, 2, 3])
    T98 = fd.ops.cast(T91, dtype=DataType.Float)
    T99 = fd.ops.cast(T97, dtype=DataType.Float)
    T100 = fd.ops.mul(T98, T99)
    T101 = fd.ops.add(T55, T100)
    T102 = fd.ops.cast(T101, dtype=DataType.BFloat16)
    T108 = fd.ops.broadcast_in_dim(T40, shape=[1, 8, 2048, 64], broadcast_dims=[0, 1, 2, 3])
    T109 = fd.ops.cast(T27, dtype=DataType.Float)
    T110 = fd.ops.cast(T108, dtype=DataType.Float)
    T111 = fd.ops.mul(T109, T110)
    T127 = fd.ops.slice(T27, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 2048, 32], strides=[1, 1, 1, 1])
    T143 = fd.ops.slice(T27, start_indices=[0, 0, 0, 32], end_indices=[1, 8, 2048, 64], strides=[1, 1, 1, 1])
    T144 = fd.ops.cast(T143, dtype=DataType.Float)
    T145 = fd.ops.neg(T144)
    T146 = fd.ops.cast(T145, dtype=DataType.BFloat16)
    T147 = fd.ops.cat([T146, T127], dim=-1)
    T153 = fd.ops.broadcast_in_dim(T46, shape=[1, 8, 2048, 64], broadcast_dims=[0, 1, 2, 3])
    T154 = fd.ops.cast(T147, dtype=DataType.Float)
    T155 = fd.ops.cast(T153, dtype=DataType.Float)
    T156 = fd.ops.mul(T154, T155)
    T157 = fd.ops.add(T111, T156)
    T158 = fd.ops.cast(T157, dtype=DataType.BFloat16)
    T165 = fd.ops.broadcast_in_dim(T158, shape=[1, 8, 1, 2048, 64], broadcast_dims=[0, 1, 3, 4])
    T172 = fd.ops.broadcast_in_dim(T165, shape=[1, 8, 4, 2048, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T178 = fd.ops.reshape(T172, new_shape=[1, 32, 2048, 64])
    T185 = fd.ops.broadcast_in_dim(T34, shape=[1, 8, 1, 2048, 64], broadcast_dims=[0, 1, 3, 4])
    T192 = fd.ops.broadcast_in_dim(T185, shape=[1, 8, 4, 2048, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T198 = fd.ops.reshape(T192, new_shape=[1, 32, 2048, 64])
   fd.add_output(T34)
    fd.add_output(T54)
    fd.add_output(T99)
    fd.add_output(T102)
    fd.add_output(T110)
    fd.add_output(T155)
    fd.add_output(T158)
    fd.add_output(T178)
    fd.add_output(T198)

with FusionDefinition() as fd:
    nvfuser_fusion_id38(fd)

inputs = [
    torch.testing.make_tensor((1, 32, 2048), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1, 2048, 2048), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((1, 2048, 512), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((1, 2048, 512), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs)

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Dec 2, 2024

I think we have multiple issues in the original script. I opened one issue in #3512 to track the smaller repro.

@jjsjann123 jjsjann123 self-assigned this Dec 2, 2024
@jjsjann123
Copy link
Collaborator

A slightly smaller repro (needs to run on branch #3513 ):

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id38(fd : FusionDefinition) -> None :
    T5 = fd.define_tensor(shape=[1, 2048, 64], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[1, 2048, 512], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T9 = fd.ops.cos(T5)
    T11 = fd.ops.sin(T5)
    T26 = fd.ops.reshape(T2, new_shape=[1, 2048, 8, 64])
    T27 = fd.ops.permute(T26, dims=[0, 2, 1, 3])
    T40 = fd.ops.broadcast_in_dim(T9, shape=[1, 1, 2048, 64], broadcast_dims=[0, 2, 3])
    T46 = fd.ops.broadcast_in_dim(T11, shape=[1, 1, 2048, 64], broadcast_dims=[0, 2, 3])
    T108 = fd.ops.broadcast_in_dim(T40, shape=[1, 8, 2048, 64], broadcast_dims=[0, 1, 2, 3])
    T153 = fd.ops.broadcast_in_dim(T46, shape=[1, 8, 2048, 64], broadcast_dims=[0, 1, 2, 3])
    T111 = fd.ops.mul(T27, T108)
    T127 = fd.ops.slice(T27, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 2048, 32], strides=[1, 1, 1, 1])
    T143 = fd.ops.slice(T27, start_indices=[0, 0, 0, 32], end_indices=[1, 8, 2048, 64], strides=[1, 1, 1, 1])
    T145 = fd.ops.neg(T143)
    T147 = fd.ops.cat([T145, T127], dim=-1)
    T156 = fd.ops.mul(T147, T153)
    T157 = fd.ops.add(T111, T156)
    T165 = fd.ops.broadcast_in_dim(T157, shape=[1, 8, 1, 2048, 64], broadcast_dims=[0, 1, 3, 4])
    T172 = fd.ops.broadcast_in_dim(T165, shape=[1, 8, 4, 2048, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T178 = fd.ops.reshape(T172, new_shape=[1, 32, 2048, 64])
    fd.add_output(T108)
    fd.add_output(T153)
    fd.add_output(T157)
    fd.add_output(T178)

with FusionDefinition() as fd:
    nvfuser_fusion_id38(fd)

inputs = [
    torch.testing.make_tensor((1, 2048, 64), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1, 2048, 512), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs)

Looking at @naoyam's indexing war PR #3454, sounds like we don't yet support broadcasting after the resizing, which is happening in the pattern:

    T147 = fd.ops.cat([T145, T127], dim=-1)
    T156 = fd.ops.mul(T147, T153)
    T157 = fd.ops.add(T111, T156)
    T165 = fd.ops.broadcast_in_dim(T157, shape=[1, 8, 1, 2048, 64], broadcast_dims=[0, 1, 3, 4])
    T172 = fd.ops.broadcast_in_dim(T165, shape=[1, 8, 4, 2048, 64], broadcast_dims=[0, 1, 2, 3, 4])

naoyam added a commit that referenced this issue Dec 4, 2024
Fixes this error of #3505 

```
Error from segmentation group 9:  INTERNAL ASSERT FAILED at "/Fuser/csrc/id_model/indexing_traversal.cpp":102, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Indexing path for resize not found: iblockIdx.y376{( ceilDiv(1280, blockDim.x) )}
```

The error happens when trying to use the indexing WAR for resize that
was recently added (#3454). The war itself is limited, in particular it
does not work with promoted loop IDs. The limitation should be fine for
the RoPE scheduling I've been working on, but it's a real issue in
general.

This PR avoids the issue by limiting the use of the WAR. Currently, the
WAR is used whenever there's at least a single resize expr in a single
math expr. That is actually overly pessimistic since the indexing issue
only happens when there's multiple resize exprs that result in a cycle
in the exact graph. For example, if there's only one resize, there must
be no cycle, thus the indexing WAR is not necessary.

This PR attempts to limit the use of the WAR by doing a little deeper
analysis. The added check should entirely disable the WAR for the
current default scheduling, where resize is only allowed with fusion
inputs, which means there can be no multiple dependent resize exprs in a
single fusion.

The limitation of the WAR remains, but it does not matter for RoPE, and
with this PR it should also not matter for general cases.
@naoyam
Copy link
Collaborator

naoyam commented Dec 4, 2024

@jjsjann123 @kevinstephano

I believe this error is now fixed (#3515):

Error from segmentation group 9: INTERNAL ASSERT FAILED at "/Fuser/csrc/id_model/indexing_traversal.cpp":102, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Indexing path for resize not found: iblockIdx.y376{( ceilDiv(1280, blockDim.x) )}

I haven't tested the fix with the exact same repro, though. Let me know if it still hits the error.

@jjsjann123
Copy link
Collaborator

The smaller repro posted above works after your fix, but it still fails with the original repro script.

If we expand the smaller repro slightly to include the T5 = fd.ops.cat([T4, T4], dim=-1) at the beginning, I'm hitting the same error.

import torch
from nvfuser import FusionDefinition, DataType
 
def nvfuser_fusion_id38(fd : FusionDefinition) -> None :
    T4 = fd.define_tensor(shape=[1, 2048, 32], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[1, 2048, 512], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.ops.cat([T4, T4], dim=-1)
    T9 = fd.ops.cos(T5)
    T11 = fd.ops.sin(T5)
    T26 = fd.ops.reshape(T2, new_shape=[1, 2048, 8, 64])
    T27 = fd.ops.permute(T26, dims=[0, 2, 1, 3])
    T40 = fd.ops.broadcast_in_dim(T9, shape=[1, 1, 2048, 64], broadcast_dims=[0, 2, 3])
    T46 = fd.ops.broadcast_in_dim(T11, shape=[1, 1, 2048, 64], broadcast_dims=[0, 2, 3])
    T108 = fd.ops.broadcast_in_dim(T40, shape=[1, 8, 2048, 64], broadcast_dims=[0, 1, 2, 3])
    T153 = fd.ops.broadcast_in_dim(T46, shape=[1, 8, 2048, 64], broadcast_dims=[0, 1, 2, 3])
    T111 = fd.ops.mul(T27, T108)
    T127 = fd.ops.slice(T27, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 2048, 32], strides=[1, 1, 1, 1])
    T143 = fd.ops.slice(T27, start_indices=[0, 0, 0, 32], end_indices=[1, 8, 2048, 64], strides=[1, 1, 1, 1])
    T145 = fd.ops.neg(T143)
    T147 = fd.ops.cat([T145, T127], dim=-1)
    T156 = fd.ops.mul(T147, T153)
    T157 = fd.ops.add(T111, T156)
    T165 = fd.ops.broadcast_in_dim(T157, shape=[1, 8, 1, 2048, 64], broadcast_dims=[0, 1, 3, 4])
    T172 = fd.ops.broadcast_in_dim(T165, shape=[1, 8, 4, 2048, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T178 = fd.ops.reshape(T172, new_shape=[1, 32, 2048, 64])
    fd.add_output(T108)
    fd.add_output(T153)
    fd.add_output(T157)
    fd.add_output(T178)
 
with FusionDefinition() as fd:
    nvfuser_fusion_id38(fd)

inputs = [
    #torch.testing.make_tensor((1, 2048, 64), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1, 2048, 32), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1, 2048, 512), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs)
RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/id_model/indexing_traversal.cpp":156, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Indexing path for resize not found: iblockIdx.x201{32}
Exception raised from getExprsBetweenForResize at /opt/pytorch/nvfuser/csrc/id_model/indexing_traversal.cpp:156 (most recent call first):

@naoyam
Copy link
Collaborator

naoyam commented Dec 5, 2024

Ah, that's of course because a cat has multiple inputs and their resize ops can be connected. So, what I mentioned here is wrong.

I also encountered a different issue with the fix. I'm going to revert the change and try a simpler fix.

@naoyam
Copy link
Collaborator

naoyam commented Dec 5, 2024

@jjsjann123 I hit the other error with the repro. How to reproduce the indexing error?

@naoyam
Copy link
Collaborator

naoyam commented Dec 5, 2024

@jjsjann123 Could you test this fix with your repro? #3530

@jjsjann123
Copy link
Collaborator

The other error in the repro is caused by pointwise scheduler picking the wrong reference. I'm patching it here in #3513.

But don't worry about it, let me try your fix myself.

@jjsjann123
Copy link
Collaborator

merging 07797a4 (pw_scheduler_reference_find_patch) and fb0b53a(resize_path_further_fix) did work for the repro Kevin posted above.

cc'ing @naoyam

@jjsjann123
Copy link
Collaborator

Linking the two PRs that's needed to close this issue
#3513 & #3530

naoyam added a commit that referenced this issue Dec 9, 2024
This is a second attempt to fix #3505. The first attempt is #3515. As
mentioned
[here](#3505 (comment)),
the first fix isn't sufficient when an expr has multiple resized inputs,
like concat. The actual condition we need to check is between each
producer and consumer pair, not between producers, so this second
attempt is just changing how we check the condition.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants