-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problem with large fusions from HuggingFace Llama #3537
Comments
This might be a smaller repro: import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id6(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[1, 6, 32], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False)
T1 = fd.define_tensor(shape=[1, 6, 2048], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False)
S2 = fd.ops.size(T1, dim=0)
S3 = fd.ops.size(T1, dim=1)
S4 = fd.define_scalar(32, dtype=DataType.Int)
S5 = fd.define_scalar(64, dtype=DataType.Int)
V6 = fd.define_vector([S2, S3, S4, S5], dtype=DataType.Int)
T7 = fd.ops.reshape(T1, new_shape=V6)
T8 = fd.ops.permute(T7, dims=[0, 2, 1, 3])
T9 = fd.ops.cast(T8, dtype=DataType.Float)
S10 = fd.define_scalar(0.00000, dtype=DataType.Float)
S11 = fd.define_scalar(0, dtype=DataType.Int)
S12 = fd.define_scalar(0, dtype=DataType.Int)
S13 = fd.define_scalar(0, dtype=DataType.Int)
S14 = fd.define_scalar(0, dtype=DataType.Int)
S15 = fd.define_scalar(0, dtype=DataType.Int)
S16 = fd.define_scalar(32, dtype=DataType.Int)
V18 = fd.define_vector([S15, S16, S13, S14, S11, S12], dtype=DataType.Int)
T17 = fd.ops.pad(T0, V18, S10)
S19 = fd.define_scalar(0, dtype=DataType.Int)
S20 = fd.ops.size(T0, dim=2)
S21 = fd.ops.add(S19, S20)
S22 = fd.define_scalar(0.00000, dtype=DataType.Float)
S23 = fd.define_scalar(0, dtype=DataType.Int)
S24 = fd.define_scalar(0, dtype=DataType.Int)
S25 = fd.define_scalar(0, dtype=DataType.Int)
S26 = fd.define_scalar(0, dtype=DataType.Int)
S27 = fd.define_scalar(0, dtype=DataType.Int)
V29 = fd.define_vector([S21, S27, S25, S26, S23, S24], dtype=DataType.Int)
T28 = fd.ops.pad(T0, V29, S22)
T30 = fd.ops.cat([T17, T28], dim=2, manual_padding=1)
T31 = fd.ops.cos(T30)
S32 = fd.define_scalar(1.00000, dtype=DataType.Double)
T33 = fd.ops.mul(T31, S32)
T34 = fd.ops.cast(T33, dtype=DataType.BFloat16)
T35 = fd.ops.broadcast(T34, is_broadcast_dim=[False, True, False, False])
S36 = fd.define_scalar(32, dtype=DataType.Int)
S37 = fd.ops.size(T35, dim=2)
S38 = fd.ops.size(T35, dim=3)
V39 = fd.define_vector([S2, S36, S37, S38], dtype=DataType.Int)
T40 = fd.ops.expand(T35, shape=V39)
T41 = fd.ops.cast(T40, dtype=DataType.Float)
T42 = fd.ops.mul(T9, T41)
T43 = fd.ops.sin(T30)
fd.add_output(T8)
fd.add_output(T35)
fd.add_output(T42)
fd.add_output(T43)
with FusionDefinition() as fd:
nvfuser_fusion_id6(fd)
inputs = [
torch.randn(192, dtype=torch.float32, device='cuda:0').as_strided((1, 6, 32), (192, 1, 6)),
torch.testing.make_tensor((1, 6, 2048), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs) Error:
|
This particular error seems to be because |
Small repro using segment 15 of the original repro in the description: # CUDA devices:
# 0: NVIDIA GeForce RTX 3090 Ti
# torch version: 2.6.0a0+gitffb7a08
# cuda version: 12.6
# nvfuser version: 0.2.23+git67127c9
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[1, 6, 512], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False)
S1 = fd.ops.size(T0, dim=0)
S2 = fd.ops.size(T0, dim=1)
S3 = fd.define_scalar(8, dtype=DataType.Int)
S4 = fd.define_scalar(64, dtype=DataType.Int)
V5 = fd.define_vector([S1, S2, S3, S4], dtype=DataType.Int)
T6 = fd.ops.reshape(T0, new_shape=V5)
T7 = fd.ops.permute(T6, dims=[0, 2, 1, 3])
T8 = fd.ops.broadcast(T7, is_broadcast_dim=[False, False, True, False, False])
S9 = fd.define_scalar(4, dtype=DataType.Int)
V10 = fd.define_vector([S1, S3, S9, S2, S4], dtype=DataType.Int)
T11 = fd.ops.expand(T8, shape=V10)
S12 = fd.define_scalar(32, dtype=DataType.Int)
V13 = fd.define_vector([S1, S12, S2, S4], dtype=DataType.Int)
T14 = fd.ops.reshape(T11, new_shape=V13)
fd.add_output(T7, stride_order=[3, 1, 2, 0])
fd.add_output(T14)
with FusionDefinition() as fd:
nvfuser_fusion_id0(fd)
inputs = [
torch.testing.make_tensor((1, 6, 512), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs) The error is happening at this line Fuser/csrc/scheduler/pointwise.cpp Line 890 in 9d37348
The fusion at this point looks like
The error occurs in |
Duplicate of #3512. |
The error:
Error from segmentation group 15: INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/transform_iter.cpp":546, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Error dur ing replay, a transformation was called that conflicts with an rfactor call.
The Thunder repro: (this requires enabling linears to be consumed by nvFuser)
The text was updated successfully, but these errors were encountered: