Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ComputeAtMap build failure No potential concrete_id's found for disjoint set #2514

Open
jjsjann123 opened this issue Jul 2, 2024 · 3 comments
Assignees

Comments

@jjsjann123
Copy link
Collaborator

Repro c++ script.

TEST_F(NVFuserTest, Repro) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
                  
  TensorView* tv0 = makeContigConcreteTensor({1, 10});
  TensorView* tv1 = makeContigConcreteTensor({4, 10});
  TensorView* tv2 = relu(tv0);
  TensorView* tv3 = neg(tv0);
  TensorView* tv4 = add(tv2, tv3);
  TensorView* tv5 = cat({tv4, tv1}, /*dim=*/0);
  TensorView* tv6 = add(tv5, tv0);
                  
  fusion->addInput(tv0);
  fusion->addInput(tv1);
  fusion->addOutput(tv6);
                  
  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); 
  at::Tensor t0 = at::randn({1, 10}, options);
  at::Tensor t1 = at::randn({4, 10}, options);
  std::vector<c10::IValue> aten_inputs = {t0, t1};
                  
  FusionExecutorCache fec(std::move(fusion));
  auto out_tensors = fec.runFusionWithInputs(aten_inputs);
                  
  testValidate(fec.fusion(), out_tensors, aten_inputs, __LINE__, __FILE__);
}

hits an assert vvv

C++ exception with description "!maybe_concrete_ids.vector().empty() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/compute_at_map.cpp":1111, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. No potential concrete_id's found for disjoint set { bS0{1}; bS4{1}; bS6{1}; bS8{1}; bS10{1}rf; iS14{( ( 0 + 1 ) + 4 )}rf; iS11{5}rf; iS16{5}; iS18{5} }
Exception raised from computeConcreteId at /opt/pytorch/nvfuser/csrc/compute_at_map.cpp:1111 (most recent call first):

Full trace:

#0  0x00007fffb9c824a1 in __cxa_throw () from /lib/x86_64-linux-gnu/libstdc++.so.6
#1  0x0000555555ac7b9e in nvfuser::nvfCheckFail (func=0x555556d22ce6 "computeConcreteId", 
    file=0x555556d21e90 "/opt/pytorch/nvfuser/csrc/compute_at_map.cpp", line=1111, 
    msg="!maybe_concrete_ids.vector().empty() INTERNAL ASSERT FAILED at \"/opt/pytorch/nvfuser/csrc/compute_at_map.cpp\":1111, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/"...)
    at /opt/pytorch/nvfuser/csrc/exceptions.cpp:274
#2  0x0000555555ac7db3 in nvfuser::nvfErrorFail (func=0x555556d22ce6 "computeConcreteId", 
    file=0x555556d21e90 "/opt/pytorch/nvfuser/csrc/compute_at_map.cpp", line=1111, 
    condMsg=0x555556d22f28 "!maybe_concrete_ids.vector().empty() INTERNAL ASSERT FAILED at \"/opt/pytorch/nvfuser/csrc/compute_at_map.cpp\":1111, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/"..., 
    userMsg="No potential concrete_id's found for disjoint set { bS0{1}; bS4{1}; bS6{1}; bS8{1}; bS10{1}rf; iS14{( ( 0 + 1 ) + 4 )}rf; iS11{5}rf; iS16{5}; iS18{5} }") at /opt/pytorch/nvfuser/csrc/exceptions.cpp:300
#3  0x0000555555799b62 in nvfuser::ComputeAtMap::computeConcreteId (this=0x7fffffffb720, id=0x7fff21816100, 
    mode=nvfuser::IdMappingMode::PERMISSIVE) at /opt/pytorch/nvfuser/csrc/compute_at_map.cpp:1111
#4  0x000055555579a3b9 in nvfuser::ComputeAtMap::buildConcreteIds (this=0x7fffffffb720)
    at /opt/pytorch/nvfuser/csrc/compute_at_map.cpp:1179
#5  0x0000555555797ec2 in nvfuser::ComputeAtMap::build (this=0x7fffffffb720, fusion=0x7fff2106f380)
    at /opt/pytorch/nvfuser/csrc/compute_at_map.cpp:780
#6  0x0000555555797df7 in nvfuser::ComputeAtMap::ComputeAtMap (this=0x7fffffffb720, fusion=0x7fff2106f380, 
    allow_self_mapping=false) at /opt/pytorch/nvfuser/csrc/compute_at_map.cpp:775
#7  0x00005555560c877e in nvfuser::pointwise_utils::DomainMap::DomainMap (this=0x7fffffffb710, fusion=0x7fff2106f380)
    at /opt/pytorch/nvfuser/csrc/scheduler/pointwise_utils.cpp:113
#8  0x00005555560beb57 in nvfuser::(anonymous namespace)::DomainMap::DomainMap (this=0x7fffffffb710)
    at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:121
#9  0x00005555560bebc0 in nvfuser::getReferenceTensorView (fusion=0x7fff2106f380)
    at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:525
#10 0x00005555560bec74 in nvfuser::hasReferenceTensorView (fusion=0x7fff2106f380)
    at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:531
#11 0x00005555560bbbd3 in nvfuser::PointWiseScheduler::canScheduleCompileTime (fusion=0x7fff2106f380)
    at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:38
#12 0x00005555560ef1ab in nvfuser::(anonymous namespace)::checkCanSchedule<nvfuser::PointWiseScheduler> (fusion=0x7fff2106f380, 
    runtime_info=..., data_cache=0x0) at /opt/pytorch/nvfuser/csrc/scheduler/registry.cpp:190
#13 0x00005555560ed7ac in nvfuser::SchedulerEntry::canSchedule (sh=nvfuser::ScheduleHeuristic::PointWise, fusion=0x7fff2106f380, 
    runtime_info=..., data_cache=0x0) at /opt/pytorch/nvfuser/csrc/scheduler/registry.cpp:210
#14 0x00005555560edcfe in nvfuser::SchedulerEntry::proposeHeuristics (fusion=0x7fff2106f380, runtime_info=...)
    at /opt/pytorch/nvfuser/csrc/scheduler/registry.cpp:299
#15 0x0000555555c0505e in nvfuser::SegmentCandidateFinder::segment (fusion=std::unique_ptr<nvfuser::Fusion> = {...}, 
    inputs=0x7fffffffcde0, runtime_info=...) at /opt/pytorch/nvfuser/csrc/fusion_segmenter.cpp:1987
#16 0x0000555555e92968 in nvfuser::FusionKernelRuntime::FusionKernelRuntime (this=0x7fff218e40c0, 
    fusion=std::unique_ptr<nvfuser::Fusion> = {...}, args=..., serde_buffer=0x0, 
    forced_index_type=std::optional<nvfuser::PrimDataType> [no contained value], fusion_id=0, concrete_id=1, runtime_id=0, 
    auto_schedule=true) at /opt/pytorch/nvfuser/csrc/kernel_cache.cpp:1061
#17 0x0000555555ea30b5 in std::make_unique<nvfuser::FusionKernelRuntime, std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >, nvfuser::KernelArgumentHolder const&, decltype(nullptr), std::optional<nvfuser::PrimDataType>&, long&, long&, unsigned long, bool const&>(std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >&&, nvfuser::KernelArgumentHolder const&, decltype(nullptr)&&, std::optional<nvfuser::PrimDataType>&, long&, long&, unsigned long&&, bool const&) ()
    at /usr/include/c++/11/bits/unique_ptr.h:962
#18 0x0000555555e91115 in nvfuser::FusionExecutorCache::getKernelRuntimeFor (this=0x7fffffffd3c0, args=..., 
    forced_index_type=std::optional<nvfuser::PrimDataType> [no contained value]) at /opt/pytorch/nvfuser/csrc/kernel_cache.cpp:833
#19 0x0000555555e8f393 in nvfuser::FusionExecutorCache::runFusionWithInputs (this=0x7fffffffd3c0, inputs=..., 
    forced_index_type=std::optional<nvfuser::PrimDataType> [no contained value], 
    selected_device=std::optional<signed char> [no contained value]) at /opt/pytorch/nvfuser/csrc/kernel_cache.cpp:547

@jjsjann123
Copy link
Collaborator Author

In the example above:

(gdb) print id->fusion()->printMath(0)
Inputs:
  T0_g[ bS0{1}, iS1{10} ], float
  T1_g[ iS2{4}, iS3{10} ], float
Outputs:
  T8_g[ iS18{5}, iS19{10} ], float

