diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 9cb9088c776..43e7f4ca563 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -219,6 +219,7 @@ bool DomainMap::areAllTargetIdsCoveredBy( // it's safe for target_tv to have them. std::unordered_set covered_source_ids; for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) { + NVF_ERROR(source_id_ref->definition() == nullptr || id->definition()->isA()); covered_source_ids.insert(source_id_ref); } // It's safe to have unmapped broadcast IterDomain. There're quite a few tests diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 713c2810e33..11678ee9e8f 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -826,9 +826,6 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { // tv4 is a valid reference EXPECT_TRUE(domain_map.isValidReference(tv4)); - // check reference tv selection - EXPECT_FALSE(domain_map.findReferenceTensorView() == tv4); - // validate generated kernel auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({4, 7}, options); @@ -873,9 +870,6 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { // tv4 is a valid reference EXPECT_TRUE(domain_map.isValidReference(tv4)); - // check reference tv selection - EXPECT_FALSE(domain_map.findReferenceTensorView() == tv4); - // validate generated kernel auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({2, 4}, options);