Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INTERNAL ASSERT FAILED: Allocations must be based on constant integers for local memory #2792

Closed
IvanYashchuk opened this issue Aug 15, 2024 · 3 comments
Assignees

Comments

@IvanYashchuk
Copy link
Collaborator

I get the following error when running one of Thunder benchmarks (pytest thunder/benchmarks/targets.py -k "test_litgpt_qkv_split_rope[Llama-3-8B-backward-bs1-thunder+nvfuser+torch.compile]"):

RuntimeError: false INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/executor.cpp":437, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Allocations must be based on constant integers for local memory. However, found: T10_l[ iblockIdx.x166{( ceilDiv(( ceilDiv(( ceilDiv(( 2 * ( 8 * ( 8192 * 128 ) ) ), 8) ), blockDim.x) ), 1) )}, ithreadIdx.x165{blockDim.x}, iUS167{1}, iV163{8}, iS170{( ceilDiv(( ceilDiv(( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 8) ), 2) ), 1) )}, iUS171{1}, iUR169{2} ] ca_pos( 3 ), T10_l[ iblockIdx.x166{( ceilDiv(( ceilDiv(( ceilDiv(( 2 * ( 8 * ( 8192 * 128 ) ) ), 8) ), blockDim.x) ), 1) )}, ithreadIdx.x165{blockDim.x}, iUS167{1}, iV163{8}, iS170{( ceilDiv(( ceilDiv(( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 8) ), 2) ), 1) )}, iUS171{1}, iUR169{2} ] ca_pos( 3 ),  have dynamic allocations but are placed in local memory.
Exception raised from compileFusion at /opt/pytorch/nvfuser/csrc/executor.cpp:437 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7f0494fc3d2f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x53 (0x7f04953250d3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: nvfuser::FusionExecutor::compileFusion(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::ScheduleHeuristic, long, long, long, long) + 0x1f77 (0x7f049533eb87 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
#  0: NVIDIA H100 80GB HBM3
# torch version: 2.5.0a0+gitefc6e84
# cuda version: 12.6
# nvfuser version: 0.2.10+git0d4813d
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T1 = fd.define_tensor(shape=[-1, -1, -1, -1, -1], contiguity=[True, True, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[4, 3, 2, 1, 0])
    S2 = fd.define_scalar(2, dtype=DataType.Int)
    S3 = fd.define_scalar(8, dtype=DataType.Int)
    S4 = fd.define_scalar(4, dtype=DataType.Int)
    S5 = fd.define_scalar(8192, dtype=DataType.Int)
    S6 = fd.define_scalar(128, dtype=DataType.Int)
    V7 = fd.define_vector([S2, S3, S4, S5, S6], dtype=DataType.Int)
    T8 = fd.ops.reshape(T0, new_shape=V7)
    T9 = fd.ops.cast(T8, dtype=DataType.Float)
    T10 = fd.ops.sum(T9, dims=[2], keepdim=False, dtype=DataType.Null)
    T11 = fd.ops.cast(T10, dtype=DataType.BFloat16)
    T12 = fd.ops.cast(T1, dtype=DataType.Float)
    T13 = fd.ops.sum(T12, dims=[2], keepdim=False, dtype=DataType.Null)
    T14 = fd.ops.cast(T13, dtype=DataType.BFloat16)
    fd.add_output(T11)
    fd.add_output(T14)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((67108864,), dtype=torch.bfloat16, device='cuda:0').as_strided((2, 32, 8192, 128), (33554432, 1048576, 128, 1)),
    torch.randn((67108864,), dtype=torch.bfloat16, device='cuda:0').as_strided((2, 8, 4, 8192, 128), (33554432, 4194304, 1048576, 128, 1)),
]
fd.execute(inputs)
@naoyam
Copy link
Collaborator

naoyam commented Aug 15, 2024

Another instance of #2702?

CC: @jacobhinkle

@jacobhinkle
Copy link
Collaborator

Another instance of #2702?

Yes I think this is a dupe of #2702. I verified that #2714 fixes it as long as I address the known issue with PolymorphicValue::isSame for the StructHandle special case: cf #2714 (comment).

IvanYashchuk added a commit to Lightning-AI/lightning-thunder that referenced this issue Aug 16, 2024
…st of executors (#974)

We should use an Inductor-based concatenation executor (through torch.compile) by default since it gives us a perf improvement for sections of the network that include the torch.cat operation (mainly RoPE). Previously we didn't enable it by default because there were memory leaks that were fixed (see Lightning-AI/lit-thunder-LEGACY#2194 if you have access).

In addition, this PR also avoids hitting a problem with nvFuser (NVIDIA/Fuser#2792). The backward of RoPE was generating a nvFuser fusion in between TorchCompile region because the torchcompile_cat executor wasn't marked to be able to execute torch.sum. I also added a test that verifies that only one fusion region is created.
In the (NVIDIA-internal) Mixology dashboard Llama-3-8B, Mistral-7B-v0.1, and stablecode-completion-alpha-3b models do not work with Thunder + Inductor concatenation executor and this problem should be fixed in this PR
@jacobhinkle
Copy link
Collaborator

Fixed by #2714

@jacobhinkle jacobhinkle self-assigned this Aug 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants