-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Internal assert with Thunder CUDA Falcon 40b like #3505
Comments
It seemed like the This is the new error:
Repro: import torch [21/204]
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id38(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[1, 32, 2048], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
T1 = fd.define_tensor(shape=[1, 2048, 2048], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T2 = fd.define_tensor(shape=[1, 2048, 512], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T3 = fd.define_tensor(shape=[1, 2048, 512], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T4 = fd.ops.permute(T0, dims=[0, 2, 1])
T5 = fd.ops.cat([T4, T4], dim=-1)
T6 = fd.ops.cos(T5)
T7 = fd.ops.sin(T5)
S8 = fd.define_scalar(1.00000, dtype=DataType.Double)
T9 = fd.ops.mul(T6, S8)
S10 = fd.define_scalar(1.00000, dtype=DataType.Double)
T11 = fd.ops.mul(T7, S10)
T12 = fd.ops.cast(T9, dtype=DataType.BFloat16)
T13 = fd.ops.cast(T11, dtype=DataType.BFloat16)
T19 = fd.ops.reshape(T1, new_shape=[1, 2048, 32, 64])
T20 = fd.ops.permute(T19, dims=[0, 2, 1, 3])
T26 = fd.ops.reshape(T2, new_shape=[1, 2048, 8, 64])
T27 = fd.ops.permute(T26, dims=[0, 2, 1, 3])
T33 = fd.ops.reshape(T3, new_shape=[1, 2048, 8, 64])
T34 = fd.ops.permute(T33, dims=[0, 2, 1, 3])
T40 = fd.ops.broadcast_in_dim(T12, shape=[1, 1, 2048, 64], broadcast_dims=[0, 2, 3])
T46 = fd.ops.broadcast_in_dim(T13, shape=[1, 1, 2048, 64], broadcast_dims=[0, 2, 3])
T52 = fd.ops.broadcast_in_dim(T40, shape=[1, 32, 2048, 64], broadcast_dims=[0, 1, 2, 3])
T53 = fd.ops.cast(T20, dtype=DataType.Float)
T54 = fd.ops.cast(T52, dtype=DataType.Float)
T55 = fd.ops.mul(T53, T54)
T71 = fd.ops.slice(T20, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 2048, 32], strides=[1, 1, 1, 1])
T87 = fd.ops.slice(T20, start_indices=[0, 0, 0, 32], end_indices=[1, 32, 2048, 64], strides=[1, 1, 1, 1])
T88 = fd.ops.cast(T87, dtype=DataType.Float)
T89 = fd.ops.neg(T88)
T90 = fd.ops.cast(T89, dtype=DataType.BFloat16)
T91 = fd.ops.cat([T90, T71], dim=-1)
T97 = fd.ops.broadcast_in_dim(T46, shape=[1, 32, 2048, 64], broadcast_dims=[0, 1, 2, 3])
T98 = fd.ops.cast(T91, dtype=DataType.Float)
T99 = fd.ops.cast(T97, dtype=DataType.Float)
T100 = fd.ops.mul(T98, T99)
T101 = fd.ops.add(T55, T100)
T102 = fd.ops.cast(T101, dtype=DataType.BFloat16)
T108 = fd.ops.broadcast_in_dim(T40, shape=[1, 8, 2048, 64], broadcast_dims=[0, 1, 2, 3])
T109 = fd.ops.cast(T27, dtype=DataType.Float)
T110 = fd.ops.cast(T108, dtype=DataType.Float)
T111 = fd.ops.mul(T109, T110)
T127 = fd.ops.slice(T27, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 2048, 32], strides=[1, 1, 1, 1])
T143 = fd.ops.slice(T27, start_indices=[0, 0, 0, 32], end_indices=[1, 8, 2048, 64], strides=[1, 1, 1, 1])
T144 = fd.ops.cast(T143, dtype=DataType.Float)
T145 = fd.ops.neg(T144)
T146 = fd.ops.cast(T145, dtype=DataType.BFloat16)
T147 = fd.ops.cat([T146, T127], dim=-1)
T153 = fd.ops.broadcast_in_dim(T46, shape=[1, 8, 2048, 64], broadcast_dims=[0, 1, 2, 3])
T154 = fd.ops.cast(T147, dtype=DataType.Float)
T155 = fd.ops.cast(T153, dtype=DataType.Float)
T156 = fd.ops.mul(T154, T155)
T157 = fd.ops.add(T111, T156)
T158 = fd.ops.cast(T157, dtype=DataType.BFloat16)
T165 = fd.ops.broadcast_in_dim(T158, shape=[1, 8, 1, 2048, 64], broadcast_dims=[0, 1, 3, 4])
T172 = fd.ops.broadcast_in_dim(T165, shape=[1, 8, 4, 2048, 64], broadcast_dims=[0, 1, 2, 3, 4])
T178 = fd.ops.reshape(T172, new_shape=[1, 32, 2048, 64])
T185 = fd.ops.broadcast_in_dim(T34, shape=[1, 8, 1, 2048, 64], broadcast_dims=[0, 1, 3, 4])
T192 = fd.ops.broadcast_in_dim(T185, shape=[1, 8, 4, 2048, 64], broadcast_dims=[0, 1, 2, 3, 4])
T198 = fd.ops.reshape(T192, new_shape=[1, 32, 2048, 64])
fd.add_output(T34)
fd.add_output(T54)
fd.add_output(T99)
fd.add_output(T102)
fd.add_output(T110)
fd.add_output(T155)
fd.add_output(T158)
fd.add_output(T178)
fd.add_output(T198)
with FusionDefinition() as fd:
nvfuser_fusion_id38(fd)
inputs = [
torch.testing.make_tensor((1, 32, 2048), dtype=torch.float32, device='cuda:0'),
torch.testing.make_tensor((1, 2048, 2048), dtype=torch.bfloat16, device='cuda:0'),
torch.testing.make_tensor((1, 2048, 512), dtype=torch.bfloat16, device='cuda:0'),
torch.testing.make_tensor((1, 2048, 512), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs) |
I think we have multiple issues in the original script. I opened one issue in #3512 to track the smaller repro. |
A slightly smaller repro (needs to run on branch #3513 ):
Looking at @naoyam's indexing war PR #3454, sounds like we don't yet support broadcasting after the resizing, which is happening in the pattern:
|
Fixes this error of #3505 ``` Error from segmentation group 9: INTERNAL ASSERT FAILED at "/Fuser/csrc/id_model/indexing_traversal.cpp":102, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Indexing path for resize not found: iblockIdx.y376{( ceilDiv(1280, blockDim.x) )} ``` The error happens when trying to use the indexing WAR for resize that was recently added (#3454). The war itself is limited, in particular it does not work with promoted loop IDs. The limitation should be fine for the RoPE scheduling I've been working on, but it's a real issue in general. This PR avoids the issue by limiting the use of the WAR. Currently, the WAR is used whenever there's at least a single resize expr in a single math expr. That is actually overly pessimistic since the indexing issue only happens when there's multiple resize exprs that result in a cycle in the exact graph. For example, if there's only one resize, there must be no cycle, thus the indexing WAR is not necessary. This PR attempts to limit the use of the WAR by doing a little deeper analysis. The added check should entirely disable the WAR for the current default scheduling, where resize is only allowed with fusion inputs, which means there can be no multiple dependent resize exprs in a single fusion. The limitation of the WAR remains, but it does not matter for RoPE, and with this PR it should also not matter for general cases.
I believe this error is now fixed (#3515):
I haven't tested the fix with the exact same repro, though. Let me know if it still hits the error. |
The smaller repro posted above works after your fix, but it still fails with the original repro script. If we expand the smaller repro slightly to include the
|
Ah, that's of course because a cat has multiple inputs and their resize ops can be connected. So, what I mentioned here is wrong. I also encountered a different issue with the fix. I'm going to revert the change and try a simpler fix. |
@jjsjann123 I hit the other error with the repro. How to reproduce the indexing error? |
@jjsjann123 Could you test this fix with your repro? #3530 |
The other error in the repro is caused by pointwise scheduler picking the wrong reference. I'm patching it here in #3513. But don't worry about it, let me try your fix myself. |
This is a second attempt to fix #3505. The first attempt is #3515. As mentioned [here](#3505 (comment)), the first fix isn't sufficient when an expr has multiple resized inputs, like concat. The actual condition we need to check is between each producer and consumer pair, not between producers, so this second attempt is just changing how we check the condition.
nvfuser-0.2.23+git8546b62
Repro:
pytest-3 thunder/tests/test_jit_general.py::test_litgpt_variants[cuda-falcon-40b-like]
The text was updated successfully, but these errors were encountered: