From f0ce0e33bedde15063d3400dcd799c555ea0d4bb Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 2 Dec 2024 11:45:24 -0800 Subject: [PATCH 01/58] will this work? --- csrc/scheduler/pointwise.cpp | 6 +++--- csrc/scheduler/pointwise_utils.h | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index bc7a0fb32c6..f0de641be59 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -43,7 +43,7 @@ class DomainMap : public pointwise_utils::DomainMap { if (isValidReference(output_tv) && hasMinimumSize(output_tv, minimum_num_axes) && !output_tv->isFusionInput()) { - int64_t n_dims = pointwise_utils::nRootDims(output_tv); + int64_t n_dims = pointwise_utils::nLogicalDims(output_tv); if (n_dims > max_dims) { result = output_tv; max_dims = n_dims; @@ -529,11 +529,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { int64_t max_dims = 0; for (auto inp : input_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims); } for (auto out : output_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(out), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims); } // If everything is zero dim tensors, just return. diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index 56db0ee0806..dac71d23e8e 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -63,11 +63,11 @@ class DomainMap { // Returns number of non-reduction/non-broadcas/non-device dims in logical // domain -inline int64_t nRootDims(const TensorView* tv) { +inline int64_t nLogicalDims(const TensorView* tv) { auto logical_dom = tv->getLogicalDomain(); int64_t tv_n_dims = 0; for (auto dim : logical_dom) { - if (!dim->isReduction() && !dim->isBroadcast() && !dim->isDeviceDim()) { + if (!dim->isReduction() && !dim->isBroadcast() && !id->hasExpandedExtent() && !dim->isDeviceDim()) { tv_n_dims++; } } From 70e31bf4b9108d1a316489a244ff03af280d14bc Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 2 Dec 2024 11:50:25 -0800 Subject: [PATCH 02/58] errr --- csrc/scheduler/pointwise_utils.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index dac71d23e8e..46bece8ccf3 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -67,7 +67,8 @@ inline int64_t nLogicalDims(const TensorView* tv) { auto logical_dom = tv->getLogicalDomain(); int64_t tv_n_dims = 0; for (auto dim : logical_dom) { - if (!dim->isReduction() && !dim->isBroadcast() && !id->hasExpandedExtent() && !dim->isDeviceDim()) { + if (!dim->isReduction() && !dim->isBroadcast() && + !dim->hasExpandedExtent() && !dim->isDeviceDim()) { tv_n_dims++; } } From 5f09e36f60ae18d96471416af4ccf8ced637ba9d Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 2 Dec 2024 11:51:43 -0800 Subject: [PATCH 03/58] missed a few renaming --- csrc/scheduler/transpose.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 7e320f99a91..93be8309113 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -171,7 +171,7 @@ class DomainMap : public pointwise_utils::DomainMap { int64_t max_dims = -1; for (auto tv : group) { if (isValidReference(tv)) { - int64_t dims = (int64_t)pointwise_utils::nRootDims(tv); + int64_t dims = (int64_t)pointwise_utils::nLogicalDims(tv); if (dims > max_dims) { result = tv; max_dims = dims; @@ -1045,11 +1045,11 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { int64_t max_dims = 0; for (auto inp : input_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims); } for (auto out : output_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(out), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims); } // If everything is zero dim tensors, just return. From 9ad9edb4533cf66c4fdd7bdac37dc066a2988298 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 2 Dec 2024 12:41:33 -0800 Subject: [PATCH 04/58] WIP --- csrc/scheduler/pointwise_utils.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 2f4f119fc46..972e609f31c 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -64,7 +64,9 @@ bool canIgnoreIndexedInputDomainID( TensorView* input_tv, IterDomain* root_id, const ComputeAtMap& ca_map) { - NVF_ERROR(input_tv->isFusionInput()); + if (!input_tv->isFusionInput()) { + return false; + } for (auto use : input_tv->uses()) { if (auto select = dynamic_cast(use)) { if (root_id != select->getIndexedID()) { @@ -245,6 +247,11 @@ bool DomainMap::isValidReference(TensorView* tv) const { return false; } } + for (auto output_tv : ir_utils::filterByType(fusion_->outputs())) { + if (!areAllInputIdsMappedTo(output_tv, tv)) { + return false; + } + } return true; } From 654020106e37fc1c4c25d9819780479aaaea6ad9 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 2 Dec 2024 14:37:28 -0800 Subject: [PATCH 05/58] test added --- tests/python/test_pointwise.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/python/test_pointwise.py b/tests/python/test_pointwise.py index e24a139555e..a2a66ced616 100644 --- a/tests/python/test_pointwise.py +++ b/tests/python/test_pointwise.py @@ -421,3 +421,36 @@ def fusion_func(fd: FusionDefinition): with pytest.raises(RuntimeError, match="No executor supports provided fusion."): _ = fd.execute(inputs) + + +def test_pointwise_issue(): + inputs = [ + torch.testing.make_tensor( + (1, 2048, 512), dtype=torch.bfloat16, device="cuda:0" + ), + ] + + def fusion_func(fd: FusionDefinition): + T3 = fd.define_tensor( + shape=[1, 2048, 512], + contiguity=[None, True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[2, 1, 0], + ) + T33 = fd.ops.reshape(T3, new_shape=[1, 2048, 8, 64]) + T34 = fd.ops.permute(T33, dims=[0, 2, 1, 3]) + T185 = fd.ops.broadcast_in_dim( + T34, shape=[1, 8, 1, 2048, 64], broadcast_dims=[0, 1, 3, 4] + ) + T192 = fd.ops.broadcast_in_dim( + T185, shape=[1, 8, 4, 2048, 64], broadcast_dims=[0, 1, 2, 3, 4] + ) + T198 = fd.ops.reshape(T192, new_shape=[1, 32, 2048, 64]) + fd.add_output(T34) + fd.add_output(T198) + + with FusionDefinition() as fd: + fusion_func(fd) + + _ = fd.execute(inputs) From ed56c75c11f77a267a36dd54539da1810fde8c62 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 3 Dec 2024 20:44:49 -0800 Subject: [PATCH 06/58] WIP --- csrc/scheduler/pointwise_utils.cpp | 32 ++++++++++++++++++++++++++---- csrc/scheduler/pointwise_utils.h | 2 +- csrc/scheduler/transpose.cpp | 4 ++-- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 972e609f31c..72bb15051a1 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -144,6 +144,25 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) return in_concrete_ids.empty(); } +bool DomainMap::areAllProducerIdsMappedTo(TensorView* target, TensorView* reference_tv) + const { + + // reverse traversal to collect all producer ids of reference_tv + VectorOfUniqueEntries>> + all_covered_exact_sets; + std::for_each(reference_tv->getLogicalDomain().begin(), reference_tv->getLogicalDomain().end(), [&](IterDomain* id) { + all_covered_exact_sets.pushBack(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); + }); + all_covered_exact_sets.pushBack(ca_map_.getAllDisjointSetProducers(all_covered_exact_sets)); + + for (auto id : target->getLogicalDomain()) { + auto inp_ids = getInputDisjointSetsOf(id); + // check if all inp_ids are mapped in all_covered_exact_sets + } + + return true; +} + // Reference domains must exactly match with the input domains. See // also PR #661 IterDomain* DomainMap::getMappedInputConcreteID( @@ -236,7 +255,7 @@ IterDomain* DomainMap::anyMapped( // Determine if output TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input. -bool DomainMap::isValidReference(TensorView* tv) const { +bool DomainMap::isValidReference(TensorView* tv, bool check_output_coverage) const { for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { if (input_tv->uses().empty()) { continue; @@ -247,9 +266,14 @@ bool DomainMap::isValidReference(TensorView* tv) const { return false; } } - for (auto output_tv : ir_utils::filterByType(fusion_->outputs())) { - if (!areAllInputIdsMappedTo(output_tv, tv)) { - return false; + if (check_output_coverage) { + for (auto output_tv : ir_utils::filterByType(fusion_->outputs())) { + if (output_tv == tv) { + continue; + } + if (!areAllProducerIdsMappedTo(output_tv, tv)) { + return false; + } } } return true; diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index 46bece8ccf3..d02c40b36be 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -29,7 +29,7 @@ class DomainMap { // Determine if a TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input. - bool isValidReference(TensorView* tv) const; + bool isValidReference(TensorView* tv, bool check_output_coverage=true) const; protected: // Determine if all IterDomains are mapped between input and the given tvs diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 93be8309113..f3ec0322682 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -166,7 +166,7 @@ class DomainMap : public pointwise_utils::DomainMap { // be a traversal path to an input. This type of analysis is // expected to be possible much more easily with the new indexing // graph (#32), so we should revisit once it becomes available. - TensorView* findReferenceFor(const std::vector& group) const { + TensorView* findReferenceFor(const std::vector& group, bool check_output = true) const { TensorView* result = nullptr; int64_t max_dims = -1; for (auto tv : group) { @@ -202,7 +202,7 @@ class DomainMap : public pointwise_utils::DomainMap { return false; } auto ref1 = domain_map.findReferenceFor(grouped_inputs_outputs[0]); - auto ref2 = domain_map.findReferenceFor(grouped_inputs_outputs[1]); + auto ref2 = domain_map.findReferenceFor(grouped_inputs_outputs[1], false); if (ref1 == nullptr || ref2 == nullptr) { return false; } From f1e7e0ac8008952b5b3a291ec3a1f029e4886431 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 4 Dec 2024 10:48:30 -0800 Subject: [PATCH 07/58] WIP --- csrc/scheduler/pointwise_utils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 72bb15051a1..3d5b18a8050 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -144,7 +144,7 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) return in_concrete_ids.empty(); } -bool DomainMap::areAllProducerIdsMappedTo(TensorView* target, TensorView* reference_tv) +bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* reference_tv) const { // reverse traversal to collect all producer ids of reference_tv @@ -155,7 +155,7 @@ bool DomainMap::areAllProducerIdsMappedTo(TensorView* target, TensorView* refere }); all_covered_exact_sets.pushBack(ca_map_.getAllDisjointSetProducers(all_covered_exact_sets)); - for (auto id : target->getLogicalDomain()) { + for (auto id : target_tv->getLogicalDomain()) { auto inp_ids = getInputDisjointSetsOf(id); // check if all inp_ids are mapped in all_covered_exact_sets } From 9d174c9b50e739dd945f93a20dbcf99bf705d8ff Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 4 Dec 2024 10:52:59 -0800 Subject: [PATCH 08/58] declaration --- csrc/scheduler/pointwise_utils.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index d02c40b36be..1310d149818 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -36,6 +36,10 @@ class DomainMap { bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) const; + // Determine if all IterDomains are mapped between input and the given tvs + bool areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* reference_tv) + const; + virtual IterDomain* getMappedInputConcreteID( const std::unordered_set& in_concrete_ids, IterDomain* out_id) const; From bf425ebf7ecf956271e5b79e9cf51baaeea7c337 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 4 Dec 2024 11:07:16 -0800 Subject: [PATCH 09/58] WIP --- csrc/scheduler/pointwise_utils.cpp | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 3d5b18a8050..08ad83d6e7d 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -155,9 +155,27 @@ bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* ref }); all_covered_exact_sets.pushBack(ca_map_.getAllDisjointSetProducers(all_covered_exact_sets)); + std::vector covered_concrete_ids; + for (const auto& exact_set_ptr : all_covered_exact_sets) { + auto exact_concrete_id = ca_map_.getConcreteMappedID( + exact_set_ptr->front(), IdMappingMode::EXACT); + covered_concrete_ids.push_back(exact_concrete_id); + } + for (auto id : target_tv->getLogicalDomain()) { - auto inp_ids = getInputDisjointSetsOf(id); - // check if all inp_ids are mapped in all_covered_exact_sets + if (getMappedInputConcreteID(covered_concrete_ids, id) != nullptr) { + continue; + } + + auto inp_id_sets = ca_map_.getInputDisjointSetsOf(id); + // check if all inp_ids are mapped in covered_concrete_ids + for (auto inp_id_set : inp_id_sets) { + auto exact_inp_id = ca_map_.getConcreteMappedID( + inp_id_set->front(), IdMappingMode::EXACT); + if (getMappedInputConcreteID(covered_concrete_ids, exact_inp_id) == nullptr) { + return false; + } + } } return true; From aef13ac9d6686bbc27b3a1e9949a6c6c94ef8c80 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 4 Dec 2024 11:14:58 -0800 Subject: [PATCH 10/58] WIP --- csrc/scheduler/pointwise_utils.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 08ad83d6e7d..4da4ef3a681 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -155,11 +155,11 @@ bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* ref }); all_covered_exact_sets.pushBack(ca_map_.getAllDisjointSetProducers(all_covered_exact_sets)); - std::vector covered_concrete_ids; + std::unordered_set covered_concrete_ids; for (const auto& exact_set_ptr : all_covered_exact_sets) { auto exact_concrete_id = ca_map_.getConcreteMappedID( exact_set_ptr->front(), IdMappingMode::EXACT); - covered_concrete_ids.push_back(exact_concrete_id); + covered_concrete_ids.insert(exact_concrete_id); } for (auto id : target_tv->getLogicalDomain()) { @@ -167,7 +167,7 @@ bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* ref continue; } - auto inp_id_sets = ca_map_.getInputDisjointSetsOf(id); + auto inp_id_sets = ca_map_.getAllDisjointSetProducers(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); // check if all inp_ids are mapped in covered_concrete_ids for (auto inp_id_set : inp_id_sets) { auto exact_inp_id = ca_map_.getConcreteMappedID( From a9ae516785f7f98f50f9f9ddf9c80a79b7405bd5 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 4 Dec 2024 12:49:45 -0800 Subject: [PATCH 11/58] refactor the traversal --- csrc/scheduler/pointwise_utils.cpp | 38 ++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 4da4ef3a681..4f9a23f8227 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -157,25 +157,39 @@ bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* ref std::unordered_set covered_concrete_ids; for (const auto& exact_set_ptr : all_covered_exact_sets) { - auto exact_concrete_id = ca_map_.getConcreteMappedID( - exact_set_ptr->front(), IdMappingMode::EXACT); - covered_concrete_ids.insert(exact_concrete_id); + covered_concrete_ids.insert(exact_set_ptr->front()); } + auto producers = ca_map_.idGraph().producers(); for (auto id : target_tv->getLogicalDomain()) { - if (getMappedInputConcreteID(covered_concrete_ids, id) != nullptr) { - continue; - } + std::stack frontier; + frontier.push_back(id); - auto inp_id_sets = ca_map_.getAllDisjointSetProducers(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); - // check if all inp_ids are mapped in covered_concrete_ids - for (auto inp_id_set : inp_id_sets) { - auto exact_inp_id = ca_map_.getConcreteMappedID( - inp_id_set->front(), IdMappingMode::EXACT); - if (getMappedInputConcreteID(covered_concrete_ids, exact_inp_id) == nullptr) { + while (!frontier.empty()) { + IterDomain* t = frontier.back(); + frontier.pop_back(); + if (getMappedInputConcreteID(covered_concrete_ids, t) != nullptr) { + continue; + } + + auto p_iter = producers.find(t); + // no definition, mismatch found, we'll return false; + if (p_iter == producers.end()) { return false; } + + std::copy(p_iter->begin(), p_iter->end(), std::back_inserter(frontier)); } + + // // auto inp_id_sets = ca_map_.getAllDisjointSetProducers({ca_map_.disjointSetOf(id, IdMappingMode::EXACT)}); + // // check if all inp_ids are mapped in covered_concrete_ids + // for (auto inp_id_set : inp_id_sets) { + // // auto exact_inp_id = ca_map_.getConcreteMappedID( + // // inp_id_set->front(), IdMappingMode::EXACT); + // if (getMappedInputConcreteID(covered_concrete_ids, exact_inp_id) == nullptr) { + // return false; + // } + // } } return true; From d9e8dc0243044a58f9c4c47a8c4f73a4201b0702 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 4 Dec 2024 13:59:06 -0800 Subject: [PATCH 12/58] WIP --- csrc/scheduler/pointwise_utils.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 4f9a23f8227..019644b916e 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -162,7 +162,7 @@ bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* ref auto producers = ca_map_.idGraph().producers(); for (auto id : target_tv->getLogicalDomain()) { - std::stack frontier; + std::vector frontier; frontier.push_back(id); while (!frontier.empty()) { @@ -174,11 +174,11 @@ bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* ref auto p_iter = producers.find(t); // no definition, mismatch found, we'll return false; - if (p_iter == producers.end()) { + if (p_iter == producers.end() || p_iter->second.empty()) { return false; } - std::copy(p_iter->begin(), p_iter->end(), std::back_inserter(frontier)); + std::copy(p_iter->second.begin(), p_iter->second.end(), std::back_inserter(frontier)); } // // auto inp_id_sets = ca_map_.getAllDisjointSetProducers({ca_map_.disjointSetOf(id, IdMappingMode::EXACT)}); From 73338062e6258a8edb2d5df948f2937b10e650ec Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 4 Dec 2024 14:09:22 -0800 Subject: [PATCH 13/58] scratch that, it's getting out of hand --- csrc/scheduler/pointwise_utils.cpp | 53 +----------------------------- csrc/scheduler/pointwise_utils.h | 4 --- 2 files changed, 1 insertion(+), 56 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 019644b916e..f1d4de9253b 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -144,57 +144,6 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) return in_concrete_ids.empty(); } -bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* reference_tv) - const { - - // reverse traversal to collect all producer ids of reference_tv - VectorOfUniqueEntries>> - all_covered_exact_sets; - std::for_each(reference_tv->getLogicalDomain().begin(), reference_tv->getLogicalDomain().end(), [&](IterDomain* id) { - all_covered_exact_sets.pushBack(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); - }); - all_covered_exact_sets.pushBack(ca_map_.getAllDisjointSetProducers(all_covered_exact_sets)); - - std::unordered_set covered_concrete_ids; - for (const auto& exact_set_ptr : all_covered_exact_sets) { - covered_concrete_ids.insert(exact_set_ptr->front()); - } - - auto producers = ca_map_.idGraph().producers(); - for (auto id : target_tv->getLogicalDomain()) { - std::vector frontier; - frontier.push_back(id); - - while (!frontier.empty()) { - IterDomain* t = frontier.back(); - frontier.pop_back(); - if (getMappedInputConcreteID(covered_concrete_ids, t) != nullptr) { - continue; - } - - auto p_iter = producers.find(t); - // no definition, mismatch found, we'll return false; - if (p_iter == producers.end() || p_iter->second.empty()) { - return false; - } - - std::copy(p_iter->second.begin(), p_iter->second.end(), std::back_inserter(frontier)); - } - - // // auto inp_id_sets = ca_map_.getAllDisjointSetProducers({ca_map_.disjointSetOf(id, IdMappingMode::EXACT)}); - // // check if all inp_ids are mapped in covered_concrete_ids - // for (auto inp_id_set : inp_id_sets) { - // // auto exact_inp_id = ca_map_.getConcreteMappedID( - // // inp_id_set->front(), IdMappingMode::EXACT); - // if (getMappedInputConcreteID(covered_concrete_ids, exact_inp_id) == nullptr) { - // return false; - // } - // } - } - - return true; -} - // Reference domains must exactly match with the input domains. See // also PR #661 IterDomain* DomainMap::getMappedInputConcreteID( @@ -303,7 +252,7 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_output_coverage) con if (output_tv == tv) { continue; } - if (!areAllProducerIdsMappedTo(output_tv, tv)) { + if (DependencyCheck::isDependencyOf(tv, output_tv)) { return false; } } diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index 1310d149818..d02c40b36be 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -36,10 +36,6 @@ class DomainMap { bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) const; - // Determine if all IterDomains are mapped between input and the given tvs - bool areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* reference_tv) - const; - virtual IterDomain* getMappedInputConcreteID( const std::unordered_set& in_concrete_ids, IterDomain* out_id) const; From f6ad363e8169e516e4d38ef6175397b0016dd3db Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 4 Dec 2024 14:17:59 -0800 Subject: [PATCH 14/58] Revert "scratch that, it's getting out of hand" This reverts commit 73338062e6258a8edb2d5df948f2937b10e650ec. --- csrc/scheduler/pointwise_utils.cpp | 53 +++++++++++++++++++++++++++++- csrc/scheduler/pointwise_utils.h | 4 +++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index f1d4de9253b..019644b916e 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -144,6 +144,57 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) return in_concrete_ids.empty(); } +bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* reference_tv) + const { + + // reverse traversal to collect all producer ids of reference_tv + VectorOfUniqueEntries>> + all_covered_exact_sets; + std::for_each(reference_tv->getLogicalDomain().begin(), reference_tv->getLogicalDomain().end(), [&](IterDomain* id) { + all_covered_exact_sets.pushBack(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); + }); + all_covered_exact_sets.pushBack(ca_map_.getAllDisjointSetProducers(all_covered_exact_sets)); + + std::unordered_set covered_concrete_ids; + for (const auto& exact_set_ptr : all_covered_exact_sets) { + covered_concrete_ids.insert(exact_set_ptr->front()); + } + + auto producers = ca_map_.idGraph().producers(); + for (auto id : target_tv->getLogicalDomain()) { + std::vector frontier; + frontier.push_back(id); + + while (!frontier.empty()) { + IterDomain* t = frontier.back(); + frontier.pop_back(); + if (getMappedInputConcreteID(covered_concrete_ids, t) != nullptr) { + continue; + } + + auto p_iter = producers.find(t); + // no definition, mismatch found, we'll return false; + if (p_iter == producers.end() || p_iter->second.empty()) { + return false; + } + + std::copy(p_iter->second.begin(), p_iter->second.end(), std::back_inserter(frontier)); + } + + // // auto inp_id_sets = ca_map_.getAllDisjointSetProducers({ca_map_.disjointSetOf(id, IdMappingMode::EXACT)}); + // // check if all inp_ids are mapped in covered_concrete_ids + // for (auto inp_id_set : inp_id_sets) { + // // auto exact_inp_id = ca_map_.getConcreteMappedID( + // // inp_id_set->front(), IdMappingMode::EXACT); + // if (getMappedInputConcreteID(covered_concrete_ids, exact_inp_id) == nullptr) { + // return false; + // } + // } + } + + return true; +} + // Reference domains must exactly match with the input domains. See // also PR #661 IterDomain* DomainMap::getMappedInputConcreteID( @@ -252,7 +303,7 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_output_coverage) con if (output_tv == tv) { continue; } - if (DependencyCheck::isDependencyOf(tv, output_tv)) { + if (!areAllProducerIdsMappedTo(output_tv, tv)) { return false; } } diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index d02c40b36be..1310d149818 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -36,6 +36,10 @@ class DomainMap { bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) const; + // Determine if all IterDomains are mapped between input and the given tvs + bool areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* reference_tv) + const; + virtual IterDomain* getMappedInputConcreteID( const std::unordered_set& in_concrete_ids, IterDomain* out_id) const; From cef0b831820740699ec09e61f9f8d2580ec45530 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 4 Dec 2024 14:46:48 -0800 Subject: [PATCH 15/58] try focus on expanded dimensions --- csrc/scheduler/pointwise_utils.cpp | 47 +++++++++++++----------------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 019644b916e..5cf606629e1 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -155,41 +155,34 @@ bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* ref }); all_covered_exact_sets.pushBack(ca_map_.getAllDisjointSetProducers(all_covered_exact_sets)); - std::unordered_set covered_concrete_ids; + std::unordered_set covered_expanded_ids; for (const auto& exact_set_ptr : all_covered_exact_sets) { - covered_concrete_ids.insert(exact_set_ptr->front()); + IterDomain* id = exact_set_ptr->front(); + if (id->hasExpandedExtent()) { + covered_expanded_ids.insert(exact_set_ptr->front()); + } } - - auto producers = ca_map_.idGraph().producers(); + // stand alone expanded id can be left alone without causing replay issue. for (auto id : target_tv->getLogicalDomain()) { - std::vector frontier; - frontier.push_back(id); + if (id->hasExpandedExtent()) { + covered_expanded_ids.insert(exact_set_ptr->front()); + } + } - while (!frontier.empty()) { - IterDomain* t = frontier.back(); - frontier.pop_back(); - if (getMappedInputConcreteID(covered_concrete_ids, t) != nullptr) { - continue; - } + VectorOfUniqueEntries>> + all_expected_exact_sets; + std::for_each(target_tv->getLogicalDomain().begin(), target_tv->getLogicalDomain().end(), [&](IterDomain* id) { + all_expected_exact_sets.pushBack(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); + }); + all_expected_exact_sets.pushBack(ca_map_.getAllDisjointSetProducers(all_expected_exact_sets)); - auto p_iter = producers.find(t); - // no definition, mismatch found, we'll return false; - if (p_iter == producers.end() || p_iter->second.empty()) { + for (const auto& exact_set_ptr : all_expected_exact_sets) { + IterDomain* id = exact_set_ptr->front(); + if (id->hasExpandedExtent()) { + if (!getMappedInputConcreteID(covered_expanded_ids, id)) { return false; } - - std::copy(p_iter->second.begin(), p_iter->second.end(), std::back_inserter(frontier)); } - - // // auto inp_id_sets = ca_map_.getAllDisjointSetProducers({ca_map_.disjointSetOf(id, IdMappingMode::EXACT)}); - // // check if all inp_ids are mapped in covered_concrete_ids - // for (auto inp_id_set : inp_id_sets) { - // // auto exact_inp_id = ca_map_.getConcreteMappedID( - // // inp_id_set->front(), IdMappingMode::EXACT); - // if (getMappedInputConcreteID(covered_concrete_ids, exact_inp_id) == nullptr) { - // return false; - // } - // } } return true; From a557a8b2a380a1f0012135ca65c810c2f931ff8e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 4 Dec 2024 17:11:48 -0800 Subject: [PATCH 16/58] wip --- csrc/scheduler/pointwise_utils.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 5cf606629e1..69c58b5380c 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -155,17 +155,18 @@ bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* ref }); all_covered_exact_sets.pushBack(ca_map_.getAllDisjointSetProducers(all_covered_exact_sets)); - std::unordered_set covered_expanded_ids; + std::unordered_set covered_source_ids; for (const auto& exact_set_ptr : all_covered_exact_sets) { IterDomain* id = exact_set_ptr->front(); - if (id->hasExpandedExtent()) { - covered_expanded_ids.insert(exact_set_ptr->front()); + // if (id->hasExpandedExtent()) { + if (ca_map_.uniqueExactDefinitions(id).empty()) { + covered_source_ids.insert(id); } } // stand alone expanded id can be left alone without causing replay issue. for (auto id : target_tv->getLogicalDomain()) { - if (id->hasExpandedExtent()) { - covered_expanded_ids.insert(exact_set_ptr->front()); + if (ca_map_.uniqueExactDefinitions(id).empty()) { + covered_source_ids.insert(id); } } @@ -178,8 +179,8 @@ bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* ref for (const auto& exact_set_ptr : all_expected_exact_sets) { IterDomain* id = exact_set_ptr->front(); - if (id->hasExpandedExtent()) { - if (!getMappedInputConcreteID(covered_expanded_ids, id)) { + if (ca_map_.uniqueExactDefinitions(id).empty()) { + if (!getMappedInputConcreteID(covered_source_ids, id)) { return false; } } From f88ebf719bab7181ae7c4601119070cb9190d1fd Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 4 Dec 2024 22:30:13 -0800 Subject: [PATCH 17/58] lintrunner --- csrc/scheduler/pointwise_utils.cpp | 38 ++++++++++++++++++++---------- csrc/scheduler/pointwise_utils.h | 8 ++++--- csrc/scheduler/transpose.cpp | 4 +++- 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 69c58b5380c..3c70deacfa2 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -144,16 +144,21 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) return in_concrete_ids.empty(); } -bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* reference_tv) - const { - +bool DomainMap::areAllProducerIdsMappedTo( + TensorView* target_tv, + TensorView* reference_tv) const { // reverse traversal to collect all producer ids of reference_tv VectorOfUniqueEntries>> all_covered_exact_sets; - std::for_each(reference_tv->getLogicalDomain().begin(), reference_tv->getLogicalDomain().end(), [&](IterDomain* id) { - all_covered_exact_sets.pushBack(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); - }); - all_covered_exact_sets.pushBack(ca_map_.getAllDisjointSetProducers(all_covered_exact_sets)); + std::for_each( + reference_tv->getLogicalDomain().begin(), + reference_tv->getLogicalDomain().end(), + [&](IterDomain* id) { + all_covered_exact_sets.pushBack( + ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); + }); + all_covered_exact_sets.pushBack( + ca_map_.getAllDisjointSetProducers(all_covered_exact_sets)); std::unordered_set covered_source_ids; for (const auto& exact_set_ptr : all_covered_exact_sets) { @@ -172,10 +177,15 @@ bool DomainMap::areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* ref VectorOfUniqueEntries>> all_expected_exact_sets; - std::for_each(target_tv->getLogicalDomain().begin(), target_tv->getLogicalDomain().end(), [&](IterDomain* id) { - all_expected_exact_sets.pushBack(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); - }); - all_expected_exact_sets.pushBack(ca_map_.getAllDisjointSetProducers(all_expected_exact_sets)); + std::for_each( + target_tv->getLogicalDomain().begin(), + target_tv->getLogicalDomain().end(), + [&](IterDomain* id) { + all_expected_exact_sets.pushBack( + ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); + }); + all_expected_exact_sets.pushBack( + ca_map_.getAllDisjointSetProducers(all_expected_exact_sets)); for (const auto& exact_set_ptr : all_expected_exact_sets) { IterDomain* id = exact_set_ptr->front(); @@ -281,7 +291,8 @@ IterDomain* DomainMap::anyMapped( // Determine if output TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input. -bool DomainMap::isValidReference(TensorView* tv, bool check_output_coverage) const { +bool DomainMap::isValidReference(TensorView* tv, bool check_output_coverage) + const { for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { if (input_tv->uses().empty()) { continue; @@ -293,7 +304,8 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_output_coverage) con } } if (check_output_coverage) { - for (auto output_tv : ir_utils::filterByType(fusion_->outputs())) { + for (auto output_tv : + ir_utils::filterByType(fusion_->outputs())) { if (output_tv == tv) { continue; } diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index 1310d149818..dea904b41be 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -29,7 +29,8 @@ class DomainMap { // Determine if a TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input. - bool isValidReference(TensorView* tv, bool check_output_coverage=true) const; + bool isValidReference(TensorView* tv, bool check_output_coverage = true) + const; protected: // Determine if all IterDomains are mapped between input and the given tvs @@ -37,8 +38,9 @@ class DomainMap { const; // Determine if all IterDomains are mapped between input and the given tvs - bool areAllProducerIdsMappedTo(TensorView* target_tv, TensorView* reference_tv) - const; + bool areAllProducerIdsMappedTo( + TensorView* target_tv, + TensorView* reference_tv) const; virtual IterDomain* getMappedInputConcreteID( const std::unordered_set& in_concrete_ids, diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index f3ec0322682..99c94c74a93 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -166,7 +166,9 @@ class DomainMap : public pointwise_utils::DomainMap { // be a traversal path to an input. This type of analysis is // expected to be possible much more easily with the new indexing // graph (#32), so we should revisit once it becomes available. - TensorView* findReferenceFor(const std::vector& group, bool check_output = true) const { + TensorView* findReferenceFor( + const std::vector& group, + bool check_output = true) const { TensorView* result = nullptr; int64_t max_dims = -1; for (auto tv : group) { From ea89b69baff6dfe0c62d97e269e4e218dad4d0b8 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 5 Dec 2024 02:16:28 -0800 Subject: [PATCH 18/58] comment added --- csrc/scheduler/pointwise_utils.cpp | 148 ++++++++++++++++++----------- csrc/scheduler/pointwise_utils.h | 15 +-- csrc/scheduler/transpose.cpp | 6 +- tests/python/test_pointwise.py | 6 +- 4 files changed, 109 insertions(+), 66 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 3c70deacfa2..95605e68ce9 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -144,58 +144,99 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) return in_concrete_ids.empty(); } -bool DomainMap::areAllProducerIdsMappedTo( - TensorView* target_tv, +// Note: ideally we would want to chck that reference_tv contains (not +// necessarily maps) all iter domain in output_tv, so that transformation +// applied on reference_tv can be propagated to output_tv. But we don't have +// an easy way to check that. +// Instead of that, this function checks that all source iter domains involved +// in transformation on output_tv is covered by reference_tv. We do so by +// traverse all disjoint set producers on both tvs and filter them with +// `ca_map_.uniqueExactDefinitions(id).empty()`. +// +// ------ +// +// e.g 0. +// T34 [i0, i1] +// T185 [i0, b2, i1] = broadcast(T34) +// T192 [i0, b3(ex), i1] = expand(T185) +// T198 [i0, b3(ex)*i1] = reshape(T192) +// output(T34) +// output(T198) +// +// if we consider taking T34 as reference_tv. T198 is the output_tv. We can't +// replay T34's transform of merging all the dimensions to T198, since b3(ex)*i1 +// can't be reversed. The check in this function would give us T34 with source +// i0, i1; where T198 would have source i0, b3, i1, where b3 isn't contained in +// T34. Hence we'll reject this referenc_tv. +// +// ------ +// +// e.g 1. +// T0 [i0, i1] +// T1 [i2, i0, i1] +// T2 [i0*i1] = reshape(T0) +// T3 [b3, i0, i1] = broadcast(T0) +// T4 [i2, i0, i1] = add(T1, T3) +// output(T2) +// output(T4) +// +// the example above should be able to pick T4 as reference_tv. T2's source i0, +// i1 are both contained by the source of T4, so this example could be scheduled +// as a single fusion. +bool DomainMap::areAllOutputIdsMappedTo( + TensorView* output_tv, TensorView* reference_tv) const { - // reverse traversal to collect all producer ids of reference_tv - VectorOfUniqueEntries>> - all_covered_exact_sets; - std::for_each( - reference_tv->getLogicalDomain().begin(), - reference_tv->getLogicalDomain().end(), - [&](IterDomain* id) { - all_covered_exact_sets.pushBack( - ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); - }); - all_covered_exact_sets.pushBack( - ca_map_.getAllDisjointSetProducers(all_covered_exact_sets)); - + // traverse back to collect all disjoint set producers from the logical domain + // of tv. + auto get_source_producers = [&ca_map_](TensorView* tv) { + VectorOfUniqueEntries>> + all_producer_sets; + std::for_each( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + [&](IterDomain* id) { + all_producer_sets.pushBack( + ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); + }); + all_producer_sets.pushBack( + ca_map_.getAllDisjointSetProducers(all_producer_sets)); + + std::vector source_ids; + std::copy_if( + all_producer_sets.begin(), + all_producer_sets.end(), + std::back_inserter(source_ids), + [&ca_map_](const std::shared_ptr>& + producer_set_ptr) { + IterDomain* id = producer_set_ptr->front(); + return ca_map_.uniqueExactDefinitions(id).empty(); + }); + return source_ids; + }; + + // this contains all source iter domain that's covered by reference_tv, so + // it's safe for output_tv to have them. std::unordered_set covered_source_ids; - for (const auto& exact_set_ptr : all_covered_exact_sets) { - IterDomain* id = exact_set_ptr->front(); - // if (id->hasExpandedExtent()) { - if (ca_map_.uniqueExactDefinitions(id).empty()) { - covered_source_ids.insert(id); - } + for (IterDomain* id : get_source_producers(reference_tv)) { + covered_source_ids.insert(id); } - // stand alone expanded id can be left alone without causing replay issue. - for (auto id : target_tv->getLogicalDomain()) { + // it's safe to have source iter domain on output_tv that's not in + // reference_tv, since they are not involved in any transforms. + for (auto id : output_tv->getLogicalDomain()) { if (ca_map_.uniqueExactDefinitions(id).empty()) { covered_source_ids.insert(id); } } - VectorOfUniqueEntries>> - all_expected_exact_sets; - std::for_each( - target_tv->getLogicalDomain().begin(), - target_tv->getLogicalDomain().end(), - [&](IterDomain* id) { - all_expected_exact_sets.pushBack( - ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); - }); - all_expected_exact_sets.pushBack( - ca_map_.getAllDisjointSetProducers(all_expected_exact_sets)); - - for (const auto& exact_set_ptr : all_expected_exact_sets) { - IterDomain* id = exact_set_ptr->front(); - if (ca_map_.uniqueExactDefinitions(id).empty()) { - if (!getMappedInputConcreteID(covered_source_ids, id)) { - return false; - } + // Check all source iter domain involved in producing output_tv + for (IterDomain* id : get_source_producers(output_tv)) { + // if we find any source id that's not contained, it's possible our + // propagation would fail since transformation involving this iter domain + // can't be resolved. + if (!getMappedInputConcreteID(covered_source_ids, id)) { + return false; } } - return true; } @@ -290,9 +331,9 @@ IterDomain* DomainMap::anyMapped( } // Determine if output TensorView is a valid reference tensor for this fusion. -// The reference tensor must map to all the iterDomains in each input. -bool DomainMap::isValidReference(TensorView* tv, bool check_output_coverage) - const { +// The reference tensor must map to all the iterDomains in each input and +// output. +bool DomainMap::isValidReference(TensorView* tv) const { for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { if (input_tv->uses().empty()) { continue; @@ -303,15 +344,14 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_output_coverage) return false; } } - if (check_output_coverage) { - for (auto output_tv : - ir_utils::filterByType(fusion_->outputs())) { - if (output_tv == tv) { - continue; - } - if (!areAllProducerIdsMappedTo(output_tv, tv)) { - return false; - } + for (auto output_tv : + ir_utils::filterByType(fusion_->outputs())) { + // no need to check for self. + if (output_tv == tv) { + continue; + } + if (!areAllOutputIdsMappedTo(output_tv, tv)) { + return false; } } return true; diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index dea904b41be..d90ac36baad 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -28,19 +28,20 @@ class DomainMap { } // Determine if a TensorView is a valid reference tensor for this fusion. - // The reference tensor must map to all the iterDomains in each input. - bool isValidReference(TensorView* tv, bool check_output_coverage = true) - const; + // The reference tensor must map to all the iterDomains in each input and + // output. + bool isValidReference(TensorView* tv) const; protected: // Determine if all IterDomains are mapped between input and the given tvs bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) const; - // Determine if all IterDomains are mapped between input and the given tvs - bool areAllProducerIdsMappedTo( - TensorView* target_tv, - TensorView* reference_tv) const; + // Determine if all source IterDomains in output_tv are mapped to the + // reference_tv, this ensures transformations from reference_tv can be + // propagated to output_tv + bool areAllOutputIdsMappedTo(TensorView* output_tv, TensorView* reference_tv) + const; virtual IterDomain* getMappedInputConcreteID( const std::unordered_set& in_concrete_ids, diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 99c94c74a93..93be8309113 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -166,9 +166,7 @@ class DomainMap : public pointwise_utils::DomainMap { // be a traversal path to an input. This type of analysis is // expected to be possible much more easily with the new indexing // graph (#32), so we should revisit once it becomes available. - TensorView* findReferenceFor( - const std::vector& group, - bool check_output = true) const { + TensorView* findReferenceFor(const std::vector& group) const { TensorView* result = nullptr; int64_t max_dims = -1; for (auto tv : group) { @@ -204,7 +202,7 @@ class DomainMap : public pointwise_utils::DomainMap { return false; } auto ref1 = domain_map.findReferenceFor(grouped_inputs_outputs[0]); - auto ref2 = domain_map.findReferenceFor(grouped_inputs_outputs[1], false); + auto ref2 = domain_map.findReferenceFor(grouped_inputs_outputs[1]); if (ref1 == nullptr || ref2 == nullptr) { return false; } diff --git a/tests/python/test_pointwise.py b/tests/python/test_pointwise.py index a2a66ced616..025bfcf4f7b 100644 --- a/tests/python/test_pointwise.py +++ b/tests/python/test_pointwise.py @@ -423,13 +423,17 @@ def fusion_func(fd: FusionDefinition): _ = fd.execute(inputs) -def test_pointwise_issue(): +def test_pointwise_issue_3512(): inputs = [ torch.testing.make_tensor( (1, 2048, 512), dtype=torch.bfloat16, device="cuda:0" ), ] + # T34 and T198 are both candidate for reference tv in pointwise scheduler. + # We can only pick T198 for scheduling though, because a expanded dimension + # is merged by the reshape that produces T198, which means transformation + # on T34 wouldn't be able to propagate from T192 to T198. def fusion_func(fd: FusionDefinition): T3 = fd.define_tensor( shape=[1, 2048, 512], From 0fc0dc10cf3d289e2ce4cf0f68ac223e25b1bf24 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 5 Dec 2024 02:36:14 -0800 Subject: [PATCH 19/58] fixing --- csrc/scheduler/pointwise_utils.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 95605e68ce9..91b7eb8a012 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -188,7 +188,7 @@ bool DomainMap::areAllOutputIdsMappedTo( TensorView* reference_tv) const { // traverse back to collect all disjoint set producers from the logical domain // of tv. - auto get_source_producers = [&ca_map_](TensorView* tv) { + auto get_source_producers = [this](TensorView* tv) { VectorOfUniqueEntries>> all_producer_sets; std::for_each( @@ -202,14 +202,16 @@ bool DomainMap::areAllOutputIdsMappedTo( ca_map_.getAllDisjointSetProducers(all_producer_sets)); std::vector source_ids; - std::copy_if( - all_producer_sets.begin(), - all_producer_sets.end(), - std::back_inserter(source_ids), - [&ca_map_](const std::shared_ptr>& - producer_set_ptr) { + std::for_each( + all_producer_sets.vector().begin(), + all_producer_sets.vector().end(), + [&source_ids, + this](const std::shared_ptr>& + producer_set_ptr) { IterDomain* id = producer_set_ptr->front(); - return ca_map_.uniqueExactDefinitions(id).empty(); + if (ca_map_.uniqueExactDefinitions(id).empty()) { + source_ids.push_back(id); + } }); return source_ids; }; From 94e2ddf08a1eab9e2aa82bef1499eff61ce35d74 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 6 Dec 2024 09:39:56 -0800 Subject: [PATCH 20/58] Apply suggestions from code review Co-authored-by: Naoya Maruyama Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> --- csrc/scheduler/pointwise_utils.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 91b7eb8a012..bcdb30715c9 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -144,8 +144,8 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) return in_concrete_ids.empty(); } -// Note: ideally we would want to chck that reference_tv contains (not -// necessarily maps) all iter domain in output_tv, so that transformation +// Note: ideally we would want to check that reference_tv contains (not +// necessarily maps) all iter domains in output_tv, so that transformation // applied on reference_tv can be propagated to output_tv. But we don't have // an easy way to check that. // Instead of that, this function checks that all source iter domains involved @@ -167,7 +167,7 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) // replay T34's transform of merging all the dimensions to T198, since b3(ex)*i1 // can't be reversed. The check in this function would give us T34 with source // i0, i1; where T198 would have source i0, b3, i1, where b3 isn't contained in -// T34. Hence we'll reject this referenc_tv. +// T34. Hence we'll reject this reference_tv. // // ------ // From 63284b6545c15acda5ea98dd6ed09ac08ea07c71 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 6 Dec 2024 09:48:02 -0800 Subject: [PATCH 21/58] reverting unintended changes --- csrc/scheduler/pointwise_utils.cpp | 4 +--- csrc/scheduler/pointwise_utils.h | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index bcdb30715c9..355c379e62f 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -64,9 +64,7 @@ bool canIgnoreIndexedInputDomainID( TensorView* input_tv, IterDomain* root_id, const ComputeAtMap& ca_map) { - if (!input_tv->isFusionInput()) { - return false; - } + NVF_ERROR(input_tv->isFusionInput()); for (auto use : input_tv->uses()) { if (auto select = dynamic_cast(use)) { if (root_id != select->getIndexedID()) { diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index d90ac36baad..86973d9b96e 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -74,8 +74,7 @@ inline int64_t nLogicalDims(const TensorView* tv) { auto logical_dom = tv->getLogicalDomain(); int64_t tv_n_dims = 0; for (auto dim : logical_dom) { - if (!dim->isReduction() && !dim->isBroadcast() && - !dim->hasExpandedExtent() && !dim->isDeviceDim()) { + if (!dim->isReduction() && !dim->isBroadcast() && !dim->isDeviceDim()) { tv_n_dims++; } } From e39ec5820254217da0a3351707a9749f9a94dd53 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 6 Dec 2024 11:10:17 -0800 Subject: [PATCH 22/58] adding unit tests --- tests/cpp/test_pointwise.cpp | 37 ++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index bb1c6bd7bfb..c9b3d43e393 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -773,4 +774,40 @@ TEST_F(PointwiseTest, VectorizePadLoweringPermuted) { EXPECT_TRUE(found_vectorize); testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } + +TEST_F(PointwiseTest, DomainMapTestEg0) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + auto tv1 = relu(tv0); + fusion->addOutput(tv1); + auto tv2 = broadcast(tv1, {false, true, false}); + auto tv3 = expand( + tv2, + {tv1->axis(0)->extent(), + IrBuilder::create(4), + tv1->axis(2)->extent()}); + auto tv4 = reshape(tv3, {2, 4, 3}, {2, 12}); + fusion->addOutput(tv4); + + pointwise_utils::DomainMap domain_map(fusion); + // tv1 can't map to tv4 + EXPECT_FALSE(domain_map.areAllOutputIdsMappedTo(tv4, tv1)); + + // tv1 can map to tv3 + EXPECT_TRUE(domain_map.areAllOutputIdsMappedTo(tv3, tv1)); + + // tv4 can map to tv1 + EXPECT_TRUE(domain_map.areAllOutputIdsMappedTo(tv1, tv4)); + + // tv1 is not a valid reference + EXPECT_FALSE(domain_map.isValidReference(tv1)); + + // tv4 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv4)); +} + } // namespace nvfuser From fa4d8ababfd90e335d57cd05a69d5b505c0b6456 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 6 Dec 2024 11:22:29 -0800 Subject: [PATCH 23/58] WIP --- tests/cpp/test_pointwise.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index c9b3d43e393..e1334a8dade 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -50,6 +50,14 @@ bool hasVectorizationCache(TensorView* tv) { return false; } +class DomainMapUnitTest : public : pointwise_utils::DomainMap { + public: + bool testOutputMapping(TensorView* output_tv, TensorView* reference_tv) + const { + return areAllOutputIdsMappedTo(output_tv, reference_tv); + } +}; + } // namespace TEST_F(PointwiseTest, VectorizeStrideContiguity2D) { @@ -793,15 +801,15 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { auto tv4 = reshape(tv3, {2, 4, 3}, {2, 12}); fusion->addOutput(tv4); - pointwise_utils::DomainMap domain_map(fusion); + DomainMapUnitTest domain_map(fusion); // tv1 can't map to tv4 - EXPECT_FALSE(domain_map.areAllOutputIdsMappedTo(tv4, tv1)); + EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv1)); // tv1 can map to tv3 - EXPECT_TRUE(domain_map.areAllOutputIdsMappedTo(tv3, tv1)); + EXPECT_TRUE(domain_map.testOutputMapping(tv3, tv1)); // tv4 can map to tv1 - EXPECT_TRUE(domain_map.areAllOutputIdsMappedTo(tv1, tv4)); + EXPECT_TRUE(domain_map.testOutputMapping(tv1, tv4)); // tv1 is not a valid reference EXPECT_FALSE(domain_map.isValidReference(tv1)); From 66bc53349998bda2ddc645169787bc8602b67f4c Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 6 Dec 2024 12:07:15 -0800 Subject: [PATCH 24/58] unit test --- tests/cpp/test_pointwise.cpp | 37 +++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index e1334a8dade..9737d67037a 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -50,8 +50,9 @@ bool hasVectorizationCache(TensorView* tv) { return false; } -class DomainMapUnitTest : public : pointwise_utils::DomainMap { +class DomainMapUnitTest : public pointwise_utils::DomainMap { public: + DomainMapUnitTest(Fusion* fusion) : pointwise_utils::DomainMap(fusion){}; bool testOutputMapping(TensorView* output_tv, TensorView* reference_tv) const { return areAllOutputIdsMappedTo(output_tv, reference_tv); @@ -795,9 +796,9 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { auto tv2 = broadcast(tv1, {false, true, false}); auto tv3 = expand( tv2, - {tv1->axis(0)->extent(), + {tv2->axis(0)->extent(), IrBuilder::create(4), - tv1->axis(2)->extent()}); + tv2->axis(2)->extent()}); auto tv4 = reshape(tv3, {2, 4, 3}, {2, 12}); fusion->addOutput(tv4); @@ -818,4 +819,34 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { EXPECT_TRUE(domain_map.isValidReference(tv4)); } +TEST_F(PointwiseTest, DomainMapTestEg1) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + TensorView* tv1 = makeContigTensor(3); + fusion->addInput(tv1); + auto tv2 = reshape(tv0, {2, 4}, {8}); + fusion->addOutput(tv2); + + auto tv3 = broadcast(tv0, {true, false, false}); + auto tv4 = add(tv1, tv3); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // tv2 can't map to tv4 + EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv2)); + + // tv2 can map to tv4 + EXPECT_TRUE(domain_map.testOutputMapping(tv2, tv4)); + + // tv2 is not a valid reference + EXPECT_FALSE(domain_map.isValidReference(tv2)); + + // tv4 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv4)); +} + } // namespace nvfuser From 26054c3fb9c91d1b5c96adeac20364716e5635f9 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 6 Dec 2024 12:12:57 -0800 Subject: [PATCH 25/58] WIP --- tests/cpp/test_pointwise.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 9737d67037a..6f1d3a8d5f9 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -836,13 +836,13 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { fusion->addOutput(tv4); DomainMapUnitTest domain_map(fusion); - // tv2 can't map to tv4 - EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv2)); + // tv2 can map to tv4, because the missing tv4->axis(0) is a dangling ID. + EXPECT_TRUE(domain_map.testOutputMapping(tv4, tv2)); // tv2 can map to tv4 EXPECT_TRUE(domain_map.testOutputMapping(tv2, tv4)); - // tv2 is not a valid reference + // However, tv2 is not a valid reference, since it doesn't cover all input IDs EXPECT_FALSE(domain_map.isValidReference(tv2)); // tv4 is a valid reference From 3d2b926ada94089fef21b610467d8a3e6ff0d6ea Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 6 Dec 2024 16:22:54 -0800 Subject: [PATCH 26/58] WIP, seems to found another issue here --- tests/cpp/test_pointwise.cpp | 44 ++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 6f1d3a8d5f9..408a24833d6 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -849,4 +849,48 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { EXPECT_TRUE(domain_map.isValidReference(tv4)); } +TEST_F(PointwiseTest, DomainMapFactory) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv1); + + auto tv2 = broadcast(tv0, {true, true, false}); + auto tv3 = add(tv2, tv1); + fusion->addOutput(tv3); + + auto size_val = IrBuilder::create(4.0, DataType::Int); + auto one_val = IrBuilder::create(1, DataType::Int); + auto tv4 = rand({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); + auto tv5 = add(tv2, tv4); + fusion->addOutput(tv5); + + DomainMapUnitTest domain_map(fusion); + + // tv2 can't map to tv4 + // EXPECT_TRUE(domain_map.testOutputMapping(tv4, tv2)); + + // tv2 can map to tv4 + // EXPECT_TRUE(domain_map.testOutputMapping(tv2, tv4)); + + // tv3 should not be a valid reference + // EXPECT_TRUE(domain_map.isValidReference(tv3)); + // EXPECT_TRUE(domain_map.isValidReference(tv5)); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = + at::empty_strided({25}, {1}, options); + at::Tensor input1 = + at::empty_strided({7, 25}, {25, 1}, options); + auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); + // EXPECT_EQ(getVecSizeForPointwise(executor_cache), 4); + testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); +} + } // namespace nvfuser From bb659f83e7d79de8db2e2da62768fabfdbab4240 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sat, 7 Dec 2024 12:19:07 -0800 Subject: [PATCH 27/58] revert unsafe exception --- csrc/scheduler/pointwise_utils.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 355c379e62f..8c5e8ec72ff 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -222,11 +222,11 @@ bool DomainMap::areAllOutputIdsMappedTo( } // it's safe to have source iter domain on output_tv that's not in // reference_tv, since they are not involved in any transforms. - for (auto id : output_tv->getLogicalDomain()) { - if (ca_map_.uniqueExactDefinitions(id).empty()) { - covered_source_ids.insert(id); - } - } + // for (auto id : output_tv->getLogicalDomain()) { + // if (ca_map_.uniqueExactDefinitions(id).empty()) { + // covered_source_ids.insert(id); + // } + // } // Check all source iter domain involved in producing output_tv for (IterDomain* id : get_source_producers(output_tv)) { From 45bb78536896c180eddd92186b1f46b74f84ef5f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sat, 7 Dec 2024 14:32:47 -0800 Subject: [PATCH 28/58] moving tests to uniform --- tests/cpp/test_pointwise.cpp | 44 ------------------------------ tests/cpp/test_rng.cpp | 53 ++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 44 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 408a24833d6..6f1d3a8d5f9 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -849,48 +849,4 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { EXPECT_TRUE(domain_map.isValidReference(tv4)); } -TEST_F(PointwiseTest, DomainMapFactory) { - auto fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = makeContigTensor(1); - fusion->addInput(tv0); - TensorView* tv1 = makeContigTensor(2); - fusion->addInput(tv1); - - auto tv2 = broadcast(tv0, {true, true, false}); - auto tv3 = add(tv2, tv1); - fusion->addOutput(tv3); - - auto size_val = IrBuilder::create(4.0, DataType::Int); - auto one_val = IrBuilder::create(1, DataType::Int); - auto tv4 = rand({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); - auto tv5 = add(tv2, tv4); - fusion->addOutput(tv5); - - DomainMapUnitTest domain_map(fusion); - - // tv2 can't map to tv4 - // EXPECT_TRUE(domain_map.testOutputMapping(tv4, tv2)); - - // tv2 can map to tv4 - // EXPECT_TRUE(domain_map.testOutputMapping(tv2, tv4)); - - // tv3 should not be a valid reference - // EXPECT_TRUE(domain_map.isValidReference(tv3)); - // EXPECT_TRUE(domain_map.isValidReference(tv5)); - - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input0 = - at::empty_strided({25}, {1}, options); - at::Tensor input1 = - at::empty_strided({7, 25}, {25, 1}, options); - auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); - // EXPECT_EQ(getVecSizeForPointwise(executor_cache), 4); - testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); -} - } // namespace nvfuser diff --git a/tests/cpp/test_rng.cpp b/tests/cpp/test_rng.cpp index c8f7c545ae6..00a2c795478 100644 --- a/tests/cpp/test_rng.cpp +++ b/tests/cpp/test_rng.cpp @@ -553,4 +553,57 @@ TEST_F(RNGTest, DifferentOffsets) { } } +TEST_F(PointwiseTest, DomainMapFactory) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv1); + + auto tv2 = broadcast(tv0, {true, true, false}); + auto tv3 = add(tv2, tv1); + fusion->addOutput(tv3); + + auto size_val = IrBuilder::create(4.0, DataType::Int); + auto one_val = IrBuilder::create(1, DataType::Int); + auto tv4 = rand({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); + auto tv5 = add(tv2, tv4); + fusion->addOutput(tv5); + + DomainMapUnitTest domain_map(fusion); + + // tv2 can't map to tv4 + // EXPECT_TRUE(domain_map.testOutputMapping(tv4, tv2)); + + // tv2 can map to tv4 + // EXPECT_TRUE(domain_map.testOutputMapping(tv2, tv4)); + + // tv3 should not be a valid reference + // EXPECT_TRUE(domain_map.isValidReference(tv3)); + // EXPECT_TRUE(domain_map.isValidReference(tv5)); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = + at::empty_strided({25}, {1}, options); + at::Tensor input1 = + at::empty_strided({7, 25}, {25, 1}, options); + + at::manual_seed(0); + auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); + // EXPECT_EQ(getVecSizeForPointwise(executor_cache), 4); + testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); + + auto t3 = (input0 + input1).as_strided({1, 7, 25}, {7*25, 25, 1}); + at::manual_seed(0); + auto t4 = generate_uniform(4 * 25, at::kFloat).as_strided({4, 1, 25}, {25, 25, 1}); + auto t5 = t4 + input0; + + testValidate(fusion, cg_outputs, {input0, input1}, {t3, t5}, __LINE__, __FILE__); +} + } // namespace nvfuser From b7440865021b4f088a8db1628ebcdc0e63c14afd Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sat, 7 Dec 2024 14:36:13 -0800 Subject: [PATCH 29/58] Revert "moving tests to uniform" This reverts commit 45bb78536896c180eddd92186b1f46b74f84ef5f. --- tests/cpp/test_pointwise.cpp | 44 ++++++++++++++++++++++++++++++ tests/cpp/test_rng.cpp | 53 ------------------------------------ 2 files changed, 44 insertions(+), 53 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 6f1d3a8d5f9..408a24833d6 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -849,4 +849,48 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { EXPECT_TRUE(domain_map.isValidReference(tv4)); } +TEST_F(PointwiseTest, DomainMapFactory) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv1); + + auto tv2 = broadcast(tv0, {true, true, false}); + auto tv3 = add(tv2, tv1); + fusion->addOutput(tv3); + + auto size_val = IrBuilder::create(4.0, DataType::Int); + auto one_val = IrBuilder::create(1, DataType::Int); + auto tv4 = rand({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); + auto tv5 = add(tv2, tv4); + fusion->addOutput(tv5); + + DomainMapUnitTest domain_map(fusion); + + // tv2 can't map to tv4 + // EXPECT_TRUE(domain_map.testOutputMapping(tv4, tv2)); + + // tv2 can map to tv4 + // EXPECT_TRUE(domain_map.testOutputMapping(tv2, tv4)); + + // tv3 should not be a valid reference + // EXPECT_TRUE(domain_map.isValidReference(tv3)); + // EXPECT_TRUE(domain_map.isValidReference(tv5)); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = + at::empty_strided({25}, {1}, options); + at::Tensor input1 = + at::empty_strided({7, 25}, {25, 1}, options); + auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); + // EXPECT_EQ(getVecSizeForPointwise(executor_cache), 4); + testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/tests/cpp/test_rng.cpp b/tests/cpp/test_rng.cpp index 00a2c795478..c8f7c545ae6 100644 --- a/tests/cpp/test_rng.cpp +++ b/tests/cpp/test_rng.cpp @@ -553,57 +553,4 @@ TEST_F(RNGTest, DifferentOffsets) { } } -TEST_F(PointwiseTest, DomainMapFactory) { - auto fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = makeContigTensor(1); - fusion->addInput(tv0); - TensorView* tv1 = makeContigTensor(2); - fusion->addInput(tv1); - - auto tv2 = broadcast(tv0, {true, true, false}); - auto tv3 = add(tv2, tv1); - fusion->addOutput(tv3); - - auto size_val = IrBuilder::create(4.0, DataType::Int); - auto one_val = IrBuilder::create(1, DataType::Int); - auto tv4 = rand({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); - auto tv5 = add(tv2, tv4); - fusion->addOutput(tv5); - - DomainMapUnitTest domain_map(fusion); - - // tv2 can't map to tv4 - // EXPECT_TRUE(domain_map.testOutputMapping(tv4, tv2)); - - // tv2 can map to tv4 - // EXPECT_TRUE(domain_map.testOutputMapping(tv2, tv4)); - - // tv3 should not be a valid reference - // EXPECT_TRUE(domain_map.isValidReference(tv3)); - // EXPECT_TRUE(domain_map.isValidReference(tv5)); - - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input0 = - at::empty_strided({25}, {1}, options); - at::Tensor input1 = - at::empty_strided({7, 25}, {25, 1}, options); - - at::manual_seed(0); - auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); - // EXPECT_EQ(getVecSizeForPointwise(executor_cache), 4); - testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); - - auto t3 = (input0 + input1).as_strided({1, 7, 25}, {7*25, 25, 1}); - at::manual_seed(0); - auto t4 = generate_uniform(4 * 25, at::kFloat).as_strided({4, 1, 25}, {25, 25, 1}); - auto t5 = t4 + input0; - - testValidate(fusion, cg_outputs, {input0, input1}, {t3, t5}, __LINE__, __FILE__); -} - } // namespace nvfuser From 3a16c65d76408226f978c9ef6a308d961388f0a9 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sat, 7 Dec 2024 14:37:02 -0800 Subject: [PATCH 30/58] do not use random for validation --- tests/cpp/test_pointwise.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 408a24833d6..26c9fdf9167 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -865,7 +865,7 @@ TEST_F(PointwiseTest, DomainMapFactory) { auto size_val = IrBuilder::create(4.0, DataType::Int); auto one_val = IrBuilder::create(1, DataType::Int); - auto tv4 = rand({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); + auto tv4 = ones({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); auto tv5 = add(tv2, tv4); fusion->addOutput(tv5); From 3b9c97f85b421b88c2e923c6dc6c67b01f602fef Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sat, 7 Dec 2024 14:48:17 -0800 Subject: [PATCH 31/58] fixing tests --- tests/cpp/test_pointwise.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 26c9fdf9167..6214e5430d9 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -806,8 +806,8 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { // tv1 can't map to tv4 EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv1)); - // tv1 can map to tv3 - EXPECT_TRUE(domain_map.testOutputMapping(tv3, tv1)); + // tv1 can't map to tv3, because it's missing the broadcast dimension + EXPECT_FALSE(domain_map.testOutputMapping(tv3, tv1)); // tv4 can map to tv1 EXPECT_TRUE(domain_map.testOutputMapping(tv1, tv4)); @@ -836,8 +836,8 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { fusion->addOutput(tv4); DomainMapUnitTest domain_map(fusion); - // tv2 can map to tv4, because the missing tv4->axis(0) is a dangling ID. - EXPECT_TRUE(domain_map.testOutputMapping(tv4, tv2)); + // tv2 can't map to tv4, because it misses tv4->axis(0) + EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv2)); // tv2 can map to tv4 EXPECT_TRUE(domain_map.testOutputMapping(tv2, tv4)); @@ -872,14 +872,13 @@ TEST_F(PointwiseTest, DomainMapFactory) { DomainMapUnitTest domain_map(fusion); // tv2 can't map to tv4 - // EXPECT_TRUE(domain_map.testOutputMapping(tv4, tv2)); - - // tv2 can map to tv4 - // EXPECT_TRUE(domain_map.testOutputMapping(tv2, tv4)); + EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv2)); + // tv4 can't map to tv2 + EXPECT_FALSE(domain_map.testOutputMapping(tv2, tv4)); // tv3 should not be a valid reference - // EXPECT_TRUE(domain_map.isValidReference(tv3)); - // EXPECT_TRUE(domain_map.isValidReference(tv5)); + EXPECT_FALSE(domain_map.isValidReference(tv3)); + EXPECT_FALSE(domain_map.isValidReference(tv5)); FusionExecutorCache executor_cache(std::move(fusion_ptr)); From 54176a7d902b19332bcfa2a57912d3965bc1a77e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sat, 7 Dec 2024 15:02:13 -0800 Subject: [PATCH 32/58] fixing tests and comments --- csrc/scheduler/pointwise_utils.cpp | 21 ++++++++++++++------- tests/cpp/test_pointwise.cpp | 8 +++----- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 8c5e8ec72ff..3c495e30f1b 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -220,13 +220,20 @@ bool DomainMap::areAllOutputIdsMappedTo( for (IterDomain* id : get_source_producers(reference_tv)) { covered_source_ids.insert(id); } - // it's safe to have source iter domain on output_tv that's not in - // reference_tv, since they are not involved in any transforms. - // for (auto id : output_tv->getLogicalDomain()) { - // if (ca_map_.uniqueExactDefinitions(id).empty()) { - // covered_source_ids.insert(id); - // } - // } + // Note: there's certain cases where it's safe to have dangling IDs, + // e.g + // T34 [i0, i1] + // T185 [i0, b2, i1] = broadcast(T34) + // T192 [i0, b3(ex), i1] = expand(T185) + // It's safe to propagate T34 to T192, since b3(ex) is not involved in the + // propagation. But this isn't generally safe. If the above example is changed + // to e.g + // T34 [i0, i1] + // T185 [i0, b2, i1] = broadcast(T34) + // T186 [i0, i4, i1] = ones({i0, i4, i1}) + // T193 [i0, i4, i1] = add(T34, T186) + // It's unsafe to propagate from T34 to T193, see issue + // https://github.com/NVIDIA/Fuser/issues/3542 // Check all source iter domain involved in producing output_tv for (IterDomain* id : get_source_producers(output_tv)) { diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 6214e5430d9..08574c133b3 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -803,7 +803,7 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { fusion->addOutput(tv4); DomainMapUnitTest domain_map(fusion); - // tv1 can't map to tv4 + // tv1 can't map to tv4, since we are missing the expanded ID EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv1)); // tv1 can't map to tv3, because it's missing the broadcast dimension @@ -883,10 +883,8 @@ TEST_F(PointwiseTest, DomainMapFactory) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input0 = - at::empty_strided({25}, {1}, options); - at::Tensor input1 = - at::empty_strided({7, 25}, {25, 1}, options); + at::Tensor input0 = at::empty_strided({25}, {1}, options); + at::Tensor input1 = at::empty_strided({7, 25}, {25, 1}, options); auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); // EXPECT_EQ(getVecSizeForPointwise(executor_cache), 4); testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); From 5b668d6e02c55d998f14752294d8608c7774420b Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 9 Dec 2024 08:15:20 -0800 Subject: [PATCH 33/58] skip the check for transpose scheduler to ensure no performance regression --- csrc/scheduler/pointwise_utils.cpp | 29 ++++++++++++++++++----------- csrc/scheduler/pointwise_utils.h | 3 ++- csrc/scheduler/transpose.cpp | 6 +++++- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 3c495e30f1b..1e6b70b1c3e 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -338,9 +338,10 @@ IterDomain* DomainMap::anyMapped( } // Determine if output TensorView is a valid reference tensor for this fusion. -// The reference tensor must map to all the iterDomains in each input and -// output. -bool DomainMap::isValidReference(TensorView* tv) const { +// The reference tensor must map to all the iterDomains in each input (and +// output, when check_coverage_to_output is set as true) +bool DomainMap::isValidReference(TensorView* tv, bool check_coverage_to_output) + const { for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { if (input_tv->uses().empty()) { continue; @@ -351,14 +352,20 @@ bool DomainMap::isValidReference(TensorView* tv) const { return false; } } - for (auto output_tv : - ir_utils::filterByType(fusion_->outputs())) { - // no need to check for self. - if (output_tv == tv) { - continue; - } - if (!areAllOutputIdsMappedTo(output_tv, tv)) { - return false; + // The check on outputs are optional, transpose scheduler might propose a + // secondary reference that only applies to a subset of IO tensors. Ideally we + // should have a more robust check and consider the IO groups instead of + // blindly skip outputs. + if (check_coverage_to_output) { + for (auto output_tv : + ir_utils::filterByType(fusion_->outputs())) { + // no need to check for self. + if (output_tv == tv) { + continue; + } + if (!areAllOutputIdsMappedTo(output_tv, tv)) { + return false; + } } } return true; diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index 86973d9b96e..9e3c631d193 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -30,7 +30,8 @@ class DomainMap { // Determine if a TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input and // output. - bool isValidReference(TensorView* tv) const; + bool isValidReference(TensorView* tv, bool check_coverage_to_output = true) + const; protected: // Determine if all IterDomains are mapped between input and the given tvs diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 93be8309113..ff4f00f355f 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -170,7 +170,11 @@ class DomainMap : public pointwise_utils::DomainMap { TensorView* result = nullptr; int64_t max_dims = -1; for (auto tv : group) { - if (isValidReference(tv)) { + // since transpose scheduler have different set of reference, we skip IDs + // coverage check of the reference on outputs of the fusion. Note that + // this is not ideal, we would want to instead have reference tensor + // checked against all its target IO tensors. + if (isValidReference(tv, false)) { int64_t dims = (int64_t)pointwise_utils::nLogicalDims(tv); if (dims > max_dims) { result = tv; From 3112ebd9f3ad2a931d00d12a33fd8228762fef9b Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 9 Dec 2024 08:40:19 -0800 Subject: [PATCH 34/58] allowing unmatched broadcast dimension --- csrc/scheduler/pointwise_utils.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 1e6b70b1c3e..c9beb026307 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -237,6 +237,10 @@ bool DomainMap::areAllOutputIdsMappedTo( // Check all source iter domain involved in producing output_tv for (IterDomain* id : get_source_producers(output_tv)) { + // It's safe to have unmapped broadcast dimension + if (id->isBraodcast()) { + continue; + } // if we find any source id that's not contained, it's possible our // propagation would fail since transformation involving this iter domain // can't be resolved. From db4cabc128510147b58ccfcf63ff0ec5db1d5789 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 9 Dec 2024 08:42:22 -0800 Subject: [PATCH 35/58] CLANGFORMAT --- tests/cpp/test_pointwise.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 08574c133b3..748d4cf7bcb 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -52,7 +52,7 @@ bool hasVectorizationCache(TensorView* tv) { class DomainMapUnitTest : public pointwise_utils::DomainMap { public: - DomainMapUnitTest(Fusion* fusion) : pointwise_utils::DomainMap(fusion){}; + DomainMapUnitTest(Fusion* fusion) : pointwise_utils::DomainMap(fusion) {}; bool testOutputMapping(TensorView* output_tv, TensorView* reference_tv) const { return areAllOutputIdsMappedTo(output_tv, reference_tv); From 55ddfc860dfc09ffff1beef86d85d6e81b5b66ac Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 9 Dec 2024 09:33:04 -0800 Subject: [PATCH 36/58] TYPO --- csrc/scheduler/pointwise_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index c9beb026307..ba082199d27 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -238,7 +238,7 @@ bool DomainMap::areAllOutputIdsMappedTo( // Check all source iter domain involved in producing output_tv for (IterDomain* id : get_source_producers(output_tv)) { // It's safe to have unmapped broadcast dimension - if (id->isBraodcast()) { + if (id->isBroadcast()) { continue; } // if we find any source id that's not contained, it's possible our From 325f5bb2f2c40c105aa5d8208b2032cf7bcd64c6 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 9 Dec 2024 10:02:50 -0800 Subject: [PATCH 37/58] lifting the broadcast exception, in case we change how expand is modeled in fusion later --- csrc/scheduler/pointwise_utils.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index ba082199d27..60c8ff4c1b5 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -220,6 +220,13 @@ bool DomainMap::areAllOutputIdsMappedTo( for (IterDomain* id : get_source_producers(reference_tv)) { covered_source_ids.insert(id); } + // It's safe to have unmapped broadcast IterDomain. There're quite a few tests + // expecting pointwise scheduler to handle this pattern + for (IterDomain* id : output_tv->getLogicalDomain()) { + if (id->isBroadcast()) { + covered_source_ids.insert(id); + } + } // Note: there's certain cases where it's safe to have dangling IDs, // e.g // T34 [i0, i1] @@ -237,10 +244,6 @@ bool DomainMap::areAllOutputIdsMappedTo( // Check all source iter domain involved in producing output_tv for (IterDomain* id : get_source_producers(output_tv)) { - // It's safe to have unmapped broadcast dimension - if (id->isBroadcast()) { - continue; - } // if we find any source id that's not contained, it's possible our // propagation would fail since transformation involving this iter domain // can't be resolved. From 22a7561a2df528f954a6c58e903045a176e8d287 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 9 Dec 2024 11:32:57 -0800 Subject: [PATCH 38/58] fixing false negative tests --- tests/cpp/test_pointwise.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 748d4cf7bcb..fd79ef9227b 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -803,11 +803,12 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { fusion->addOutput(tv4); DomainMapUnitTest domain_map(fusion); - // tv1 can't map to tv4, since we are missing the expanded ID + // tv1 can't map to tv4, because the expanded ID participates in + // transformation EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv1)); - // tv1 can't map to tv3, because it's missing the broadcast dimension - EXPECT_FALSE(domain_map.testOutputMapping(tv3, tv1)); + // tv1 can map to tv3, because the missing ID is broadcast + EXPECT_TRUE(domain_map.testOutputMapping(tv3, tv1)); // tv4 can map to tv1 EXPECT_TRUE(domain_map.testOutputMapping(tv1, tv4)); @@ -871,10 +872,10 @@ TEST_F(PointwiseTest, DomainMapFactory) { DomainMapUnitTest domain_map(fusion); - // tv2 can't map to tv4 - EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv2)); - // tv4 can't map to tv2 - EXPECT_FALSE(domain_map.testOutputMapping(tv2, tv4)); + // tv3 can't map to tv4, because it's missing the expanded dimension + EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv3)); + // tv4 can't map to tv1, since it's missing tv1->axis(0) + EXPECT_FALSE(domain_map.testOutputMapping(tv1, tv4)); // tv3 should not be a valid reference EXPECT_FALSE(domain_map.isValidReference(tv3)); From 49767da86026ef9fe94e4d055a09825fe347da9d Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 10 Dec 2024 15:45:20 -0800 Subject: [PATCH 39/58] WIP addressing review comments --- csrc/scheduler/pointwise_utils.cpp | 59 ++++++++++++++-------------- csrc/scheduler/pointwise_utils.h | 6 +-- tests/cpp/test_pointwise.cpp | 62 ++++++++++++++++++++---------- 3 files changed, 74 insertions(+), 53 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 60c8ff4c1b5..24d5238aa50 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -142,14 +142,13 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) return in_concrete_ids.empty(); } -// Note: ideally we would want to check that reference_tv contains (not -// necessarily maps) all iter domains in output_tv, so that transformation -// applied on reference_tv can be propagated to output_tv. But we don't have -// an easy way to check that. -// Instead of that, this function checks that all source iter domains involved -// in transformation on output_tv is covered by reference_tv. We do so by -// traverse all disjoint set producers on both tvs and filter them with -// `ca_map_.uniqueExactDefinitions(id).empty()`. +// Note: ideally we would want to check that reference_tv contains all iter +// domains in target_tv, so that transformation applied on reference_tv can be +// propagated to target_tv. But we don't have an easy way to check that. Instead +// of that, this function checks that all source iter domains involved in +// transformation on target_tv is covered by reference_tv. Source iter domains +// of TensorViews are IDs that doesn't have an definition and are producers of +// any IDs on the logical domain of the given TensorView. // // ------ // @@ -161,7 +160,7 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) // output(T34) // output(T198) // -// if we consider taking T34 as reference_tv. T198 is the output_tv. We can't +// if we consider taking T34 as reference_tv. T198 is the target_tv. We can't // replay T34's transform of merging all the dimensions to T198, since b3(ex)*i1 // can't be reversed. The check in this function would give us T34 with source // i0, i1; where T198 would have source i0, b3, i1, where b3 isn't contained in @@ -181,34 +180,36 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) // the example above should be able to pick T4 as reference_tv. T2's source i0, // i1 are both contained by the source of T4, so this example could be scheduled // as a single fusion. -bool DomainMap::areAllOutputIdsMappedTo( - TensorView* output_tv, +bool DomainMap::areAllTargetIdsCoveredBy( + TensorView* target_tv, TensorView* reference_tv) const { - // traverse back to collect all disjoint set producers from the logical domain - // of tv. - auto get_source_producers = [this](TensorView* tv) { + auto get_source_iter_domains = [this](TensorView* tv) { + // traverse back to collect all disjoint set producer IDs for each ID in the + // logical domain of tv. VectorOfUniqueEntries>> all_producer_sets; std::for_each( tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(), - [&](IterDomain* id) { + [&](IterDomain* tv_logical_id) { all_producer_sets.pushBack( - ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); + ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT)); }); all_producer_sets.pushBack( ca_map_.getAllDisjointSetProducers(all_producer_sets)); std::vector source_ids; + // filtering all producer IDs with empty definition to get source iter + // domains std::for_each( all_producer_sets.vector().begin(), all_producer_sets.vector().end(), [&source_ids, this](const std::shared_ptr>& producer_set_ptr) { - IterDomain* id = producer_set_ptr->front(); - if (ca_map_.uniqueExactDefinitions(id).empty()) { - source_ids.push_back(id); + IterDomain* producer_id = producer_set_ptr->front(); + if (ca_map_.uniqueExactDefinitions(producer_id).empty()) { + source_ids.push_back(producer_id); } }); return source_ids; @@ -217,14 +218,14 @@ bool DomainMap::areAllOutputIdsMappedTo( // this contains all source iter domain that's covered by reference_tv, so // it's safe for output_tv to have them. std::unordered_set covered_source_ids; - for (IterDomain* id : get_source_producers(reference_tv)) { - covered_source_ids.insert(id); + for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) { + covered_source_ids.insert(source_id_ref); } // It's safe to have unmapped broadcast IterDomain. There're quite a few tests // expecting pointwise scheduler to handle this pattern - for (IterDomain* id : output_tv->getLogicalDomain()) { - if (id->isBroadcast()) { - covered_source_ids.insert(id); + for (IterDomain* id_out : output_tv->getLogicalDomain()) { + if (id_out->isBroadcast()) { + covered_source_ids.insert(id_out); } } // Note: there's certain cases where it's safe to have dangling IDs, @@ -238,16 +239,16 @@ bool DomainMap::areAllOutputIdsMappedTo( // T34 [i0, i1] // T185 [i0, b2, i1] = broadcast(T34) // T186 [i0, i4, i1] = ones({i0, i4, i1}) - // T193 [i0, i4, i1] = add(T34, T186) + // T193 [i0, i4, i1] = add(T185, T186) // It's unsafe to propagate from T34 to T193, see issue // https://github.com/NVIDIA/Fuser/issues/3542 // Check all source iter domain involved in producing output_tv - for (IterDomain* id : get_source_producers(output_tv)) { - // if we find any source id that's not contained, it's possible our + for (IterDomain* source_id_out : get_source_iter_domains(output_tv)) { + // if we find any source_id_out that's not contained, it's possible our // propagation would fail since transformation involving this iter domain // can't be resolved. - if (!getMappedInputConcreteID(covered_source_ids, id)) { + if (!getMappedInputConcreteID(covered_source_ids, source_id_out)) { return false; } } @@ -370,7 +371,7 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_coverage_to_output) if (output_tv == tv) { continue; } - if (!areAllOutputIdsMappedTo(output_tv, tv)) { + if (!areAllTargetIdsCoveredBy(output_tv, tv)) { return false; } } diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index 9e3c631d193..ef73a2f917e 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -38,10 +38,10 @@ class DomainMap { bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) const; - // Determine if all source IterDomains in output_tv are mapped to the + // Determine if all source IterDomains in target_tv are contained by the // reference_tv, this ensures transformations from reference_tv can be - // propagated to output_tv - bool areAllOutputIdsMappedTo(TensorView* output_tv, TensorView* reference_tv) + // propagated to target_tv + bool areAllTargetIdsCoveredBy(TensorView* target_tv, TensorView* reference_tv) const; virtual IterDomain* getMappedInputConcreteID( diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index fd79ef9227b..21ee519a308 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -52,10 +52,10 @@ bool hasVectorizationCache(TensorView* tv) { class DomainMapUnitTest : public pointwise_utils::DomainMap { public: - DomainMapUnitTest(Fusion* fusion) : pointwise_utils::DomainMap(fusion) {}; - bool testOutputMapping(TensorView* output_tv, TensorView* reference_tv) + DomainMapUnitTest(Fusion* fusion) : pointwise_utils::DomainMap(fusion){}; + bool testTargetCoverage(TensorView* target_tv, TensorView* reference_tv) const { - return areAllOutputIdsMappedTo(output_tv, reference_tv); + return areAllTargetIdsCoveredBy(target_tv, reference_tv); } }; @@ -789,29 +789,36 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); + // tv0 {i0, i1} TensorView* tv0 = makeContigTensor(2); fusion->addInput(tv0); + // tv1 {i0, i1} auto tv1 = relu(tv0); fusion->addOutput(tv1); + // tv2 {i0, b2, i1} auto tv2 = broadcast(tv1, {false, true, false}); + // tv3 {i0, b3{1 ex 4}, i1} auto tv3 = expand( tv2, {tv2->axis(0)->extent(), IrBuilder::create(4), tv2->axis(2)->extent()}); + // Note that currently expand doesn't introduce an iter domain operation, so + // we don't see that i4 is produced by realizing the expanded extent of b3{1 + // ex 4} tv4 {i0, i4*i1} auto tv4 = reshape(tv3, {2, 4, 3}, {2, 12}); fusion->addOutput(tv4); DomainMapUnitTest domain_map(fusion); - // tv1 can't map to tv4, because the expanded ID participates in + // tv4 is not covered by tv1, because the expanded ID i4 participates in // transformation - EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv1)); + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv1)); - // tv1 can map to tv3, because the missing ID is broadcast - EXPECT_TRUE(domain_map.testOutputMapping(tv3, tv1)); + // tv3 is covered by tv1, because the missing ID b3{1 ex 4} is broadcast + EXPECT_TRUE(domain_map.testTargetCoverage(tv3, tv1)); - // tv4 can map to tv1 - EXPECT_TRUE(domain_map.testOutputMapping(tv1, tv4)); + // tv1 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv1, tv4)); // tv1 is not a valid reference EXPECT_FALSE(domain_map.isValidReference(tv1)); @@ -825,25 +832,30 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); + // tv0 {i0, i1} TensorView* tv0 = makeContigTensor(2); fusion->addInput(tv0); + // tv1 {i2, i0, i1} TensorView* tv1 = makeContigTensor(3); fusion->addInput(tv1); + // tv2 {i0*i1} auto tv2 = reshape(tv0, {2, 4}, {8}); fusion->addOutput(tv2); + // tv3 {b3, i0, i1} auto tv3 = broadcast(tv0, {true, false, false}); + // tv4 {i2, i0, i1} auto tv4 = add(tv1, tv3); fusion->addOutput(tv4); DomainMapUnitTest domain_map(fusion); - // tv2 can't map to tv4, because it misses tv4->axis(0) - EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv2)); + // tv4 is not covered by tv2, because it misses i2 + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv2)); - // tv2 can map to tv4 - EXPECT_TRUE(domain_map.testOutputMapping(tv2, tv4)); + // tv2 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); - // However, tv2 is not a valid reference, since it doesn't cover all input IDs + // tv2 is not a valid reference EXPECT_FALSE(domain_map.isValidReference(tv2)); // tv4 is a valid reference @@ -855,39 +867,47 @@ TEST_F(PointwiseTest, DomainMapFactory) { auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); + // tv1 {i1} TensorView* tv0 = makeContigTensor(1); fusion->addInput(tv0); + // tv1 {i0, i1} TensorView* tv1 = makeContigTensor(2); fusion->addInput(tv1); + // tv2 {b2, b3, i1} auto tv2 = broadcast(tv0, {true, true, false}); + // Note: tv1 will be broadcasted to {b2, i0, i1} before the add. + // tv3 {b2, i0, i1} auto tv3 = add(tv2, tv1); fusion->addOutput(tv3); auto size_val = IrBuilder::create(4.0, DataType::Int); auto one_val = IrBuilder::create(1, DataType::Int); + // factory method creates an iter domain out of thin air + // tv4 {i4{4}, b4, i1} auto tv4 = ones({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); + // tv5 {i4{4}, i0, i1} auto tv5 = add(tv2, tv4); fusion->addOutput(tv5); DomainMapUnitTest domain_map(fusion); - // tv3 can't map to tv4, because it's missing the expanded dimension - EXPECT_FALSE(domain_map.testOutputMapping(tv4, tv3)); - // tv4 can't map to tv1, since it's missing tv1->axis(0) - EXPECT_FALSE(domain_map.testOutputMapping(tv1, tv4)); + // tv4 is not covered by tv3, because it's missing i4{4} + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv3)); + // tv1 is not covered by tv4, since it's missing i0 + EXPECT_FALSE(domain_map.testTargetCoverage(tv1, tv4)); - // tv3 should not be a valid reference EXPECT_FALSE(domain_map.isValidReference(tv3)); + // tv5 has the same IDs as tv4, and is not a valid reference. EXPECT_FALSE(domain_map.isValidReference(tv5)); + // This fusion currently cannot be scheduled as a single kernel. The test + // verifies that it generates correct result. FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input0 = at::empty_strided({25}, {1}, options); at::Tensor input1 = at::empty_strided({7, 25}, {25, 1}, options); auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); - // EXPECT_EQ(getVecSizeForPointwise(executor_cache), 4); testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); } From 2cb337256b892d3ebf4203e40a881d61aac9faf3 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 10 Dec 2024 15:48:48 -0800 Subject: [PATCH 40/58] typo --- csrc/scheduler/pointwise_utils.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 24d5238aa50..7b0f22aa867 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -216,14 +216,14 @@ bool DomainMap::areAllTargetIdsCoveredBy( }; // this contains all source iter domain that's covered by reference_tv, so - // it's safe for output_tv to have them. + // it's safe for target_tv to have them. std::unordered_set covered_source_ids; for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) { covered_source_ids.insert(source_id_ref); } // It's safe to have unmapped broadcast IterDomain. There're quite a few tests // expecting pointwise scheduler to handle this pattern - for (IterDomain* id_out : output_tv->getLogicalDomain()) { + for (IterDomain* id_out : target_tv->getLogicalDomain()) { if (id_out->isBroadcast()) { covered_source_ids.insert(id_out); } @@ -243,8 +243,8 @@ bool DomainMap::areAllTargetIdsCoveredBy( // It's unsafe to propagate from T34 to T193, see issue // https://github.com/NVIDIA/Fuser/issues/3542 - // Check all source iter domain involved in producing output_tv - for (IterDomain* source_id_out : get_source_iter_domains(output_tv)) { + // Check all source iter domain involved in producing target_tv + for (IterDomain* source_id_out : get_source_iter_domains(target_tv)) { // if we find any source_id_out that's not contained, it's possible our // propagation would fail since transformation involving this iter domain // can't be resolved. From 73c66f824f3d1ae26a2e48f66f076850c7260471 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 10 Dec 2024 17:32:29 -0800 Subject: [PATCH 41/58] refactor the logic per review comments/discussions --- csrc/scheduler/pointwise_utils.cpp | 28 +++++++++++++++------------- csrc/scheduler/transpose.cpp | 7 ++++--- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 7b0f22aa867..ac51078bb0b 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -225,6 +225,11 @@ bool DomainMap::areAllTargetIdsCoveredBy( // expecting pointwise scheduler to handle this pattern for (IterDomain* id_out : target_tv->getLogicalDomain()) { if (id_out->isBroadcast()) { + // TODO: open an issue with a summary here on next step: fix split/merge to preserve expanded broadcasts, when this assert fails, we need to evaluate the refactor. + NVF_ERROR(ca_map_.uniqueExactDefinitions(id_out).empty(), "broadcast IDs are not expected to have definitions"); + // Note that ideally we should also be able to handle merge/split on broadcast + // IDs, so we should really move this skip inside the loop below + // `get_source_iter_domains(target_tv)` and skip broadcast source IDs. covered_source_ids.insert(id_out); } } @@ -346,9 +351,8 @@ IterDomain* DomainMap::anyMapped( } // Determine if output TensorView is a valid reference tensor for this fusion. -// The reference tensor must map to all the iterDomains in each input (and -// output, when check_coverage_to_output is set as true) -bool DomainMap::isValidReference(TensorView* tv, bool check_coverage_to_output) +// The reference tensor must map to all the iterDomains in each input and output +bool DomainMap::isValidReference(TensorView* tv) const { for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { if (input_tv->uses().empty()) { @@ -364,16 +368,14 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_coverage_to_output) // secondary reference that only applies to a subset of IO tensors. Ideally we // should have a more robust check and consider the IO groups instead of // blindly skip outputs. - if (check_coverage_to_output) { - for (auto output_tv : - ir_utils::filterByType(fusion_->outputs())) { - // no need to check for self. - if (output_tv == tv) { - continue; - } - if (!areAllTargetIdsCoveredBy(output_tv, tv)) { - return false; - } + for (auto output_tv : + ir_utils::filterByType(fusion_->outputs())) { + // no need to check for self. + if (output_tv == tv) { + continue; + } + if (!areAllTargetIdsCoveredBy(output_tv, tv)) { + return false; } } return true; diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index ff4f00f355f..9980692f421 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -174,7 +174,8 @@ class DomainMap : public pointwise_utils::DomainMap { // coverage check of the reference on outputs of the fusion. Note that // this is not ideal, we would want to instead have reference tensor // checked against all its target IO tensors. - if (isValidReference(tv, false)) { + // TODO: open an issue for this one. transpose scheduler is not supposed to reuse pointwise_utils::DomainMap::isValidRefrence. This function is too restrictive and doesn't align well with the scheme of transpose scheduler + if (isValidReference(tv)) { int64_t dims = (int64_t)pointwise_utils::nLogicalDims(tv); if (dims > max_dims) { result = tv; @@ -994,12 +995,12 @@ std::unique_ptr getTransposeHeuristics( << "max_io_dtype_size: " << max_io_dtype_size << "\n" << "group 1: " << ir_utils::toString(grouped_inputs_outputs[0]) << "\n" - << "reference1: " << reference1 << "\n" + << "reference1: " << reference1->toString() << "\n" << "inner_most_id1 position: " << inner_most_pos1_in_ref1 << " (in reference 1)\n" << "group 2: " << ir_utils::toString(grouped_inputs_outputs[1]) << "\n" - << "reference2: " << reference2 << "\n" + << "reference2: " << reference2->toString() << "\n" << "inner_most_id2 position: " << inner_most_pos2_in_ref1 << " (in reference 1)" << std::endl; if (hasSmallTransposeDimensions(tparams)) { From dbd59955816c980acbab6a6cb15795ee79918bc6 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 10 Dec 2024 17:33:58 -0800 Subject: [PATCH 42/58] fixing signature --- csrc/scheduler/pointwise_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index ef73a2f917e..ad4d6337ee5 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -30,7 +30,7 @@ class DomainMap { // Determine if a TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input and // output. - bool isValidReference(TensorView* tv, bool check_coverage_to_output = true) + bool isValidReference(TensorView* tv) const; protected: From 4ba5baabe10fb391ad918e533637f6a36a66f66f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 11:43:28 -0800 Subject: [PATCH 43/58] updating tests, removing asserts --- csrc/scheduler/pointwise_utils.cpp | 8 ++++++-- tests/cpp/test_pointwise.cpp | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index ac51078bb0b..9cb9088c776 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -225,11 +225,15 @@ bool DomainMap::areAllTargetIdsCoveredBy( // expecting pointwise scheduler to handle this pattern for (IterDomain* id_out : target_tv->getLogicalDomain()) { if (id_out->isBroadcast()) { - // TODO: open an issue with a summary here on next step: fix split/merge to preserve expanded broadcasts, when this assert fails, we need to evaluate the refactor. - NVF_ERROR(ca_map_.uniqueExactDefinitions(id_out).empty(), "broadcast IDs are not expected to have definitions"); + // if(!ca_map_.uniqueExactDefinitions(id_out).empty()) { + // continue; + // } + // Note that ideally we should also be able to handle merge/split on broadcast // IDs, so we should really move this skip inside the loop below // `get_source_iter_domains(target_tv)` and skip broadcast source IDs. + // currently we have the issue that split/merge does not preserve expanded broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126 + covered_source_ids.insert(id_out); } } diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 21ee519a308..713c2810e33 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -825,6 +825,18 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { // tv4 is a valid reference EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // check reference tv selection + EXPECT_FALSE(domain_map.findReferenceTensorView() == tv4); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({4, 7}, options); + std::vector aten_inputs = {t0}; + // NOTE: force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(PointwiseTest, DomainMapTestEg1) { @@ -860,6 +872,19 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { // tv4 is a valid reference EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // check reference tv selection + EXPECT_FALSE(domain_map.findReferenceTensorView() == tv4); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 4}, options); + at::Tensor t1 = at::randn({3, 2, 4}, options); + std::vector aten_inputs = {t0}; + // NOTE: force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(PointwiseTest, DomainMapFactory) { From 742f7f39ffd26eb007515f8853d0c55fa45d25df Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 13:03:12 -0800 Subject: [PATCH 44/58] removing checks that are not exposed by scheduler --- csrc/scheduler/pointwise_utils.cpp | 1 + tests/cpp/test_pointwise.cpp | 6 ------ 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 9cb9088c776..43e7f4ca563 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -219,6 +219,7 @@ bool DomainMap::areAllTargetIdsCoveredBy( // it's safe for target_tv to have them. std::unordered_set covered_source_ids; for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) { + NVF_ERROR(source_id_ref->definition() == nullptr || id->definition()->isA()); covered_source_ids.insert(source_id_ref); } // It's safe to have unmapped broadcast IterDomain. There're quite a few tests diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 713c2810e33..11678ee9e8f 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -826,9 +826,6 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { // tv4 is a valid reference EXPECT_TRUE(domain_map.isValidReference(tv4)); - // check reference tv selection - EXPECT_FALSE(domain_map.findReferenceTensorView() == tv4); - // validate generated kernel auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({4, 7}, options); @@ -873,9 +870,6 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { // tv4 is a valid reference EXPECT_TRUE(domain_map.isValidReference(tv4)); - // check reference tv selection - EXPECT_FALSE(domain_map.findReferenceTensorView() == tv4); - // validate generated kernel auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({2, 4}, options); From 145d902b41ef725074af414b6d9c01e3aa6c54f2 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 14:03:58 -0800 Subject: [PATCH 45/58] renaming things --- csrc/scheduler/pointwise_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 43e7f4ca563..aa3c42890dc 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -219,7 +219,7 @@ bool DomainMap::areAllTargetIdsCoveredBy( // it's safe for target_tv to have them. std::unordered_set covered_source_ids; for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) { - NVF_ERROR(source_id_ref->definition() == nullptr || id->definition()->isA()); + NVF_ERROR(source_id_ref->definition() == nullptr || source_id_ref->definition()->isA()); covered_source_ids.insert(source_id_ref); } // It's safe to have unmapped broadcast IterDomain. There're quite a few tests From 19291c8f44a9adcc631f22e3c1f381d76a296c10 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 14:15:29 -0800 Subject: [PATCH 46/58] err somehow I missed this one --- tests/cpp/test_pointwise.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 11678ee9e8f..b29f5867ae8 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -874,7 +874,7 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({2, 4}, options); at::Tensor t1 = at::randn({3, 2, 4}, options); - std::vector aten_inputs = {t0}; + std::vector aten_inputs = {t0, t1}; // NOTE: force pointwise scheduler here for unit test auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); From 526e6b748573d274f7743b7a7eadd417ada33f84 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 17:26:01 -0800 Subject: [PATCH 47/58] updating tests --- tests/cpp/test_pointwise.cpp | 96 ++++++++++++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index b29f5867ae8..351c00e28eb 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -906,7 +906,7 @@ TEST_F(PointwiseTest, DomainMapFactory) { // tv4 {i4{4}, b4, i1} auto tv4 = ones({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); // tv5 {i4{4}, i0, i1} - auto tv5 = add(tv2, tv4); + auto tv5 = mul(tv2, tv4); fusion->addOutput(tv5); DomainMapUnitTest domain_map(fusion); @@ -920,13 +920,103 @@ TEST_F(PointwiseTest, DomainMapFactory) { // tv5 has the same IDs as tv4, and is not a valid reference. EXPECT_FALSE(domain_map.isValidReference(tv5)); - // This fusion currently cannot be scheduled as a single kernel. The test - // verifies that it generates correct result. FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input0 = at::empty_strided({25}, {1}, options); at::Tensor input1 = at::empty_strided({7, 25}, {25, 1}, options); auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + SegmentedFusion* segmented_fusion = runtime->fusionSegments(); + // This fusion currently cannot be scheduled as a single kernel. It is expected to be segmented as: + // g{(pointwise) + // inputs: tv0, tv1 + // outputs: tv2, tv3 + // tv2 = broadcast(tv0) + // tv3 = add (tv2, broadcast(tv1)) + // } + // + // g{(pointwise) + // inputs: tv2 + // outputs: tv5 + // tv4 = full({4, 1, i0}) + // tv5 = mul(tv2, tv4) + // } + EXPECT_EQ(segmented_fusion->groups().size(), 2); + + for (SegmentedGroup* group : segmented_fusion->groups()) { + const std::vector& exprs = group->exprs(); + + size_t num_full = std::count_if(exprs.begin(), exprs.end(), [](Expr* expr) {return expr->isA();}); + if (num_full != 0) { + // this is the segment contains the factory op. + EXPECT_EQ(exprs.size(), 2); + EXPECT_EQ(num_full, 1); + auto binary_op_iter = std::find(exprs.begin(), exprs.end(), [](Expr* expr) {return expr->isA();}); + EXPECT_EQ(binary_op_iter->as()->getBinaryOpType(), BinaryOpType::Mul); + Fusion* group_fusion = group->getFusion(); + // validate that we have a valid reference in the segmented fusion + DomainMapUnitTest group_dm(group_fusion); + EXPECT_EQ(group_fusion->outputs().size(), 1); + EXPECT_TRUE(group_dm.isValidReference(group_fusion->outputs()[0])); + } else { + // validate segmentation has the correct ops + EXPECT_EQ(exprs.size(), 3); + EXPECT_EQ(std::count_if(exprs.begin(), exprs.end(), [](Expr* expr) {return expr->isA();}), 2); + EXPECT_EQ(std::count_if(exprs.begin(), exprs.end(), [](Expr* expr) {return expr->isA();}), 1); + Fusion* group_fusion = group->getFusion(); + auto output_add = std::find_if(group_fusion->outputs().begin(), group_fusion->outputs().end(), [](Val* val) { + return val->definition()->isA(); + }); + EXPECT_TRUE(output_add != group_fusion->outputs().end()); + DomainMapUnitTest group_dm(group_fusion); + // validate that the segmented fusion choose the add output as the reference + EXPECT_TRUE(group_dm.isValidReference(output_add->as())); + } + } + + testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapPad) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv1 {b1, i0} + TensorView* tv0 = TensorViewBuilder().shape({1, -1}).build(); + fusion->addInput(tv0); + // tv1 {i2, b1, i0} + TensorView* tv1 = TensorViewBuilder().shape({-1, 1, -1}).build(); + fusion->addInput(tv1); + // tv2 {i2, b1, i0} + auto tv2 = add(tv1, tv0); + fusion->addOutput(tv2); + // i3 = resize(b1 + 4 + 4) + // tv3 {i3, i0} + auto tv3 = pad(tv0, {IrBuilder::create(0L), IrBuilder::create(0L), IrBuilder::create(4L), IrBuilder::create(4L)}); + // tv4 {i3*i0} + auto tv4 = reshape(tv3, {9, 5}, {45}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + + // tv4 is covered by tv2, because i3 is produced by b1 + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv2)); + // tv1 is not covered by tv4, since it's missing i0 + EXPECT_FALSE(domain_map.testTargetCoverage(tv1, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv3)); + // tv5 has the same IDs as tv4, and is not a valid reference. + EXPECT_FALSE(domain_map.isValidReference(tv5)); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::empty_strided({1, 5}, {5, 1}, options); + at::Tensor input1 = at::empty_strided({7, 1, 5}, {5, 5, 1}, options); + auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); } From 0ecc1f6874e37d2a1d4d673764ecd0f0775863bf Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 17:38:14 -0800 Subject: [PATCH 48/58] adding another test --- tests/cpp/test_pointwise.cpp | 125 ++++++++++++++++++++++++++++------- 1 file changed, 101 insertions(+), 24 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 351c00e28eb..d3de289ef8d 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -928,14 +928,14 @@ TEST_F(PointwiseTest, DomainMapFactory) { FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); SegmentedFusion* segmented_fusion = runtime->fusionSegments(); - // This fusion currently cannot be scheduled as a single kernel. It is expected to be segmented as: - // g{(pointwise) + // This fusion currently cannot be scheduled as a single kernel. It is + // expected to be segmented as: g{(pointwise) // inputs: tv0, tv1 // outputs: tv2, tv3 // tv2 = broadcast(tv0) // tv3 = add (tv2, broadcast(tv1)) // } - // + // // g{(pointwise) // inputs: tv2 // outputs: tv5 @@ -947,13 +947,20 @@ TEST_F(PointwiseTest, DomainMapFactory) { for (SegmentedGroup* group : segmented_fusion->groups()) { const std::vector& exprs = group->exprs(); - size_t num_full = std::count_if(exprs.begin(), exprs.end(), [](Expr* expr) {return expr->isA();}); + size_t num_full = std::count_if(exprs.begin(), exprs.end(), [](Expr* expr) { + return expr->isA(); + }); if (num_full != 0) { // this is the segment contains the factory op. EXPECT_EQ(exprs.size(), 2); EXPECT_EQ(num_full, 1); - auto binary_op_iter = std::find(exprs.begin(), exprs.end(), [](Expr* expr) {return expr->isA();}); - EXPECT_EQ(binary_op_iter->as()->getBinaryOpType(), BinaryOpType::Mul); + auto binary_op_iter = + std::find(exprs.begin(), exprs.end(), [](Expr* expr) { + return expr->isA(); + }); + EXPECT_EQ( + (*binary_op_iter)->as()->getBinaryOpType(), + BinaryOpType::Mul); Fusion* group_fusion = group->getFusion(); // validate that we have a valid reference in the segmented fusion DomainMapUnitTest group_dm(group_fusion); @@ -962,16 +969,28 @@ TEST_F(PointwiseTest, DomainMapFactory) { } else { // validate segmentation has the correct ops EXPECT_EQ(exprs.size(), 3); - EXPECT_EQ(std::count_if(exprs.begin(), exprs.end(), [](Expr* expr) {return expr->isA();}), 2); - EXPECT_EQ(std::count_if(exprs.begin(), exprs.end(), [](Expr* expr) {return expr->isA();}), 1); + EXPECT_EQ( + std::count_if( + exprs.begin(), + exprs.end(), + [](Expr* expr) { return expr->isA(); }), + 2); + EXPECT_EQ( + std::count_if( + exprs.begin(), + exprs.end(), + [](Expr* expr) { return expr->isA(); }), + 1); Fusion* group_fusion = group->getFusion(); - auto output_add = std::find_if(group_fusion->outputs().begin(), group_fusion->outputs().end(), [](Val* val) { - return val->definition()->isA(); - }); + auto output_add = std::find_if( + group_fusion->outputs().begin(), + group_fusion->outputs().end(), + [](Val* val) { return val->definition()->isA(); }); EXPECT_TRUE(output_add != group_fusion->outputs().end()); DomainMapUnitTest group_dm(group_fusion); - // validate that the segmented fusion choose the add output as the reference - EXPECT_TRUE(group_dm.isValidReference(output_add->as())); + // validate that the segmented fusion choose the add output as the + // reference + EXPECT_TRUE(group_dm.isValidReference((*output_add)->as())); } } @@ -996,7 +1015,12 @@ TEST_F(PointwiseTest, DomainMapPad) { fusion->addOutput(tv2); // i3 = resize(b1 + 4 + 4) // tv3 {i3, i0} - auto tv3 = pad(tv0, {IrBuilder::create(0L), IrBuilder::create(0L), IrBuilder::create(4L), IrBuilder::create(4L)}); + auto tv3 = + pad(tv0, + {IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(4L), + IrBuilder::create(4L)}); // tv4 {i3*i0} auto tv4 = reshape(tv3, {9, 5}, {45}); fusion->addOutput(tv4); @@ -1004,20 +1028,73 @@ TEST_F(PointwiseTest, DomainMapPad) { DomainMapUnitTest domain_map(fusion); // tv4 is covered by tv2, because i3 is produced by b1 + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv2)); + // tv2 is not covered by tv4, it's missing i2 + EXPECT_FALSE(domain_map.testTargetCoverage(tv2, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv4)); + EXPECT_TRUE(domain_map.isValidReference(tv2)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::empty_strided({1, 5}, {5, 1}, options); + at::Tensor t1 = at::empty_strided({7, 1, 5}, {5, 5, 1}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE: force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapSlice) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv1 {i1, i0} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i1, i2} + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv1); + + // b3 = resize(i2 + 4 + 4) + // tv2 {i1, b3} + auto tv2 = slice( + tv1, + {Slice(), + {IrBuilder::create(0L), + IrBuilder::create(1L), + IrBuilder::create(1L)}}); + fusion->addOutput(tv2); + // tv3 {i1, i0} + auto tv3 = add(tv0, tv2); + // tv4 {i1*i0} + auto tv4 = reshape(tv3, {2, 4}, {8}); + fusion->addOutput(tv4); + // TODO: add a slice that's not merged back into the consumer + + DomainMapUnitTest domain_map(fusion); + + // tv4 is not covered by tv2, because i0 is missing EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv2)); - // tv1 is not covered by tv4, since it's missing i0 - EXPECT_FALSE(domain_map.testTargetCoverage(tv1, tv4)); + // tv2 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); - EXPECT_FALSE(domain_map.isValidReference(tv3)); - // tv5 has the same IDs as tv4, and is not a valid reference. - EXPECT_FALSE(domain_map.isValidReference(tv5)); + EXPECT_FALSE(domain_map.isValidReference(tv2)); + EXPECT_TRUE(domain_map.isValidReference(tv4)); - FusionExecutorCache executor_cache(std::move(fusion_ptr)); + // validate generated kernel auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input0 = at::empty_strided({1, 5}, {5, 1}, options); - at::Tensor input1 = at::empty_strided({7, 1, 5}, {5, 5, 1}, options); - auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); - testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); + at::Tensor t0 = at::randn({2, 4}, options); + at::Tensor t1 = at::randn({2, 8}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE: force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); } } // namespace nvfuser From 6abaa1d941636f80243ca4748e7d81d2dfc1cab4 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 18:18:56 -0800 Subject: [PATCH 49/58] test fixing --- tests/cpp/test_pointwise.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index d3de289ef8d..2204c17427b 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -955,7 +955,7 @@ TEST_F(PointwiseTest, DomainMapFactory) { EXPECT_EQ(exprs.size(), 2); EXPECT_EQ(num_full, 1); auto binary_op_iter = - std::find(exprs.begin(), exprs.end(), [](Expr* expr) { + std::find_if(exprs.begin(), exprs.end(), [](Expr* expr) { return expr->isA(); }); EXPECT_EQ( @@ -965,7 +965,8 @@ TEST_F(PointwiseTest, DomainMapFactory) { // validate that we have a valid reference in the segmented fusion DomainMapUnitTest group_dm(group_fusion); EXPECT_EQ(group_fusion->outputs().size(), 1); - EXPECT_TRUE(group_dm.isValidReference(group_fusion->outputs()[0])); + EXPECT_TRUE(group_dm.isValidReference( + group_fusion->outputs()[0]->as())); } else { // validate segmentation has the correct ops EXPECT_EQ(exprs.size(), 3); @@ -1053,7 +1054,7 @@ TEST_F(PointwiseTest, DomainMapSlice) { auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); - // tv1 {i1, i0} + // tv0 {i1, i0} TensorView* tv0 = makeContigTensor(2); fusion->addInput(tv0); // tv1 {i1, i2} From e8a4dddbe034f08d01abd9572ff7b4972a0b3472 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 19:04:22 -0800 Subject: [PATCH 50/58] fixing tests --- tests/cpp/test_pointwise.cpp | 71 +++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 10 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 2204c17427b..24a1fad64f4 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -1047,7 +1047,7 @@ TEST_F(PointwiseTest, DomainMapPad) { testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(PointwiseTest, DomainMapSlice) { +TEST_F(PointwiseTest, DomainMapSlice0) { preseg_passes::OptimizationPassGuard optimization_guard(false); auto fusion_ptr = std::make_unique(); @@ -1057,12 +1057,13 @@ TEST_F(PointwiseTest, DomainMapSlice) { // tv0 {i1, i0} TensorView* tv0 = makeContigTensor(2); fusion->addInput(tv0); - // tv1 {i1, i2} - TensorView* tv1 = makeContigTensor(2); + // tv1 {i1, i0} + // use concrete tensor to avoid need of concretization + TensorView* tv1 = makeContigConcreteTensor({2, 4}); fusion->addInput(tv1); - // b3 = resize(i2 + 4 + 4) - // tv2 {i1, b3} + // b3 = resize(i0 + 0 - 3) + // tv2 {i1, b2} auto tv2 = slice( tv1, {Slice(), @@ -1071,7 +1072,7 @@ TEST_F(PointwiseTest, DomainMapSlice) { IrBuilder::create(1L)}}); fusion->addOutput(tv2); // tv3 {i1, i0} - auto tv3 = add(tv0, tv2); + auto tv3 = add(tv0, tv1); // tv4 {i1*i0} auto tv4 = reshape(tv3, {2, 4}, {8}); fusion->addOutput(tv4); @@ -1079,9 +1080,59 @@ TEST_F(PointwiseTest, DomainMapSlice) { DomainMapUnitTest domain_map(fusion); - // tv4 is not covered by tv2, because i0 is missing + // tv2 and tv4 has the same source IDs, since b3 = resize(i0 + 0 - 3) + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv2)); + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); + + EXPECT_TRUE(domain_map.isValidReference(tv2)); + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 4}, options); + at::Tensor t1 = at::randn({2, 4}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE: force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapSlice1) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i2, i1, i0} + TensorView* tv0 = makeContigTensor(3); + fusion->addInput(tv0); + // tv1 {i1, i0} + // use concrete tensor to avoid need of concretization + TensorView* tv1 = makeContigConcreteTensor({2, 4}); + fusion->addInput(tv1); + + // b3 = resize(i0 + 0 - 3) + // tv2 {i1, b3} + auto tv2 = slice( + tv1, + {Slice(), + {IrBuilder::create(0L), + IrBuilder::create(1L), + IrBuilder::create(1L)}}); + fusion->addOutput(tv2); + // tv3 {i2, i1, i0} + auto tv3 = add(tv0, tv1); + // tv4 {i2, i1*i0} + auto tv4 = reshape(tv3, {2, 2, 4}, {2, 8}); + fusion->addOutput(tv4); + // TODO: add a slice that's not merged back into the consumer + + DomainMapUnitTest domain_map(fusion); + + // i2 is missing in tv2 EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv2)); - // tv2 is covered by tv4 EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); EXPECT_FALSE(domain_map.isValidReference(tv2)); @@ -1089,8 +1140,8 @@ TEST_F(PointwiseTest, DomainMapSlice) { // validate generated kernel auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({2, 4}, options); - at::Tensor t1 = at::randn({2, 8}, options); + at::Tensor t0 = at::randn({2, 2, 4}, options); + at::Tensor t1 = at::randn({2, 4}, options); std::vector aten_inputs = {t0, t1}; // NOTE: force pointwise scheduler here for unit test auto cg_results = From 6ada6570e5dfb646eee38b8abbb2f0865296f1e8 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 19:13:39 -0800 Subject: [PATCH 51/58] CLANGFORMAT --- csrc/scheduler/pointwise_utils.cpp | 14 ++++++++------ csrc/scheduler/transpose.cpp | 5 ++++- tests/cpp/test_pointwise.cpp | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index aa3c42890dc..8f7efc3065c 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -219,7 +219,9 @@ bool DomainMap::areAllTargetIdsCoveredBy( // it's safe for target_tv to have them. std::unordered_set covered_source_ids; for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) { - NVF_ERROR(source_id_ref->definition() == nullptr || source_id_ref->definition()->isA()); + NVF_ERROR( + source_id_ref->definition() == nullptr || + source_id_ref->definition()->isA()); covered_source_ids.insert(source_id_ref); } // It's safe to have unmapped broadcast IterDomain. There're quite a few tests @@ -230,10 +232,11 @@ bool DomainMap::areAllTargetIdsCoveredBy( // continue; // } - // Note that ideally we should also be able to handle merge/split on broadcast - // IDs, so we should really move this skip inside the loop below + // Note that ideally we should also be able to handle merge/split on + // broadcast IDs, so we should really move this skip inside the loop below // `get_source_iter_domains(target_tv)` and skip broadcast source IDs. - // currently we have the issue that split/merge does not preserve expanded broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126 + // currently we have the issue that split/merge does not preserve expanded + // broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126 covered_source_ids.insert(id_out); } @@ -357,8 +360,7 @@ IterDomain* DomainMap::anyMapped( // Determine if output TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input and output -bool DomainMap::isValidReference(TensorView* tv) - const { +bool DomainMap::isValidReference(TensorView* tv) const { for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { if (input_tv->uses().empty()) { continue; diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 9980692f421..4321afd9c7f 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -174,7 +174,10 @@ class DomainMap : public pointwise_utils::DomainMap { // coverage check of the reference on outputs of the fusion. Note that // this is not ideal, we would want to instead have reference tensor // checked against all its target IO tensors. - // TODO: open an issue for this one. transpose scheduler is not supposed to reuse pointwise_utils::DomainMap::isValidRefrence. This function is too restrictive and doesn't align well with the scheme of transpose scheduler + // TODO: open an issue for this one. transpose scheduler is not supposed + // to reuse pointwise_utils::DomainMap::isValidRefrence. This function is + // too restrictive and doesn't align well with the scheme of transpose + // scheduler if (isValidReference(tv)) { int64_t dims = (int64_t)pointwise_utils::nLogicalDims(tv); if (dims > max_dims) { diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 24a1fad64f4..56f7f26cfcc 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -52,7 +52,7 @@ bool hasVectorizationCache(TensorView* tv) { class DomainMapUnitTest : public pointwise_utils::DomainMap { public: - DomainMapUnitTest(Fusion* fusion) : pointwise_utils::DomainMap(fusion){}; + DomainMapUnitTest(Fusion* fusion) : pointwise_utils::DomainMap(fusion) {}; bool testTargetCoverage(TensorView* target_tv, TensorView* reference_tv) const { return areAllTargetIdsCoveredBy(target_tv, reference_tv); From d797df8e12a846144cba365c161949ca97ce9700 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 19:15:55 -0800 Subject: [PATCH 52/58] removing python test since it's already covered in cpp test --- tests/python/test_pointwise.py | 37 ---------------------------------- 1 file changed, 37 deletions(-) diff --git a/tests/python/test_pointwise.py b/tests/python/test_pointwise.py index 025bfcf4f7b..e24a139555e 100644 --- a/tests/python/test_pointwise.py +++ b/tests/python/test_pointwise.py @@ -421,40 +421,3 @@ def fusion_func(fd: FusionDefinition): with pytest.raises(RuntimeError, match="No executor supports provided fusion."): _ = fd.execute(inputs) - - -def test_pointwise_issue_3512(): - inputs = [ - torch.testing.make_tensor( - (1, 2048, 512), dtype=torch.bfloat16, device="cuda:0" - ), - ] - - # T34 and T198 are both candidate for reference tv in pointwise scheduler. - # We can only pick T198 for scheduling though, because a expanded dimension - # is merged by the reshape that produces T198, which means transformation - # on T34 wouldn't be able to propagate from T192 to T198. - def fusion_func(fd: FusionDefinition): - T3 = fd.define_tensor( - shape=[1, 2048, 512], - contiguity=[None, True, True], - dtype=DataType.BFloat16, - is_cpu=False, - stride_order=[2, 1, 0], - ) - T33 = fd.ops.reshape(T3, new_shape=[1, 2048, 8, 64]) - T34 = fd.ops.permute(T33, dims=[0, 2, 1, 3]) - T185 = fd.ops.broadcast_in_dim( - T34, shape=[1, 8, 1, 2048, 64], broadcast_dims=[0, 1, 3, 4] - ) - T192 = fd.ops.broadcast_in_dim( - T185, shape=[1, 8, 4, 2048, 64], broadcast_dims=[0, 1, 2, 3, 4] - ) - T198 = fd.ops.reshape(T192, new_shape=[1, 32, 2048, 64]) - fd.add_output(T34) - fd.add_output(T198) - - with FusionDefinition() as fd: - fusion_func(fd) - - _ = fd.execute(inputs) From 25362cdf4d8d84204212a6c32802be1e0c172b26 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 19:22:42 -0800 Subject: [PATCH 53/58] oops, assert was placed in the wrong spot --- csrc/scheduler/pointwise_utils.cpp | 10 +++------- tests/cpp/test_pointwise.cpp | 4 ---- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 8f7efc3065c..e372343af5d 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -219,25 +219,21 @@ bool DomainMap::areAllTargetIdsCoveredBy( // it's safe for target_tv to have them. std::unordered_set covered_source_ids; for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) { - NVF_ERROR( - source_id_ref->definition() == nullptr || - source_id_ref->definition()->isA()); covered_source_ids.insert(source_id_ref); } // It's safe to have unmapped broadcast IterDomain. There're quite a few tests // expecting pointwise scheduler to handle this pattern for (IterDomain* id_out : target_tv->getLogicalDomain()) { if (id_out->isBroadcast()) { - // if(!ca_map_.uniqueExactDefinitions(id_out).empty()) { - // continue; - // } + NVF_ERROR( + id_out->definition() == nullptr || + id_out->definition()->isA()); // Note that ideally we should also be able to handle merge/split on // broadcast IDs, so we should really move this skip inside the loop below // `get_source_iter_domains(target_tv)` and skip broadcast source IDs. // currently we have the issue that split/merge does not preserve expanded // broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126 - covered_source_ids.insert(id_out); } } diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 56f7f26cfcc..9a978916510 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -1076,10 +1076,8 @@ TEST_F(PointwiseTest, DomainMapSlice0) { // tv4 {i1*i0} auto tv4 = reshape(tv3, {2, 4}, {8}); fusion->addOutput(tv4); - // TODO: add a slice that's not merged back into the consumer DomainMapUnitTest domain_map(fusion); - // tv2 and tv4 has the same source IDs, since b3 = resize(i0 + 0 - 3) EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv2)); EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); @@ -1127,10 +1125,8 @@ TEST_F(PointwiseTest, DomainMapSlice1) { // tv4 {i2, i1*i0} auto tv4 = reshape(tv3, {2, 2, 4}, {2, 8}); fusion->addOutput(tv4); - // TODO: add a slice that's not merged back into the consumer DomainMapUnitTest domain_map(fusion); - // i2 is missing in tv2 EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv2)); EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); From a129e720200eb7d450db395257c41465440f93a7 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 19:25:09 -0800 Subject: [PATCH 54/58] CLANGFORMAT --- csrc/scheduler/pointwise_utils.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index ad4d6337ee5..8daaccc616a 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -30,8 +30,7 @@ class DomainMap { // Determine if a TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input and // output. - bool isValidReference(TensorView* tv) - const; + bool isValidReference(TensorView* tv) const; protected: // Determine if all IterDomains are mapped between input and the given tvs From b7f2efbfc86cd2a503cd03de30f9fa7a5c67994e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 22:27:53 -0800 Subject: [PATCH 55/58] adding naoya's example --- tests/cpp/test_pointwise.cpp | 54 ++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 9a978916510..55e6adc564c 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -998,14 +998,14 @@ TEST_F(PointwiseTest, DomainMapFactory) { testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); } -TEST_F(PointwiseTest, DomainMapPad) { +TEST_F(PointwiseTest, DomainMapPad0) { preseg_passes::OptimizationPassGuard optimization_guard(false); auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); - // tv1 {b1, i0} + // tv0 {b1, i0} TensorView* tv0 = TensorViewBuilder().shape({1, -1}).build(); fusion->addInput(tv0); // tv1 {i2, b1, i0} @@ -1047,6 +1047,56 @@ TEST_F(PointwiseTest, DomainMapPad) { testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); } +TEST_F(PointwiseTest, DomainMapPad1) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {b1, i0} + TensorView* tv0 = TensorViewBuilder().shape({1, -1}).build(); + fusion->addInput(tv0); + // tv1 {i2, i3, i4, b5} + TensorView* tv1 = TensorViewBuilder().shape({-1, -1, -1, 1}).build(); + fusion->addInput(tv1); + + // tv2 {b6, b7, b1, i0} + auto tv2 = broadcast(tv0, {true, true, false, false}); + // tv3 {i2, i3, i4, i0} + auto tv3 = add(tv1, tv2); + fusion->addOutput(tv3); + // i8 = resize(b1 + 4 + 4) + // tv4 {i8, i0} + auto tv4 = + pad(tv0, + {IrBuilder::create(4L), + IrBuilder::create(4L), + IrBuilder::create(0L), + IrBuilder::create(0L)}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + + // tv4 is covered by tv3, because i8 is produced by b1 + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv3)); + // tv3 is not covered by tv4, it's missing i2 and i3 + EXPECT_FALSE(domain_map.testTargetCoverage(tv3, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv4)); + EXPECT_TRUE(domain_map.isValidReference(tv3)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::empty_strided({1, 5}, {5, 1}, options); + at::Tensor t1 = at::empty_strided({2, 3, 4, 1}, {12, 4, 1, 1}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE: force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + TEST_F(PointwiseTest, DomainMapSlice0) { preseg_passes::OptimizationPassGuard optimization_guard(false); From 307569f553209a6e74b7adeed1946d914e255a35 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 11 Dec 2024 22:57:41 -0800 Subject: [PATCH 56/58] I was padding the wrong dimension here --- tests/cpp/test_pointwise.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 55e6adc564c..29e6ebd7bff 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -1070,10 +1070,10 @@ TEST_F(PointwiseTest, DomainMapPad1) { // tv4 {i8, i0} auto tv4 = pad(tv0, - {IrBuilder::create(4L), - IrBuilder::create(4L), + {IrBuilder::create(0L), IrBuilder::create(0L), - IrBuilder::create(0L)}); + IrBuilder::create(4L), + IrBuilder::create(4L)}); fusion->addOutput(tv4); DomainMapUnitTest domain_map(fusion); From af315c7dc43f2137ac50bfa7693091c5f7e8ffdc Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 12 Dec 2024 00:16:39 -0800 Subject: [PATCH 57/58] made a small refactor to avoid regression --- csrc/scheduler/pointwise_utils.cpp | 11 ++++++- tests/cpp/test_pointwise.cpp | 53 ++++++++++++++++++++++++++++-- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index e372343af5d..1ac92dd714a 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -254,10 +254,19 @@ bool DomainMap::areAllTargetIdsCoveredBy( // Check all source iter domain involved in producing target_tv for (IterDomain* source_id_out : get_source_iter_domains(target_tv)) { + // NOTE: we use concrete id instead. This allows us to link indirect + // broadcast. So in the example below: T2[i0, i1] = T0[i0, b0] + T1[i0, i1] + // T3[i0, i9] = pad(T0[i0, b0]) + // We have i9 in T3 + // -> source ID b0 + // -> concrete map to i1 + // So T3 is contained by T2. See test `PointwiseTest.DomainMapPad1` + auto concrete_source_id_out = + ca_map_.getConcreteMappedID(source_id_out, IdMappingMode::PERMISSIVE); // if we find any source_id_out that's not contained, it's possible our // propagation would fail since transformation involving this iter domain // can't be resolved. - if (!getMappedInputConcreteID(covered_source_ids, source_id_out)) { + if (!getMappedInputConcreteID(covered_source_ids, concrete_source_id_out)) { return false; } } diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 29e6ebd7bff..74c5e714ff1 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -814,8 +814,9 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { // transformation EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv1)); - // tv3 is covered by tv1, because the missing ID b3{1 ex 4} is broadcast - EXPECT_TRUE(domain_map.testTargetCoverage(tv3, tv1)); + // tv3 is not covered by tv1, because the missing ID b3{1 ex 4} is concretized + // as i4, which is not mapped on tv1 + EXPECT_FALSE(domain_map.testTargetCoverage(tv3, tv1)); // tv1 is covered by tv4 EXPECT_TRUE(domain_map.testTargetCoverage(tv1, tv4)); @@ -881,6 +882,51 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); } +TEST_F(PointwiseTest, DomainMapTestEg2) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i0, i1} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i0, i1} + auto tv1 = relu(tv0); + fusion->addOutput(tv1); + // tv2 {i0, b2, i1} + auto tv2 = broadcast(tv1, {false, true, false}); + // tv3 {i0, b3{1 ex 4}, i1} + auto tv3 = expand( + tv2, + {tv2->axis(0)->extent(), + IrBuilder::create(4), + tv2->axis(2)->extent()}); + fusion->addOutput(tv3); + + DomainMapUnitTest domain_map(fusion); + // tv3 is covered by tv1, because the missing ID b3{1 ex 4} is broadcast and + // doesn't get resolved to a concrete broadcast ID. + EXPECT_TRUE(domain_map.testTargetCoverage(tv3, tv1)); + + // tv1 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv1, tv3)); + + // tv1 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv1)); + + // tv3 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv3)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({4, 7}, options); + std::vector aten_inputs = {t0}; + // NOTE: force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + TEST_F(PointwiseTest, DomainMapFactory) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); @@ -1078,7 +1124,8 @@ TEST_F(PointwiseTest, DomainMapPad1) { DomainMapUnitTest domain_map(fusion); - // tv4 is covered by tv3, because i8 is produced by b1 + // tv4 is covered by tv3, because i8 is produced by b1, a broadcast dimension + // concretized as i4 EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv3)); // tv3 is not covered by tv4, it's missing i2 and i3 EXPECT_FALSE(domain_map.testTargetCoverage(tv3, tv4)); From d46323cd299443823124d5b71be733c8540abe93 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 12 Dec 2024 02:58:25 -0800 Subject: [PATCH 58/58] committing something so I can trigger CI again --- tests/cpp/test_pointwise.cpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 74c5e714ff1..a6889a13562 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -316,7 +316,7 @@ TEST_F(PointwiseTest, Issue1567VectorizeAllocationDomain) { at::Tensor input1 = at::empty_strided({1, 128, 1}, {128, 1, 128}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -350,7 +350,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase0) { at::Tensor input1 = at::randn({1024, 2, 512}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs, false); auto pparams = cg_results.heuristic_params->as(); @@ -384,7 +384,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase1) { at::Tensor input1 = at::randn({1024, 512, 2}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -424,7 +424,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase2) { at::Tensor input1 = at::empty_strided({1024, 512, 2}, {2, 2048, 1}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -461,7 +461,7 @@ TEST_F(PointwiseTest, VIssue1567ectorizationFactorAnalysisCase3) { at::Tensor input1 = at::randn({512, 1024, 2}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -803,7 +803,7 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { {tv2->axis(0)->extent(), IrBuilder::create(4), tv2->axis(2)->extent()}); - // Note that currently expand doesn't introduce an iter domain operation, so + // NOTE hat currently expand doesn't introduce an iter domain operation, so // we don't see that i4 is produced by realizing the expanded extent of b3{1 // ex 4} tv4 {i0, i4*i1} auto tv4 = reshape(tv3, {2, 4, 3}, {2, 12}); @@ -831,7 +831,7 @@ TEST_F(PointwiseTest, DomainMapTestEg0) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({4, 7}, options); std::vector aten_inputs = {t0}; - // NOTE: force pointwise scheduler here for unit test + // NOTE force pointwise scheduler here for unit test auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); @@ -876,7 +876,7 @@ TEST_F(PointwiseTest, DomainMapTestEg1) { at::Tensor t0 = at::randn({2, 4}, options); at::Tensor t1 = at::randn({3, 2, 4}, options); std::vector aten_inputs = {t0, t1}; - // NOTE: force pointwise scheduler here for unit test + // NOTE force pointwise scheduler here for unit test auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); @@ -921,7 +921,7 @@ TEST_F(PointwiseTest, DomainMapTestEg2) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({4, 7}, options); std::vector aten_inputs = {t0}; - // NOTE: force pointwise scheduler here for unit test + // NOTE force pointwise scheduler here for unit test auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); @@ -941,7 +941,7 @@ TEST_F(PointwiseTest, DomainMapFactory) { // tv2 {b2, b3, i1} auto tv2 = broadcast(tv0, {true, true, false}); - // Note: tv1 will be broadcasted to {b2, i0, i1} before the add. + // NOTE tv1 will be broadcasted to {b2, i0, i1} before the add. // tv3 {b2, i0, i1} auto tv3 = add(tv2, tv1); fusion->addOutput(tv3); @@ -1087,7 +1087,7 @@ TEST_F(PointwiseTest, DomainMapPad0) { at::Tensor t0 = at::empty_strided({1, 5}, {5, 1}, options); at::Tensor t1 = at::empty_strided({7, 1, 5}, {5, 5, 1}, options); std::vector aten_inputs = {t0, t1}; - // NOTE: force pointwise scheduler here for unit test + // NOTE force pointwise scheduler here for unit test auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); @@ -1138,7 +1138,7 @@ TEST_F(PointwiseTest, DomainMapPad1) { at::Tensor t0 = at::empty_strided({1, 5}, {5, 1}, options); at::Tensor t1 = at::empty_strided({2, 3, 4, 1}, {12, 4, 1, 1}, options); std::vector aten_inputs = {t0, t1}; - // NOTE: force pointwise scheduler here for unit test + // NOTE force pointwise scheduler here for unit test auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); @@ -1187,7 +1187,7 @@ TEST_F(PointwiseTest, DomainMapSlice0) { at::Tensor t0 = at::randn({2, 4}, options); at::Tensor t1 = at::randn({2, 4}, options); std::vector aten_inputs = {t0, t1}; - // NOTE: force pointwise scheduler here for unit test + // NOTE force pointwise scheduler here for unit test auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); @@ -1236,7 +1236,7 @@ TEST_F(PointwiseTest, DomainMapSlice1) { at::Tensor t0 = at::randn({2, 2, 4}, options); at::Tensor t1 = at::randn({2, 4}, options); std::vector aten_inputs = {t0, t1}; - // NOTE: force pointwise scheduler here for unit test + // NOTE force pointwise scheduler here for unit test auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__);