%kernel_math {
Resize: bS10{1}rf by 0 and 4 -> iS11{5}rf
i21 = 0 + 1;
i30 = i21 + 4;
i37 = i21 + 4;
Resize: iS13{4}rf by ( 0 + 1 ) and 0 -> iS14{( ( 0 + 1 ) + 4 )}rf
T2_l[ bS4{1}, iS5{10} ]
   = relu(T0_g[ bS0{1}, iS1{10} ]);
T3_l[ bS6{1}, iS7{10} ]
   = -T0_g[ bS0{1}, iS1{10} ];
T4_l[ bS8{1}, iS9{10} ]
   = T2_l[ bS4{1}, iS5{10} ]
   + T3_l[ bS6{1}, iS7{10} ];
T5_l[ iS11{5}rf, iS12{10} ]
   = pad( T4_l[ bS8{1}, iS9{10} ], {0, 4, 0, 0} )
T6_l[ iS14{( ( 0 + 1 ) + 4 )}rf, iS15{10} ]
   = pad( T1_g[ iS2{4}, iS3{10} ], {i21, 0, 0, 0} )
T7_l[ iS16{5}, iS17{10} ]
   = cat( T5_l[ iS11{5}rf, iS12{10} ], T6_l[ iS14{( ( 0 + 1 ) + 4 )}rf, iS15{10} ], 0 )
T8_g[ iS18{5}, iS19{10} ]
   = T7_l[ iS16{5}, iS17{10} ]
   + T0_g[ bS0{1}, iS1{10} ];
b50 = blockIdx.x >= 0;
b52 = gridDim.x > 0;
b54 = blockIdx.x < gridDim.x;
b56 = blockIdx.y >= 0;
b58 = gridDim.y > 0;
b60 = blockIdx.y < gridDim.y;
b62 = blockIdx.z >= 0;
b64 = gridDim.z > 0;
b66 = blockIdx.z < gridDim.z;
b68 = threadIdx.x >= 0;
b70 = blockDim.x > 0;
b72 = threadIdx.x < blockDim.x;
b74 = threadIdx.y >= 0;
b76 = blockDim.y > 0;
b78 = threadIdx.y < blockDim.y;
b80 = threadIdx.z >= 0;
b82 = blockDim.z > 0;
b84 = threadIdx.z < blockDim.z;
s86 = getMetaData(T0_g[ bS0{1}, iS1{10} ])
s87 = getMetaData(T1_g[ iS2{4}, iS3{10} ])
}
No potential concrete_id's found for disjoint set { bS0{1}; bS4{1}; bS6{1}; bS8{1}; bS10{1}rf; iS14{( ( 0 + 1 ) + 4 )}rf; iS11{5}rf; iS16{5}; iS18{5} }")

Looks like after the last swap we ended up with an empty maybe_concrete_ids somehow.

@jjsjann123
Copy link
Collaborator Author

This one still repros. I thought we can use NVFUSER_ENABLE=id_model to avoid compute_at but that doesn't seem to help here. cc'ing @naoyam

@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Dec 12, 2024

Found another smaller example hitting the same error.

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.ops.slice(T0, start_indices=[0, 0], end_indices=[1, 4], strides=[1, 1])
    T2 = fd.ops.add(T1, T0)
    fd.add_output(T2)

with FusionDefinition() as fd:
    nvfuser_fusion_id1(fd)

t0 = torch.randn(4, 4, dtype=torch.float, device='cuda:0')

cpp test

TEST_F(PointwiseTest, SlicePlayground) {
  preseg_passes::OptimizationPassGuard<preseg_passes::MarkAliasesPreparePass>
      optimization_guard(false);
  auto fusion_ptr = std::make_unique<Fusion>();
  auto fusion = fusion_ptr.get();
  FusionGuard fg(fusion);

  // tv0 {i1, i0}
  TensorView* tv0 = makeContigTensor(2);
  fusion->addInput(tv0);

  // b3 = resize(i2 + 4 + 4)
  // tv2 {i1, b2}
  auto tv2 = slice(
      tv0,
      {Slice(),
       {IrBuilder::create<Val>(0L),
        IrBuilder::create<Val>(1L),
        IrBuilder::create<Val>(1L)}});
  // tv3 {i1, i0}
  auto tv3 = add(tv0, tv2);
  fusion->addOutput(tv3);

  // validate generated kernel
  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn({2, 4}, options);
  std::vector<c10::IValue> aten_inputs = {t0};
  FusionExecutorCache executor_cache(std::move(fusion_ptr));
  auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);
  testValidate(fusion, cg_outputs, aten_inputs, __LINE__, __FILE__);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants