Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Squeezed IterDomain ?S536{1} must concretize to IterType::Broadcast but found ?S536{1}. #2359

Closed
wujingyue opened this issue Jun 6, 2024 · 7 comments · Fixed by #2363
Closed
Assignees

Comments

@wujingyue
Copy link
Collaborator

wujingyue commented Jun 6, 2024

This happened when I ran the transformer block with batch_size=1. It can be reproduced by

  1. checking out https://github.com/Lightning-AI/lightning-thunder/tree/wjy/sharded, and
  2. running pytest thunder/benchmarks/targets.py -k test_nanogpt_block_grad[thunder] -s.

I'm unsure whether it's a Thunder bug or nvFuser bug. I suspect define_tensor needs to say shape=[1,...] when the batch size is one?

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T4 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T5 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T6 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T7 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T8 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T9 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T10 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T11 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T12 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T13 = fd.ops.cast(T12, dtype=DataType.Float)
    T14, T15 = fd.ops.var_mean(T13, dims=[2], correction=0, keepdim=False)
    S16 = fd.define_scalar(1, dtype=DataType.Int)
    S17 = fd.define_scalar(2048, dtype=DataType.Int)
    S18 = fd.define_scalar(1, dtype=DataType.Int)
    V19 = fd.define_vector([S16, S17, S18], dtype=DataType.Int)
    T20 = fd.ops.broadcast_in_dim(T14, shape=V19, broadcast_dims=[0, 1])
    S21 = fd.define_scalar(1, dtype=DataType.Int)
    S22 = fd.define_scalar(2048, dtype=DataType.Int)
    S23 = fd.define_scalar(1, dtype=DataType.Int)
    V24 = fd.define_vector([S21, S22, S23], dtype=DataType.Int)
    T25 = fd.ops.broadcast_in_dim(T15, shape=V24, broadcast_dims=[0, 1])
    S26 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T27 = fd.ops.add(T20, S26)
    T28 = fd.ops.rsqrt(T27)
    S29 = fd.define_scalar(1, dtype=DataType.Int)
    S30 = fd.define_scalar(2048, dtype=DataType.Int)
    S31 = fd.define_scalar(12288, dtype=DataType.Int)
    V32 = fd.define_vector([S29, S30, S31], dtype=DataType.Int)
    T33 = fd.ops.broadcast_in_dim(T25, shape=V32, broadcast_dims=[0, 1, 2])
    T34 = fd.ops.sub(T13, T33)
    S35 = fd.define_scalar(1, dtype=DataType.Int)
    S36 = fd.define_scalar(2048, dtype=DataType.Int)
    S37 = fd.define_scalar(12288, dtype=DataType.Int)
    V38 = fd.define_vector([S35, S36, S37], dtype=DataType.Int)
    T39 = fd.ops.broadcast_in_dim(T28, shape=V38, broadcast_dims=[0, 1, 2])
    T40 = fd.ops.mul(T34, T39)
    S41 = fd.define_scalar(1, dtype=DataType.Int)
    S42 = fd.define_scalar(2048, dtype=DataType.Int)
    S43 = fd.define_scalar(12288, dtype=DataType.Int)
    V44 = fd.define_vector([S41, S42, S43], dtype=DataType.Int)
    T45 = fd.ops.broadcast_in_dim(T5, shape=V44, broadcast_dims=[2])
    T46 = fd.ops.cast(T45, dtype=DataType.Float)
    T47 = fd.ops.mul(T40, T46)
    S48 = fd.define_scalar(1, dtype=DataType.Int)
    S49 = fd.define_scalar(2048, dtype=DataType.Int)
    S50 = fd.define_scalar(12288, dtype=DataType.Int)
    V51 = fd.define_vector([S48, S49, S50], dtype=DataType.Int)
    T52 = fd.ops.broadcast_in_dim(T4, shape=V51, broadcast_dims=[2])
    T53 = fd.ops.cast(T52, dtype=DataType.Float)
    T54 = fd.ops.add(T47, T53)
    T55 = fd.ops.cast(T54, dtype=DataType.BFloat16)
    S56 = fd.define_scalar(2048, dtype=DataType.Int)
    S57 = fd.define_scalar(12288, dtype=DataType.Int)
    V58 = fd.define_vector([S56, S57], dtype=DataType.Int)
    T59 = fd.ops.reshape(T55, new_shape=V58)
    T60 = fd.ops.linear(T59, T1, T0)
    S61 = fd.define_scalar(1, dtype=DataType.Int)
    S62 = fd.define_scalar(2048, dtype=DataType.Int)
    S63 = fd.define_scalar(36864, dtype=DataType.Int)
    V64 = fd.define_vector([S61, S62, S63], dtype=DataType.Int)
    T65 = fd.ops.reshape(T60, new_shape=V64)
    T66 = fd.ops.slice(T65, start_indices=[0, 0, 0], end_indices=[1, 2048, 12288], strides=[1, 1, 1])
    T67 = fd.ops.slice(T65, start_indices=[0, 0, 12288], end_indices=[1, 2048, 24576], strides=[1, 1, 1])
    T68 = fd.ops.slice(T65, start_indices=[0, 0, 24576], end_indices=[1, 2048, 36864], strides=[1, 1, 1])
    S69 = fd.define_scalar(1, dtype=DataType.Int)
    S70 = fd.define_scalar(2048, dtype=DataType.Int)
    S71 = fd.define_scalar(96, dtype=DataType.Int)
    S72 = fd.define_scalar(128, dtype=DataType.Int)
    V73 = fd.define_vector([S69, S70, S71, S72], dtype=DataType.Int)
    T74 = fd.ops.reshape(T67, new_shape=V73)
    T75 = fd.ops.permute(T74, dims=[0, 2, 1, 3])
    S76 = fd.define_scalar(1, dtype=DataType.Int)
    S77 = fd.define_scalar(2048, dtype=DataType.Int)
    S78 = fd.define_scalar(96, dtype=DataType.Int)
    S79 = fd.define_scalar(128, dtype=DataType.Int)
    V80 = fd.define_vector([S76, S77, S78, S79], dtype=DataType.Int)
    T81 = fd.ops.reshape(T66, new_shape=V80)
    T82 = fd.ops.permute(T81, dims=[0, 2, 1, 3])
    S83 = fd.define_scalar(1, dtype=DataType.Int)
    S84 = fd.define_scalar(2048, dtype=DataType.Int)
    S85 = fd.define_scalar(96, dtype=DataType.Int)
    S86 = fd.define_scalar(128, dtype=DataType.Int)
    V87 = fd.define_vector([S83, S84, S85, S86], dtype=DataType.Int)
    T88 = fd.ops.reshape(T68, new_shape=V87)
    T89 = fd.ops.permute(T88, dims=[0, 2, 1, 3])
    T90 = fd.ops.cast(T82, dtype=DataType.Float)
    S91 = fd.define_scalar(0.297302, dtype=DataType.Double)
    T92 = fd.ops.mul(T90, S91)
    T93 = fd.ops.cast(T92, dtype=DataType.BFloat16)
    T94 = fd.ops.permute(T75, dims=[0, 1, 3, 2])
    T95 = fd.ops.cast(T94, dtype=DataType.Float)
    S96 = fd.define_scalar(0.297302, dtype=DataType.Double)
    T97 = fd.ops.mul(T95, S96)
    T98 = fd.ops.cast(T97, dtype=DataType.BFloat16)
    T99 = fd.ops.matmul(T93, T98)
    S100 = fd.define_scalar(2048, dtype=DataType.Int)
    S101 = fd.define_scalar(0, dtype=DataType.Int)
    S102 = fd.define_scalar(1, dtype=DataType.Int)
    T103 = fd.ops.iota(S100, S101, S102, dtype=DataType.Int)
    S104 = fd.define_scalar(2048, dtype=DataType.Int)
    S105 = fd.define_scalar(1, dtype=DataType.Int)
    V106 = fd.define_vector([S104, S105], dtype=DataType.Int)
    T107 = fd.ops.broadcast_in_dim(T103, shape=V106, broadcast_dims=[0])
    S108 = fd.define_scalar(2048, dtype=DataType.Int)
    S109 = fd.define_scalar(0, dtype=DataType.Int)
    S110 = fd.define_scalar(1, dtype=DataType.Int)
    T111 = fd.ops.iota(S108, S109, S110, dtype=DataType.Int)
    S112 = fd.define_scalar(1, dtype=DataType.Int)
    S113 = fd.define_scalar(2048, dtype=DataType.Int)
    V114 = fd.define_vector([S112, S113], dtype=DataType.Int)
    T115 = fd.ops.broadcast_in_dim(T111, shape=V114, broadcast_dims=[1])
    S116 = fd.define_scalar(0, dtype=DataType.Int)
    T117 = fd.ops.add(T107, S116)
    S118 = fd.define_scalar(2048, dtype=DataType.Int)
    S119 = fd.define_scalar(2048, dtype=DataType.Int)
    V120 = fd.define_vector([S118, S119], dtype=DataType.Int)
    T121 = fd.ops.broadcast_in_dim(T117, shape=V120, broadcast_dims=[0, 1])
    S122 = fd.define_scalar(2048, dtype=DataType.Int)
    S123 = fd.define_scalar(2048, dtype=DataType.Int)
    V124 = fd.define_vector([S122, S123], dtype=DataType.Int)
    T125 = fd.ops.broadcast_in_dim(T115, shape=V124, broadcast_dims=[0, 1])
    T126 = fd.ops.ge(T121, T125)
    S127 = fd.define_scalar(1, dtype=DataType.Int)
    S128 = fd.define_scalar(96, dtype=DataType.Int)
    S129 = fd.define_scalar(2048, dtype=DataType.Int)
    S130 = fd.define_scalar(2048, dtype=DataType.Int)
    V131 = fd.define_vector([S127, S128, S129, S130], dtype=DataType.Int)
    T132 = fd.ops.broadcast_in_dim(T126, shape=V131, broadcast_dims=[2, 3])
    S133 = fd.define_scalar(float("-inf"), dtype=DataType.Double)
    T134 = fd.ops.where(T132, T99, S133)
    T135 = fd.ops.cast(T134, dtype=DataType.Float)
    T136 = fd.ops.max(T135, dims=[3], keepdim=False, dtype=DataType.Null)
    S137 = fd.define_scalar(1, dtype=DataType.Int)
    S138 = fd.define_scalar(96, dtype=DataType.Int)
    S139 = fd.define_scalar(2048, dtype=DataType.Int)
    S140 = fd.define_scalar(1, dtype=DataType.Int)
    V141 = fd.define_vector([S137, S138, S139, S140], dtype=DataType.Int)
    T142 = fd.ops.broadcast_in_dim(T136, shape=V141, broadcast_dims=[0, 1, 2])
    S143 = fd.define_scalar(1, dtype=DataType.Int)
    S144 = fd.define_scalar(96, dtype=DataType.Int)
    S145 = fd.define_scalar(2048, dtype=DataType.Int)
    S146 = fd.define_scalar(2048, dtype=DataType.Int)
    V147 = fd.define_vector([S143, S144, S145, S146], dtype=DataType.Int)
    T148 = fd.ops.broadcast_in_dim(T142, shape=V147, broadcast_dims=[0, 1, 2, 3])
    T149 = fd.ops.sub(T135, T148)
    T150 = fd.ops.exp(T149)
    T151 = fd.ops.sum(T150, dims=[3], keepdim=False, dtype=DataType.Null)
    S152 = fd.define_scalar(1, dtype=DataType.Int)
    S153 = fd.define_scalar(96, dtype=DataType.Int)
    S154 = fd.define_scalar(2048, dtype=DataType.Int)
    S155 = fd.define_scalar(1, dtype=DataType.Int)
    V156 = fd.define_vector([S152, S153, S154, S155], dtype=DataType.Int)
    T157 = fd.ops.broadcast_in_dim(T151, shape=V156, broadcast_dims=[0, 1, 2])
    S158 = fd.define_scalar(1, dtype=DataType.Int)
    S159 = fd.define_scalar(96, dtype=DataType.Int)
    S160 = fd.define_scalar(2048, dtype=DataType.Int)
    S161 = fd.define_scalar(2048, dtype=DataType.Int)
    V162 = fd.define_vector([S158, S159, S160, S161], dtype=DataType.Int)
    T163 = fd.ops.broadcast_in_dim(T157, shape=V162, broadcast_dims=[0, 1, 2, 3])
    T164 = fd.ops.reciprocal(T163)
    T165 = fd.ops.mul(T150, T164)
    T166 = fd.ops.cast(T165, dtype=DataType.BFloat16)
    S167 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S168 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S169 = fd.define_scalar(1, dtype=DataType.Int)
    S170 = fd.define_scalar(96, dtype=DataType.Int)
    S171 = fd.define_scalar(2048, dtype=DataType.Int)
    S172 = fd.define_scalar(2048, dtype=DataType.Int)
    V173 = fd.define_vector([S169, S170, S171, S172], dtype=DataType.Int)
    T174 = fd.ops.uniform(S167, S168, shape=V173, dtype=DataType.BFloat16)
    S175 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T176 = fd.ops.lt(T174, S175)
    T177 = fd.ops.cast(T176, dtype=DataType.Float)
    T178 = fd.ops.mul(T165, T177)
    S179 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T180 = fd.ops.mul(T178, S179)
    T181 = fd.ops.cast(T180, dtype=DataType.BFloat16)
    T182 = fd.ops.matmul(T181, T89)
    T183 = fd.ops.permute(T182, dims=[0, 2, 1, 3])
    T184 = fd.ops.stride_order(T183, stride_order=[3, 2, 1, 0])
    S185 = fd.define_scalar(1, dtype=DataType.Int)
    S186 = fd.define_scalar(2048, dtype=DataType.Int)
    S187 = fd.define_scalar(12288, dtype=DataType.Int)
    V188 = fd.define_vector([S185, S186, S187], dtype=DataType.Int)
    T189 = fd.ops.reshape(T184, new_shape=V188)
    S190 = fd.define_scalar(2048, dtype=DataType.Int)
    S191 = fd.define_scalar(12288, dtype=DataType.Int)
    V192 = fd.define_vector([S190, S191], dtype=DataType.Int)
    T193 = fd.ops.reshape(T189, new_shape=V192)
    T194 = fd.ops.linear(T193, T3, T2)
    S195 = fd.define_scalar(1, dtype=DataType.Int)
    S196 = fd.define_scalar(2048, dtype=DataType.Int)
    S197 = fd.define_scalar(12288, dtype=DataType.Int)
    V198 = fd.define_vector([S195, S196, S197], dtype=DataType.Int)
    T199 = fd.ops.reshape(T194, new_shape=V198)
    S200 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S201 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S202 = fd.define_scalar(1, dtype=DataType.Int)
    S203 = fd.define_scalar(2048, dtype=DataType.Int)
    S204 = fd.define_scalar(12288, dtype=DataType.Int)
    V205 = fd.define_vector([S202, S203, S204], dtype=DataType.Int)
    T206 = fd.ops.uniform(S200, S201, shape=V205, dtype=DataType.BFloat16)
    S207 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T208 = fd.ops.lt(T206, S207)
    T209 = fd.ops.cast(T199, dtype=DataType.Float)
    T210 = fd.ops.cast(T208, dtype=DataType.Float)
    T211 = fd.ops.mul(T209, T210)
    S212 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T213 = fd.ops.mul(T211, S212)
    T214 = fd.ops.add(T13, T213)
    T215, T216 = fd.ops.var_mean(T214, dims=[2], correction=0, keepdim=False)
    S217 = fd.define_scalar(1, dtype=DataType.Int)
    S218 = fd.define_scalar(2048, dtype=DataType.Int)
    S219 = fd.define_scalar(1, dtype=DataType.Int)
    V220 = fd.define_vector([S217, S218, S219], dtype=DataType.Int)
    T221 = fd.ops.broadcast_in_dim(T215, shape=V220, broadcast_dims=[0, 1])
    S222 = fd.define_scalar(1, dtype=DataType.Int)
    S223 = fd.define_scalar(2048, dtype=DataType.Int)
    S224 = fd.define_scalar(1, dtype=DataType.Int)
    V225 = fd.define_vector([S222, S223, S224], dtype=DataType.Int)
    T226 = fd.ops.broadcast_in_dim(T216, shape=V225, broadcast_dims=[0, 1])
    S227 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T228 = fd.ops.add(T221, S227)
    T229 = fd.ops.rsqrt(T228)
    S230 = fd.define_scalar(1, dtype=DataType.Int)
    S231 = fd.define_scalar(2048, dtype=DataType.Int)
    S232 = fd.define_scalar(12288, dtype=DataType.Int)
    V233 = fd.define_vector([S230, S231, S232], dtype=DataType.Int)
    T234 = fd.ops.broadcast_in_dim(T226, shape=V233, broadcast_dims=[0, 1, 2])
    T235 = fd.ops.sub(T214, T234)
    S236 = fd.define_scalar(1, dtype=DataType.Int)
    S237 = fd.define_scalar(2048, dtype=DataType.Int)
    S238 = fd.define_scalar(12288, dtype=DataType.Int)
    V239 = fd.define_vector([S236, S237, S238], dtype=DataType.Int)
    T240 = fd.ops.broadcast_in_dim(T229, shape=V239, broadcast_dims=[0, 1, 2])
    T241 = fd.ops.mul(T235, T240)
    S242 = fd.define_scalar(1, dtype=DataType.Int)
    S243 = fd.define_scalar(2048, dtype=DataType.Int)
    S244 = fd.define_scalar(12288, dtype=DataType.Int)
    V245 = fd.define_vector([S242, S243, S244], dtype=DataType.Int)
    T246 = fd.ops.broadcast_in_dim(T7, shape=V245, broadcast_dims=[2])
    T247 = fd.ops.cast(T246, dtype=DataType.Float)
    T248 = fd.ops.mul(T241, T247)
    S249 = fd.define_scalar(1, dtype=DataType.Int)
    S250 = fd.define_scalar(2048, dtype=DataType.Int)
    S251 = fd.define_scalar(12288, dtype=DataType.Int)
    V252 = fd.define_vector([S249, S250, S251], dtype=DataType.Int)
    T253 = fd.ops.broadcast_in_dim(T6, shape=V252, broadcast_dims=[2])
    T254 = fd.ops.cast(T253, dtype=DataType.Float)
    T255 = fd.ops.add(T248, T254)
    T256 = fd.ops.cast(T255, dtype=DataType.BFloat16)
    S257 = fd.define_scalar(2048, dtype=DataType.Int)
    S258 = fd.define_scalar(12288, dtype=DataType.Int)
    V259 = fd.define_vector([S257, S258], dtype=DataType.Int)
    T260 = fd.ops.reshape(T256, new_shape=V259)
    T261 = fd.ops.linear(T260, T9, T8)
    S262 = fd.define_scalar(1, dtype=DataType.Int)
    S263 = fd.define_scalar(2048, dtype=DataType.Int)
    S264 = fd.define_scalar(49152, dtype=DataType.Int)
    V265 = fd.define_vector([S262, S263, S264], dtype=DataType.Int)
    T266 = fd.ops.reshape(T261, new_shape=V265)
    T267 = fd.ops.cast(T266, dtype=DataType.Float)
    T268 = fd.ops.mul(T267, T267)
    T269 = fd.ops.mul(T268, T267)
    S270 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T271 = fd.ops.mul(S270, T267)
    S272 = fd.define_scalar(0.0447150, dtype=DataType.Double)
    T273 = fd.ops.mul(S272, T269)
    T274 = fd.ops.add(T267, T273)
    S275 = fd.define_scalar(0.797885, dtype=DataType.Double)
    T276 = fd.ops.mul(S275, T274)
    T277 = fd.ops.tanh(T276)
    S278 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T279 = fd.ops.add(S278, T277)
    T280 = fd.ops.mul(T271, T279)
    T281 = fd.ops.cast(T280, dtype=DataType.BFloat16)
    S282 = fd.define_scalar(2048, dtype=DataType.Int)
    S283 = fd.define_scalar(49152, dtype=DataType.Int)
    V284 = fd.define_vector([S282, S283], dtype=DataType.Int)
    T285 = fd.ops.reshape(T281, new_shape=V284)
    T286 = fd.ops.linear(T285, T11, T10)
    S287 = fd.define_scalar(1, dtype=DataType.Int)
    S288 = fd.define_scalar(2048, dtype=DataType.Int)
    S289 = fd.define_scalar(12288, dtype=DataType.Int)
    V290 = fd.define_vector([S287, S288, S289], dtype=DataType.Int)
    T291 = fd.ops.reshape(T286, new_shape=V290)
    S292 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S293 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S294 = fd.define_scalar(1, dtype=DataType.Int)
    S295 = fd.define_scalar(2048, dtype=DataType.Int)
    S296 = fd.define_scalar(12288, dtype=DataType.Int)
    V297 = fd.define_vector([S294, S295, S296], dtype=DataType.Int)
    T298 = fd.ops.uniform(S292, S293, shape=V297, dtype=DataType.BFloat16)
    S299 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T300 = fd.ops.lt(T298, S299)
    T301 = fd.ops.cast(T291, dtype=DataType.Float)
    T302 = fd.ops.cast(T300, dtype=DataType.Float)
    T303 = fd.ops.mul(T301, T302)
    S304 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T305 = fd.ops.mul(T303, S304)
    T306 = fd.ops.add(T214, T305)
    T307 = fd.ops.cast(T306, dtype=DataType.BFloat16)
    fd.add_output(T216)
    fd.add_output(T229)
    fd.add_output(T300)
    fd.add_output(T307)
    fd.add_output(T15)
    fd.add_output(T166)
    fd.add_output(T176)
    fd.add_output(T28)
    fd.add_output(T181)
    fd.add_output(T208)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((36864,), dtype=torch.bfloat16, device='cuda:0').as_strided((36864,), (1,)),
    torch.randn((452984832,), dtype=torch.bfloat16, device='cuda:0').as_strided((36864, 12288), (12288, 1)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((150994944,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288, 12288), (12288, 1)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((49152,), dtype=torch.bfloat16, device='cuda:0').as_strided((49152,), (1,)),
    torch.randn((603979776,), dtype=torch.bfloat16, device='cuda:0').as_strided((49152, 12288), (12288, 1)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((603979776,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288, 49152), (49152, 1)),
    torch.randn((25165824,), dtype=torch.bfloat16, device='cuda:0').as_strided((1, 2048, 12288), (25165824, 12288, 1)),
]
fd.execute(inputs)
Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 145, in execute
    result = self._execute(
RuntimeError: Squeezed IterDomain ?S536{1} must concretize to IterType::Broadcast but found ?S536{1}
Exception raised from checkConcretization at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:1406 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x74fdf03fca67 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::SqueezeOp::checkConcretization(nvfuser::Val*, nvfuser::Val*) const + 0x654 (0x74fdf08b7db4 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x41627b (0x74fdf06ec27b in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x41976a (0x74fdf06ef76a in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: nvfuser::DynamicTransform::concretizeFusion(nvfuser::Fusion*, nvfuser::DynamicTransformConcretizationInfo const*) + 0xa2 (0x74fdf06ef9e2 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x6601f2 (0x74fdf09361f2 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x1e7 (0x74fdf09373f7 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x3ec (0x74fdf0b283fc in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x19e88e (0x74fdf047488e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x2153ff (0x74fdf04eb3ff in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x2a9be0 (0x74fdf057fbe0 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x15a10e (0x5778b421b10e in /usr/bin/python3)
frame #12: _PyObject_MakeTpCall + 0x25b (0x5778b4211a7b in /usr/bin/python3)
frame #13: <unknown function> + 0x168acb (0x5778b4229acb in /usr/bin/python3)
frame #14: _PyEval_EvalFrameDefault + 0x198c (0x5778b420553c in /usr/bin/python3)
frame #15: <unknown function> + 0x1687f1 (0x5778b42297f1 in /usr/bin/python3)
frame #16: PyObject_Call + 0x122 (0x5778b422a492 in /usr/bin/python3)
frame #17: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #18: _PyObject_FastCallDictTstate + 0xc4 (0x5778b4210c14 in /usr/bin/python3)
frame #19: _PyObject_Call_Prepend + 0xc1 (0x5778b42268d1 in /usr/bin/python3)
frame #20: <unknown function> + 0x280700 (0x5778b4341700 in /usr/bin/python3)
frame #21: _PyObject_MakeTpCall + 0x25b (0x5778b4211a7b in /usr/bin/python3)
frame #22: _PyEval_EvalFrameDefault + 0x64e6 (0x5778b420a096 in /usr/bin/python3)
frame #23: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #24: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #25: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #26: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #27: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #28: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #29: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #30: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #31: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #32: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #33: <unknown function> + 0x16893e (0x5778b422993e in /usr/bin/python3)
frame #34: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #35: <unknown function> + 0x16893e (0x5778b422993e in /usr/bin/python3)
frame #36: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #37: _PyObject_FastCallDictTstate + 0xc4 (0x5778b4210c14 in /usr/bin/python3)
frame #38: _PyObject_Call_Prepend + 0x5c (0x5778b422686c in /usr/bin/python3)
frame #39: <unknown function> + 0x280700 (0x5778b4341700 in /usr/bin/python3)
frame #40: PyObject_Call + 0xbb (0x5778b422a42b in /usr/bin/python3)
frame #41: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #42: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #43: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #44: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #45: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #46: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #47: _PyEval_EvalFrameDefault + 0x6bd (0x5778b420426d in /usr/bin/python3)
frame #48: <unknown function> + 0x1687f1 (0x5778b42297f1 in /usr/bin/python3)
frame #49: _PyEval_EvalFrameDefault + 0x198c (0x5778b420553c in /usr/bin/python3)
frame #50: <unknown function> + 0x1687f1 (0x5778b42297f1 in /usr/bin/python3)
frame #51: _PyEval_EvalFrameDefault + 0x198c (0x5778b420553c in /usr/bin/python3)
frame #52: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #53: PyObject_Call + 0x122 (0x5778b422a492 in /usr/bin/python3)
frame #54: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #55: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #56: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #57: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #58: _PyEval_EvalFrameDefault + 0x614a (0x5778b4209cfa in /usr/bin/python3)
frame #59: <unknown function> + 0x1687f1 (0x5778b42297f1 in /usr/bin/python3)
frame #60: _PyEval_EvalFrameDefault + 0x614a (0x5778b4209cfa in /usr/bin/python3)
frame #61: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #62: _PyObject_FastCallDictTstate + 0x16d (0x5778b4210cbd in /usr/bin/python3)
frame #63: _PyObject_Call_Prepend + 0x5c (0x5778b422686c in /usr/bin/python3)
@jjsjann123
Copy link
Collaborator

@jjsjann123 tagging myself. I think it's the reshape that's not specifying the output iterdomain properly.
Let me see if I can simplify the example for @jacobhinkle

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Jun 6, 2024

Also note to myself. how would reshape work for dynamic scalar? should we special case instances where we have 1 for a scalar? should this be a operations imposed constraint to add in prologue trace.

reshape(a, new_shape=[i, j, k]). if we encounter an entry in new_shape that's 1, we cannot have that as a symbolic symbol down in the road, otherwise we might also run into squeeze asserting on that? Or does this mean we should hav relaxed the check in squeeze.... Linking issue Lightning-AI/lightning-thunder#262

@wujingyue
Copy link
Collaborator Author

I suspect define_tensor needs to say shape=[1,...] when the batch size is one?

Never mind. https://github.com/Lightning-AI/lightning-thunder/blob/126940750c8e498a89376e6c787985448c79808a/thunder/executors/nvfuserex_impl.py#L299 indeed kicked in. The first dimension of T12 is 1 in the above example.

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Jun 6, 2024

I don't see any issue with this. The reshape with size '1, xxx' is doing the right thing about translating to a broadcast.

@jjsjann123
Copy link
Collaborator

A smaller repro.

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.ops.slice(T0, start_indices=[0, 0, 0], end_indices=[1, 2, 4], strides=[1, 1, 1])
    S2 = fd.define_scalar(1, dtype=DataType.Int)
    S3 = fd.define_scalar(8, dtype=DataType.Int)
    V4 = fd.define_vector([S2, S3], dtype=DataType.Int)
    V5 = fd.define_vector([S3], dtype=DataType.Int)
    T6 = fd.ops.reshape(T1, new_shape=V4)
    T7 = fd.ops.reshape(T6, new_shape=V5)
    # this works fine
    # T7 = fd.ops.reshape(T1, new_shape=V5)
    fd.add_output(T7)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((100,), dtype=torch.float32, device='cuda:0').as_strided((2, 5, 10), (50, 10, 1)),
]
fd.execute(inputs)

cc'ing @jacobhinkle looks like indeed a concretization bug.

@jacobhinkle
Copy link
Collaborator

I'm back at keyboard. Looks like this is just a check that we should avoid at least in this case. Commenting out the call to checkConcretizedUses inside concretizeReshape makes the test succeed. I'll push a PR soon.

@jacobhinkle
Copy link
Collaborator

A smaller repro.

Thank you!

jacobhinkle added a commit that referenced this issue Jun 7, 2024
This test is actually tougher than the big repro on #2359. It
necessitate either removing the check altogether or the path I took
which is to concretize Resized to broadcast extents as constant 1 so
that we can evaluate max(0, min(i0, 1)) as 1 without calling
simplifyExpr. A less invasive solution would be to remove the extent
check in `SqueezeOp::checkConcretization`.

We could also just remove `Expr::checkConcretization` and
`checkConcretizedUses`. They are only used for SqueezeOp currently and
are not adding much value anyway probably.
wujingyue pushed a commit that referenced this issue Jun 8, 2024
The included test is the small one provided by @jjsjann123 in #2359 and
it's actually tougher than the original repro. It necessitates either
removing the check that concretized squeezed extents are constant 1 or
to concretize Resized to broadcast extents as constant 1 so that we can
evaluate `max(0, min(i0, 1))` as `oneVal()` without calling
`simplifyExpr`. I went with removing the check, which means in this
example we have broadcast dimension with a dynamic shape like `max(0,
min(i0, 1))`. Since we're concretizing to Broadcast, we know that
dimension is not zero; if it were then we'd concretize to `Iteration`
and `SqueezeOp::checkConcretization` would fail the IterType check.
Still, I don't love that the expression cannot be simplified so it
appears in the kernel (`i9` and `i10`):
```c++
__global__ void nvfuser_pointwise_f0_c1_r0_g1(Tensor<float, 3, 3> T0, Tensor<float, 2, 2> T4) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (128LL * ((nvfuser_index_t)blockIdx.x));
  Tensor<float, 3, 3> s1;
  s1.data = T0.data;
  s1.logical_size = T0.logical_size;
  s1.alloc_stride = T0.alloc_stride;
  Array<nvfuser_index_t, 3, 1> a2;
  a2 = s1.logical_size;
  nvfuser_index_t i3;
  i3 = a2[2LL];
  nvfuser_index_t i4;
  i4 = max(0LL, (min(4LL, i3)));
  nvfuser_index_t i5;
  i5 = min(i3, 4LL);
  nvfuser_index_t i6;
  i6 = max(0LL, i5);
  Array<nvfuser_index_t, 3, 1> a7;
  a7 = s1.logical_size;
  nvfuser_index_t i8;
  i8 = a7[0LL];
  nvfuser_index_t i9;
  i9 = min(i8, 1LL);
  nvfuser_index_t i10;
  i10 = max(0LL, i9);
  Array<nvfuser_index_t, 3, 1> a11;
  a11 = s1.logical_size;
  nvfuser_index_t i12;
  i12 = a11[1LL];
  nvfuser_index_t i13;
  i13 = (max(0LL, (min(2LL, i12)))) * i4;
  nvfuser_index_t i14;
  i14 = i0 % i13;
  nvfuser_index_t i15;
  i15 = min(i12, 2LL);
  nvfuser_index_t i16;
  i16 = max(0LL, i15);
  if ((i0 < i13)) {
    float T1[1LL];
    T1[0LL] = 0LL;
    T1[0LL]
       = T0[((((i3 * i12) * (i0 / i13)) + (i3 * (i14 / i4))) + (i14 % i4))];
    float T5[1LL];
    T5[0LL]
       = T1[0LL];
    T4[i0]
       = T5[0LL];
  }
}
```
If you look closely though, `i10` is not used so it will be DCEd anyway.
Still, it might be nice to concretize broadcast extents to 1 just to
clean up these expressions if they appear downstream. I tried that
hastily but ran into some issues so I'll leave it for another PR.

Fixes #2359
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants