Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with large fusions from HuggingFace Llama #3537

Closed
kevinstephano opened this issue Dec 6, 2024 · 5 comments
Closed

Problem with large fusions from HuggingFace Llama #3537

kevinstephano opened this issue Dec 6, 2024 · 5 comments
Labels

Comments

@kevinstephano
Copy link
Collaborator

kevinstephano commented Dec 6, 2024

The error:

Error from segmentation group 15:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/transform_iter.cpp":546, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Error dur
ing replay, a transformation was called that conflicts with an rfactor call.
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id2(fd : FusionDefinition) -> None :                                                                                                                                                                   T0 = fd.define_tensor(shape=[1, 6], contiguity=[None, True], dtype=DataType.Int, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[1, 1, 6, 6], contiguity=[None, None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T2 = fd.define_tensor(shape=[1, 32, 6], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])                                                                                    T3 = fd.define_tensor(shape=[1, 6, 2048], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T4 = fd.define_tensor(shape=[2048], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])                                                                                                       T5 = fd.define_tensor(shape=[2048, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T6 = fd.define_tensor(shape=[512, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T7 = fd.define_tensor(shape=[512, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])                                                                                         T13 = fd.ops.broadcast_in_dim(T0, shape=[1, 1, 1, 6], broadcast_dims=[0, 3])
    T19 = fd.ops.broadcast_in_dim(T13, shape=[1, 1, 6, 6], broadcast_dims=[0, 1, 2, 3])
    T20 = fd.ops.cast(T1, dtype=DataType.Float)
    T21 = fd.ops.cast(T19, dtype=DataType.Float)
    T22 = fd.ops.add(T20, T21)
    T23 = fd.ops.cast(T22, dtype=DataType.BFloat16)
    S24 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T25 = fd.ops.eq(T23, S24)
    S26 = fd.define_scalar(-3.38953e+38, dtype=DataType.Double)
    T27 = fd.ops.where(T25, S26, T1)
    T28 = fd.ops.permute(T2, dims=[0, 2, 1])
    T29 = fd.ops.cat([T28, T28], dim=-1, manual_padding=0)
    T30 = fd.ops.cos(T29)
    T31 = fd.ops.sin(T29)
    S32 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T33 = fd.ops.mul(T30, S32)
    S34 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T35 = fd.ops.mul(T31, S34)
    T36 = fd.ops.cast(T33, dtype=DataType.BFloat16)
    T37 = fd.ops.cast(T35, dtype=DataType.BFloat16)                                                                                                                                                                       
    T38 = fd.ops.cast(T3, dtype=DataType.Float)
    S39 = fd.define_scalar(2.00000, dtype=DataType.Double)                                                                                                                                                                
    T40 = fd.ops.pow(T38, S39)
    T41 = fd.ops.sum(T40, dims=[2], keepdim=False, dtype=DataType.Null)
    T46 = fd.ops.broadcast_in_dim(T41, shape=[1, 6, 1], broadcast_dims=[0, 1])                                                                                                                                            
    S47 = fd.define_scalar(2048.00, dtype=DataType.Double)
    S48 = fd.ops.reciprocal(S47)                                                                                                                                                                                          
    T49 = fd.ops.mul(T46, S48)                                                                                                                                                                                            
    S50 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T51 = fd.ops.add(T49, S50)
    T52 = fd.ops.rsqrt(T51)
    T57 = fd.ops.broadcast_in_dim(T52, shape=[1, 6, 2048], broadcast_dims=[0, 1, 2])
    T58 = fd.ops.mul(T38, T57)
    T63 = fd.ops.broadcast_in_dim(T4, shape=[1, 6, 2048], broadcast_dims=[2])
    T64 = fd.ops.cast(T63, dtype=DataType.Float)
    T65 = fd.ops.mul(T64, T58)
    T66 = fd.ops.cast(T65, dtype=DataType.BFloat16)
    T67 = fd.ops.linear(T66, T5)
    T68 = fd.ops.linear(T66, T6)
    T69 = fd.ops.linear(T66, T7)
    T75 = fd.ops.reshape(T67, new_shape=[1, 6, 32, 64])
    T76 = fd.ops.permute(T75, dims=[0, 2, 1, 3])
    T82 = fd.ops.reshape(T68, new_shape=[1, 6, 8, 64])
    T83 = fd.ops.permute(T82, dims=[0, 2, 1, 3])
    T89 = fd.ops.reshape(T69, new_shape=[1, 6, 8, 64])
    T90 = fd.ops.permute(T89, dims=[0, 2, 1, 3])
    T96 = fd.ops.broadcast_in_dim(T36, shape=[1, 1, 6, 64], broadcast_dims=[0, 2, 3])
    T102 = fd.ops.broadcast_in_dim(T37, shape=[1, 1, 6, 64], broadcast_dims=[0, 2, 3])
    T108 = fd.ops.broadcast_in_dim(T96, shape=[1, 32, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T109 = fd.ops.cast(T76, dtype=DataType.Float)
    T110 = fd.ops.cast(T108, dtype=DataType.Float)
    T111 = fd.ops.mul(T109, T110)
    T127 = fd.ops.slice(T76, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 6, 32], strides=[1, 1, 1, 1], manual_normalization=0)
    T143 = fd.ops.slice(T76, start_indices=[0, 0, 0, 32], end_indices=[1, 32, 6, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T144 = fd.ops.cast(T143, dtype=DataType.Float)
    T145 = fd.ops.neg(T144)
    T146 = fd.ops.cast(T145, dtype=DataType.BFloat16)
    T147 = fd.ops.cat([T146, T127], dim=-1, manual_padding=0)
    T153 = fd.ops.broadcast_in_dim(T102, shape=[1, 32, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T154 = fd.ops.cast(T147, dtype=DataType.Float)
    T155 = fd.ops.cast(T153, dtype=DataType.Float)
    T156 = fd.ops.mul(T154, T155)
    T157 = fd.ops.add(T111, T156)
    T158 = fd.ops.cast(T157, dtype=DataType.BFloat16)
    T164 = fd.ops.broadcast_in_dim(T96, shape=[1, 8, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T165 = fd.ops.cast(T83, dtype=DataType.Float)
    T166 = fd.ops.cast(T164, dtype=DataType.Float)
    T167 = fd.ops.mul(T165, T166)
    T183 = fd.ops.slice(T83, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 6, 32], strides=[1, 1, 1, 1], manual_normalization=0)
    T199 = fd.ops.slice(T83, start_indices=[0, 0, 0, 32], end_indices=[1, 8, 6, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T200 = fd.ops.cast(T199, dtype=DataType.Float)
    T201 = fd.ops.neg(T200)
    T202 = fd.ops.cast(T201, dtype=DataType.BFloat16)
    T203 = fd.ops.cat([T202, T183], dim=-1, manual_padding=0)
    T209 = fd.ops.broadcast_in_dim(T102, shape=[1, 8, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T210 = fd.ops.cast(T203, dtype=DataType.Float)
    T211 = fd.ops.cast(T209, dtype=DataType.Float)
    T212 = fd.ops.mul(T210, T211)
    T213 = fd.ops.add(T167, T212)
    T214 = fd.ops.cast(T213, dtype=DataType.BFloat16)
    T221 = fd.ops.broadcast_in_dim(T214, shape=[1, 8, 1, 6, 64], broadcast_dims=[0, 1, 3, 4])
    T228 = fd.ops.broadcast_in_dim(T221, shape=[1, 8, 4, 6, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T234 = fd.ops.reshape(T228, new_shape=[1, 32, 6, 64])
    T241 = fd.ops.broadcast_in_dim(T90, shape=[1, 8, 1, 6, 64], broadcast_dims=[0, 1, 3, 4])
    T248 = fd.ops.broadcast_in_dim(T241, shape=[1, 8, 4, 6, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T254 = fd.ops.reshape(T248, new_shape=[1, 32, 6, 64])
    T255 = fd.ops.stride_order(T158, stride_order=[3, 2, 1, 0])
    T256 = fd.ops.stride_order(T234, stride_order=[3, 2, 1, 0])
    T257 = fd.ops.stride_order(T254, stride_order=[3, 2, 1, 0])
    fd.add_output(T27)
    fd.add_output(T90)
    fd.add_output(T214)
    fd.add_output(T255)
    fd.add_output(T256)
    fd.add_output(T257)

with FusionDefinition() as fd:
    nvfuser_fusion_id2(fd)

inputs = [
    torch.testing.make_tensor((1, 6), dtype=torch.int64, device='cuda:0'),
    torch.testing.make_tensor((1, 1, 6, 6), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((1, 32, 6), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1, 6, 2048), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((2048,), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((2048, 2048), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((512, 2048), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((512, 2048), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs)

The Thunder repro: (this requires enabling linears to be consumed by nvFuser)

import torch
import thunder
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from typing import Callable
from functools import partial, wraps
from collections import OrderedDict

LLAMA_3_2_1B_CFG = {
    "architectures": ["LlamaForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "bos_token_id": 128000,
    "eos_token_id": 128001,
    "head_dim": 64,
    "hidden_act": "silu",
    "hidden_size": 2048,
    "initializer_range": 0.02,
    "intermediate_size": 8192,
    "max_position_embeddings": 131072,
    "mlp_bias": False,
    "model_type": "llama",
    "num_attention_heads": 32,
    "num_hidden_layers": 16,
    "num_key_value_heads": 8,
    "pretraining_tp": 1,
    "rms_norm_eps": 1e-05,
    "rope_scaling": {
        "factor": 32.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    "rope_theta": 500000.0,
    "tie_word_embeddings": True,
    "torch_dtype": "bfloat16",
    "transformers_version": "4.45.0.dev0",
    "use_cache": True,
    "vocab_size": 128256,
    "_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}

config = LlamaConfig(**LLAMA_3_2_1B_CFG)
config.num_hidden_layers = 1

with torch.device("cuda"):
    model = LlamaForCausalLM(config).to(torch.bfloat16).requires_grad_(False).eval()

args = dict(
    cache_positions=torch.arange(6, device="cuda"),
    input_ids=torch.tensor([[128000, 791, 1401, 311, 2324, 374]], device="cuda"),
    attention_mask=torch.ones(1, 6, dtype=torch.int64, device="cuda"),
    inputs_embeds=None,
    use_cache=True,
    return_dict=True,
)

def cuda_timer(warmup_iters: int = 10, timing_iters: int = 40):
    def decorator(fn: Callable) -> Callable:
        @wraps(fn)
        def wrapper(*args, **kwargs) -> float:
            for _ in range(warmup_iters):
                fn(*args, **kwargs)

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            start.record()
            for _ in range(timing_iters):
                fn(*args, **kwargs)
            end.record()
            torch.cuda.synchronize()

            kernel_time = start.elapsed_time(end) / timing_iters
            return kernel_time
        return wrapper
    return decorator

@cuda_timer()
def run_model(mymodel, args) :
    res = mymodel(**args)

def eager(fn):
    return fn

executors = OrderedDict()
executors['Thunder-nvFuser'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa', 'nvfuser'])
executors['torch-eager'] = eager

for name, func in executors.items():
    exec_model = func(model)
    kernel_time = run_model(exec_model, args)
    print(f"{name} {kernel_time:.03f} ms")
@kevinstephano
Copy link
Collaborator Author

This might be a smaller repro:

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id6(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 6, 32], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False)
    T1 = fd.define_tensor(shape=[1, 6, 2048], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False)
    S2 = fd.ops.size(T1, dim=0)
    S3 = fd.ops.size(T1, dim=1)
    S4 = fd.define_scalar(32, dtype=DataType.Int)
    S5 = fd.define_scalar(64, dtype=DataType.Int)
    V6 = fd.define_vector([S2, S3, S4, S5], dtype=DataType.Int)
    T7 = fd.ops.reshape(T1, new_shape=V6)
    T8 = fd.ops.permute(T7, dims=[0, 2, 1, 3])
    T9 = fd.ops.cast(T8, dtype=DataType.Float)
    S10 = fd.define_scalar(0.00000, dtype=DataType.Float)
    S11 = fd.define_scalar(0, dtype=DataType.Int)
    S12 = fd.define_scalar(0, dtype=DataType.Int)
    S13 = fd.define_scalar(0, dtype=DataType.Int)
    S14 = fd.define_scalar(0, dtype=DataType.Int)
    S15 = fd.define_scalar(0, dtype=DataType.Int)
    S16 = fd.define_scalar(32, dtype=DataType.Int)
    V18 = fd.define_vector([S15, S16, S13, S14, S11, S12], dtype=DataType.Int)
    T17 = fd.ops.pad(T0, V18, S10)
    S19 = fd.define_scalar(0, dtype=DataType.Int)
    S20 = fd.ops.size(T0, dim=2)
    S21 = fd.ops.add(S19, S20)
    S22 = fd.define_scalar(0.00000, dtype=DataType.Float)
    S23 = fd.define_scalar(0, dtype=DataType.Int)
    S24 = fd.define_scalar(0, dtype=DataType.Int)
    S25 = fd.define_scalar(0, dtype=DataType.Int)
    S26 = fd.define_scalar(0, dtype=DataType.Int)
    S27 = fd.define_scalar(0, dtype=DataType.Int)
    V29 = fd.define_vector([S21, S27, S25, S26, S23, S24], dtype=DataType.Int)
    T28 = fd.ops.pad(T0, V29, S22)
    T30 = fd.ops.cat([T17, T28], dim=2, manual_padding=1)
    T31 = fd.ops.cos(T30)
    S32 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T33 = fd.ops.mul(T31, S32)
    T34 = fd.ops.cast(T33, dtype=DataType.BFloat16)
    T35 = fd.ops.broadcast(T34, is_broadcast_dim=[False, True, False, False])
    S36 = fd.define_scalar(32, dtype=DataType.Int)
    S37 = fd.ops.size(T35, dim=2)
    S38 = fd.ops.size(T35, dim=3)
    V39 = fd.define_vector([S2, S36, S37, S38], dtype=DataType.Int)
    T40 = fd.ops.expand(T35, shape=V39)
    T41 = fd.ops.cast(T40, dtype=DataType.Float)
    T42 = fd.ops.mul(T9, T41)
    T43 = fd.ops.sin(T30)
    fd.add_output(T8)
    fd.add_output(T35)
    fd.add_output(T42)
    fd.add_output(T43)

with FusionDefinition() as fd:
    nvfuser_fusion_id6(fd)

inputs = [
    torch.randn(192, dtype=torch.float32, device='cuda:0').as_strided((1, 6, 32), (192, 1, 6)),
    torch.testing.make_tensor((1, 6, 2048), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs)

Error:

RuntimeError: Stride mismatch with contiguity info.  allocation domain: bS0{1}, iS1{6}, iS2{32}: sizes: [1, 6, 32]: strides: [192, 1, 6]; contiguity: n, t, t; dim: 2; expected stride: 1; actual stride: 6

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Dec 6, 2024

This might be a smaller repro:

This particular error seems to be because T0 has default stride order but inputs[0] looks like it is expecting 2, 0, 1

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Dec 6, 2024

Small repro using segment 15 of the original repro in the description:

# CUDA devices:
#  0: NVIDIA GeForce RTX 3090 Ti
# torch version: 2.6.0a0+gitffb7a08
# cuda version: 12.6
# nvfuser version: 0.2.23+git67127c9
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 6, 512], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False)
    S1 = fd.ops.size(T0, dim=0)
    S2 = fd.ops.size(T0, dim=1)
    S3 = fd.define_scalar(8, dtype=DataType.Int)
    S4 = fd.define_scalar(64, dtype=DataType.Int)
    V5 = fd.define_vector([S1, S2, S3, S4], dtype=DataType.Int)
    T6 = fd.ops.reshape(T0, new_shape=V5)
    T7 = fd.ops.permute(T6, dims=[0, 2, 1, 3])
    T8 = fd.ops.broadcast(T7, is_broadcast_dim=[False, False, True, False, False])
    S9 = fd.define_scalar(4, dtype=DataType.Int)
    V10 = fd.define_vector([S1, S3, S9, S2, S4], dtype=DataType.Int)
    T11 = fd.ops.expand(T8, shape=V10)
    S12 = fd.define_scalar(32, dtype=DataType.Int)
    V13 = fd.define_vector([S1, S12, S2, S4], dtype=DataType.Int)
    T14 = fd.ops.reshape(T11, new_shape=V13)
    fd.add_output(T7, stride_order=[3, 1, 2, 0])
    fd.add_output(T14)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.testing.make_tensor((1, 6, 512), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs)

The error is happening at this line

spanning_tree.traverse(&propagator);

The fusion at this point looks like

%kernel {
T6_l___bfloat[iS88{24}, iS89{1}, iS87{128}]
   = Set( T0_g___bfloat[iS95{24}, iS96{1}, iS94{128}], cache_op=Streaming )
T1_l___bfloat[iS81{24}, iS82{1}, iS80{128}] = view( T6_l___bfloat[iS88{24}, iS89{1}, iS87{128}] )
T2_l___bfloat[iS74{24}, iS75{1}, iS73{128}]
   = Set.Permute( T1_l___bfloat[iS81{24}, iS82{1}, iS80{128}], cache_op=Streaming )
T8_l___bfloat[iS67{24}, iS68{1}, iS66{128}]
   = Set( T2_l___bfloat[iS74{24}, iS75{1}, iS73{128}], cache_op=Streaming )
T7_g___bfloat[iblockIdx.x60{24}, iUS61{1}, ithreadIdx.x59{128}]
   = Set( T8_l___bfloat[iS67{24}, iS68{1}, iS66{128}], cache_op=Streaming )
T3_l___bfloat[iS102{24}, iS103{1}, iS101{128}, bS15{1}]
   = broadcast( T2_l___bfloat[iS74{24}, iS75{1}, iS73{128}] )
T4_l___bfloat[iS109{24}, iS110{1}, iS108{128}, bS20{1 ex 4}] = expand( T3_l___bfloat[iS102{24}, iS103{1}, iS101{128}, bS15{1}], {1, 8, 4, 6, 64} )
T9_l___bfloat[iS30{32}rf, iS27{64}, bS23{1}, iS26{6}] = view( T4_l___bfloat[iS109{24}, iS110{1}, iS108{128}, bS20{1 ex 4}] )
T5_g___bfloat[iS46{32}, iS48{64}, bS45{1}, iS47{6}]
   = Set( T9_l___bfloat[iS30{32}rf, iS27{64}, bS23{1}, iS26{6}], cache_op=Streaming )

TransformPrinter :
T0_g___bfloat[iS95{24}, iS96{1}, iS94{128}]
 logical domain : (bS0{1}, iS1{6}, iS2{512})
 contiguity: n t t
  Outer split: iS2{512} by factor 8 -> iS51{8}, iS52{64}
  Merge: iS1{6} and iS52{64} -> iS90{384}
  Merge: iS51{8} and iS90{384} -> iS91{3072}
  Merge: bS0{1} and iS91{3072} -> iS92{3072}
  Split: iS92{3072} by factor 128 -> iS93{24}, iS94{128}
  Split: iS93{24} by factor 1 -> iS95{24}, iS96{1}
 loop domain : (iS95{24}, iS96{1}, iS94{128})
T6_l___bfloat[iS88{24}, iS89{1}, iS87{128}]
 logical domain : (bS34{1}, iS35{6}, iS36{512})
 contiguity: n t t
  Outer split: iS36{512} by factor 8 -> iS49{8}, iS50{64}
  Merge: iS35{6} and iS50{64} -> iS83{384}
  Merge: iS49{8} and iS83{384} -> iS84{3072}
  Merge: bS34{1} and iS84{3072} -> iS85{3072}
  Split: iS85{3072} by factor 128 -> iS86{24}, iS87{128}
  Split: iS86{24} by factor 1 -> iS88{24}, iS89{1}
 loop domain : (iS88{24}, iS89{1}, iS87{128})
T1_l___bfloat[iS81{24}, iS82{1}, iS80{128}]
 root domain : (bS3{1}, iS4{6}, iS6{512}rf)
  Outer split: iS6{512}rf by factor 8 -> iS7{8}rf, iS8{64}rf
 logical domain : (bS3{1}, iS4{6}, iS7{8}rf, iS8{64}rf)
 allocation domain : (bS3{1}, iS4{6}, iS7{8}rf, iS8{64}rf)
 contiguity: n t t t
  Merge: iS4{6} and iS8{64}rf -> iS76{384}
  Merge: iS7{8}rf and iS76{384} -> iS77{3072}
  Merge: bS3{1} and iS77{3072} -> iS78{3072}
  Split: iS78{3072} by factor 128 -> iS79{24}, iS80{128}
  Split: iS79{24} by factor 1 -> iS81{24}, iS82{1}
 loop domain : (iS81{24}, iS82{1}, iS80{128})
T2_l___bfloat[iS74{24}, iS75{1}, iS73{128}]
 root domain : (bS9{1}, iS10{6}, iS11{8}, iS12{64})
 logical domain : (bS9{1}, iS11{8}, iS10{6}, iS12{64})
 allocation domain : (bS9{1}, iS10{6}, iS11{8}, iS12{64})
 contiguity: n t t t
  Merge: iS10{6} and iS12{64} -> iS69{384}
  Merge: iS11{8} and iS69{384} -> iS70{3072}
  Merge: bS9{1} and iS70{3072} -> iS71{3072}
  Split: iS71{3072} by factor 128 -> iS72{24}, iS73{128}
  Split: iS72{24} by factor 1 -> iS74{24}, iS75{1}
 loop domain : (iS74{24}, iS75{1}, iS73{128})
T8_l___bfloat[iS67{24}, iS68{1}, iS66{128}]
 logical domain : (bS37{1}, iS38{8}, iS39{6}, iS40{64})
 allocation domain : (bS37{1}, iS39{6}, iS38{8}, iS40{64})
 contiguity: n t t t
  Merge: iS39{6} and iS40{64} -> iS62{384}
  Merge: iS38{8} and iS62{384} -> iS63{3072}
  Merge: bS37{1} and iS63{3072} -> iS64{3072}
  Split: iS64{3072} by factor 128 -> iS65{24}, iS66{128}
  Split: iS65{24} by factor 1 -> iS67{24}, iS68{1}
 loop domain : (iS67{24}, iS68{1}, iS66{128})
T7_g___bfloat[iblockIdx.x60{24}, iUS61{1}, ithreadIdx.x59{128}]
 logical domain : (bS41{1}, iS42{8}, iS43{6}, iS44{64})
 allocation domain : (bS41{1}, iS43{6}, iS42{8}, iS44{64})
 contiguity: n t t t
  Merge: iS43{6} and iS44{64} -> iS55{384}
  Merge: iS42{8} and iS55{384} -> iS56{3072}
  Merge: bS41{1} and iS56{3072} -> iS57{3072}
  Split: iS57{3072} by factor 128 -> iS58{24}, ithreadIdx.x59{128}
  Split: iS58{24} by factor 1 -> iblockIdx.x60{24}, iUS61{1}
 loop domain : (iblockIdx.x60{24}, iUS61{1}, ithreadIdx.x59{128})
T3_l___bfloat[iS102{24}, iS103{1}, iS101{128}, bS15{1}]
 logical domain : (bS13{1}, iS14{8}, bS15{1}, iS16{6}, iS17{64})
 allocation domain : (bS13{1}, iS16{6}, iS14{8}, iS17{64}, bS15{1})
 contiguity: n t t t n
  Merge: iS16{6} and iS17{64} -> iS97{384}
  Merge: iS14{8} and iS97{384} -> iS98{3072}
  Merge: bS13{1} and iS98{3072} -> iS99{3072}
  Split: iS99{3072} by factor 128 -> iS100{24}, iS101{128}
  Split: iS100{24} by factor 1 -> iS102{24}, iS103{1}
 loop domain : (iS102{24}, iS103{1}, iS101{128}, bS15{1})
T4_l___bfloat[iS109{24}, iS110{1}, iS108{128}, bS20{1 ex 4}]
 logical domain : (bS18{1}, iS19{8}, bS20{1 ex 4}, iS21{6}, iS22{64})
 allocation domain : (bS18{1}, iS21{6}, iS19{8}, iS22{64}, bS20{1 ex 4})
 contiguity: n t t t n
  Merge: iS21{6} and iS22{64} -> iS104{384}
  Merge: iS19{8} and iS104{384} -> iS105{3072}
  Merge: bS18{1} and iS105{3072} -> iS106{3072}
  Split: iS106{3072} by factor 128 -> iS107{24}, iS108{128}
  Split: iS107{24} by factor 1 -> iS109{24}, iS110{1}
 loop domain : (iS109{24}, iS110{1}, iS108{128}, bS20{1 ex 4})
T9_l___bfloat[iS30{32}rf, iS27{64}, bS23{1}, iS26{6}]
 root domain : (bS23{1}, iS28{8}rf, iS29{4}rf, iS26{6}, iS27{64})
  Merge: iS28{8}rf and iS29{4}rf -> iS30{32}rf
 logical domain : (bS23{1}, iS30{32}rf, iS26{6}, iS27{64})
 contiguity: n t t t
 loop domain : (iS30{32}rf, iS27{64}, bS23{1}, iS26{6})
T5_g___bfloat[iS46{32}, iS48{64}, bS45{1}, iS47{6}]
 logical domain : (bS45{1}, iS46{32}, iS47{6}, iS48{64})
 contiguity: n t t t
 loop domain : (iS46{32}, iS48{64}, bS45{1}, iS47{6})
} // %kernel

The error occurs in TransformReplay::replayCasP when propagating from consumer T9 to producer T4, which corresponds to the view op.

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Dec 6, 2024

This looks like a dupe of #3512. I confirmed this is fixed by #3513.

@kevinstephano
Copy link
Collaborator Author

Duplicate of #3512.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants