From d3722307ebd210b0f5347295635dc4588019d8a7 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 30 Aug 2024 12:24:45 -0400 Subject: [PATCH 1/8] Cache `ir_utils::allTvs` as part of Fusion This caches `ir_utils::allTvs(fusion)` as `fusion->allTvs()`. The cache is automatically invalidated whenever the TV graph topology changes; this mechanism is the same one used to recompute `Expr` uses automatically. --- csrc/fusion.cpp | 16 ++++++++++++++++ csrc/fusion.h | 7 +++++++ csrc/kernel_cache.cpp | 6 ++++-- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index b2fe3bab274..3b9eff347fd 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -116,6 +116,14 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_; + if (from->all_tvs_ptr_ != nullptr) { + to->all_tvs_ptr_ = std::make_unique>(); + to->all_tvs_ptr_->reserve(from->all_tvs_ptr_->size()); + for (TensorView* from_tv : *from->all_tvs_ptr_) { + to->all_tvs_ptr_->push_back(ir_cloner.clone(from_tv)->as()); + } + } + return ir_cloner; } @@ -854,4 +862,12 @@ bool isExpressionEvaluated(Fusion* fusion) { }); } +const std::vector& Fusion::allTvs() { + if (all_tvs_ptr_ == nullptr) { + all_tvs_ptr_ = + std::make_unique>(ir_utils::allTvs(this)); + } + return *all_tvs_ptr_; +} + } // namespace nvfuser diff --git a/csrc/fusion.h b/csrc/fusion.h index 72280c384c7..a7228776b9c 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -429,6 +429,10 @@ class NVF_API Fusion : public IrContainer { expected_dynamic_smem_bytes_ = bytes; } + //! This is a cached version of ir_utils::allTvs that is invalidated + //! whenever we invalidate TV uses + const std::vector& allTvs(); + protected: friend SegmentCandidateFinder; friend SegmentedFusion; @@ -458,6 +462,7 @@ class NVF_API Fusion : public IrContainer { //! the update). void invalidateTvUses() { all_tv_uses_valid_ = false; + all_tvs_ptr_ = nullptr; } private: @@ -485,6 +490,8 @@ class NVF_API Fusion : public IrContainer { // If set to a non-negative value during scheduling, this will be checked by // the executor. int64_t expected_dynamic_smem_bytes_ = -1LL; + + std::unique_ptr> all_tvs_ptr_ = nullptr; }; // Returns true if all fusion outputs are expression evaluated. diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index af2ecdc1d44..8e8c6be4f60 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -988,7 +988,8 @@ FusionKernelRuntime::FusionKernelRuntime( // SchedulerRuntimeInfo modifies the fusion, so it is required for both // compile paths. - std::vector all_tvs = ir_utils::allTvs(fusion.get()); + std::vector all_tvs = + fusion->allTvs(); // ir_utils::allTvs(fusion.get()); SchedulerRuntimeInfo runtime_info( fusion.get(), args, nullptr, all_tvs, forced_index_type); @@ -1453,7 +1454,8 @@ std::optional FusionKernelRuntime:: // Get all tensorviews for segmented fusion std::vector all_tvs_for_fusion_to_run = - ir_utils::allTvs(fusion_to_run); + fusion_to_run->allTvs(); + // ir_utils::allTvs(fusion_to_run); SchedulerRuntimeInfo fusion_to_run_info( fusion_to_run, From f89eeee82c318954562cb3a2029f3e07c5ab5424 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 1 Sep 2024 10:27:49 -0700 Subject: [PATCH 2/8] Cache allTvs in Fusion like allTvUses. --- csrc/fusion.cpp | 29 ++++++++++++++++------------- csrc/fusion.h | 14 +++++++------- csrc/ir/base_nodes.cpp | 5 +++-- csrc/ir/base_nodes.h | 4 ++-- 4 files changed, 28 insertions(+), 24 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 3b9eff347fd..fac9b45b83a 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -94,7 +94,7 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { .hide_output = alias_info.hide_output}; } - to->all_tv_uses_valid_ = from->all_tv_uses_valid_; + to->all_tvs_and_uses_valid_ = from->all_tvs_and_uses_valid_; // This should never be true on copy, but copying for completeness. to->is_during_update_uses_ = from->is_during_update_uses_; @@ -176,7 +176,7 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); - all_tv_uses_valid_ = false; + all_tvs_and_uses_valid_ = false; is_during_update_uses_ = false; } @@ -258,7 +258,7 @@ void Fusion::addInput(Val* input) { inputs_.push_back(input); input->setIsFusionInput(true); - all_tv_uses_valid_ = false; + all_tvs_and_uses_valid_ = false; } void Fusion::addOutputInternal(Val* output) { @@ -272,7 +272,7 @@ void Fusion::addOutputInternal(Val* output) { outputs_.push_back(output); output->setIsFusionOutput(true); - all_tv_uses_valid_ = false; + all_tvs_and_uses_valid_ = false; } void Fusion::addOutput(Val* output) { @@ -298,7 +298,7 @@ void Fusion::removeInput(Val* input) { inputs_.erase(find_input); } input->setIsFusionInput(false); - all_tv_uses_valid_ = false; + all_tvs_and_uses_valid_ = false; } void Fusion::removeOutput(Val* output) { @@ -307,7 +307,7 @@ void Fusion::removeOutput(Val* output) { outputs_.erase(find_output); } output->setIsFusionOutput(false); - all_tv_uses_valid_ = false; + all_tvs_and_uses_valid_ = false; } void Fusion::replaceOutput(Val* output, Val* replacement) { @@ -619,12 +619,12 @@ void Fusion::registerExpr(Expr* expr) { } } -void Fusion::resetTvUses() { - FUSER_PERF_SCOPE("Fusion::resetTvUses"); +void Fusion::resetAllTvsAndUses() { + FUSER_PERF_SCOPE("Fusion::resetAllTvsAndUses"); is_during_update_uses_ = true; // getExprs only uses definition, so even if we've modified uses already to - // remove dead exprs, this could reinsert them. getExprs is also boundeds by + // remove dead exprs, this could reinsert them. getExprs is also bounded by // inputs as registered inputs will return nullptr as their definition. const auto all_tvs = ir_utils::filterByType(vals_); const auto used_exprs = StmtSort::getExprs(this); @@ -640,7 +640,11 @@ void Fusion::resetTvUses() { } } - all_tv_uses_valid_ = true; + // Update all_tvs entry as well + all_tvs_ptr_ = + std::make_unique>(ir_utils::allTvs(this)); + + all_tvs_and_uses_valid_ = true; is_during_update_uses_ = false; } @@ -863,9 +867,8 @@ bool isExpressionEvaluated(Fusion* fusion) { } const std::vector& Fusion::allTvs() { - if (all_tvs_ptr_ == nullptr) { - all_tvs_ptr_ = - std::make_unique>(ir_utils::allTvs(this)); + if (!all_tvs_and_uses_valid_) { + resetAllTvsAndUses(); } return *all_tvs_ptr_; } diff --git a/csrc/fusion.h b/csrc/fusion.h index a7228776b9c..bd19fb035f7 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -263,11 +263,11 @@ class NVF_API Fusion : public IrContainer { //! aliased. const AliasInfo& getOutputAlias(const Val* output) const; - bool isTVUseInfoValid() { - return all_tv_uses_valid_; + bool isAllTvsAndUsesValid() { + return all_tvs_and_uses_valid_; } - bool isUpdatingTVUseInfo() { + bool isUpdatingAllTvsAndUses() { return is_during_update_uses_; } @@ -455,13 +455,13 @@ class NVF_API Fusion : public IrContainer { //! Clear Expr's from TV uses that are not required to produce outputs from //! inputs. Only other place this is used (other than Fusion) is in - //! Val::uses() - void resetTvUses(); + //! Val::uses(). Also populate the all_tvs_ entry with all the used TVs. + void resetAllTvsAndUses(); //! Declare that TensorView uses need to be updated (but don't actually do //! the update). void invalidateTvUses() { - all_tv_uses_valid_ = false; + all_tvs_and_uses_valid_ = false; all_tvs_ptr_ = nullptr; } @@ -480,7 +480,7 @@ class NVF_API Fusion : public IrContainer { // Records if the current use data in the IR nodes are valid // the states are either all valid or all invalid - bool all_tv_uses_valid_ = false; + bool all_tvs_and_uses_valid_ = false; bool is_during_update_uses_ = false; std::vector> managed_data_; diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 16c2b61b397..c2a70a74579 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -89,8 +89,9 @@ NVFUSER_DEFINE_CLONE(Val) const std::vector& Val::uses() const { if (vtype_ == ValType::TensorView) { - if (!fusion()->isTVUseInfoValid() && !fusion()->isUpdatingTVUseInfo()) { - fusion()->resetTvUses(); + if (!fusion()->isAllTvsAndUsesValid() && + !fusion()->isUpdatingAllTvsAndUses()) { + fusion()->resetAllTvsAndUses(); } } return uses_; diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index 024fe3cbf4b..27f59b1548b 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -361,8 +361,8 @@ class NVF_API Val : public Statement { //! Returns the Exprs for which this is an input. //! Note that uses() will occasionally trigger a deferred call to - //! resetTvUses() which can be expensive as it requires traversing the graph - //! using Val definitions. + //! resetAllTvsAndUses() which can be expensive as it requires traversing the + //! graph using Val definitions. const std::vector& uses() const; bool isFusionInput() const { From dde80ae51c197bb0dcf8ea9f41c6c29d1667f15e Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 1 Sep 2024 11:05:18 -0700 Subject: [PATCH 3/8] Revert "Cache allTvs in Fusion like allTvUses." This reverts commit f89eeee82c318954562cb3a2029f3e07c5ab5424. --- csrc/fusion.cpp | 29 +++++++++++++---------------- csrc/fusion.h | 14 +++++++------- csrc/ir/base_nodes.cpp | 5 ++--- csrc/ir/base_nodes.h | 4 ++-- 4 files changed, 24 insertions(+), 28 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index fac9b45b83a..3b9eff347fd 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -94,7 +94,7 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { .hide_output = alias_info.hide_output}; } - to->all_tvs_and_uses_valid_ = from->all_tvs_and_uses_valid_; + to->all_tv_uses_valid_ = from->all_tv_uses_valid_; // This should never be true on copy, but copying for completeness. to->is_during_update_uses_ = from->is_during_update_uses_; @@ -176,7 +176,7 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); - all_tvs_and_uses_valid_ = false; + all_tv_uses_valid_ = false; is_during_update_uses_ = false; } @@ -258,7 +258,7 @@ void Fusion::addInput(Val* input) { inputs_.push_back(input); input->setIsFusionInput(true); - all_tvs_and_uses_valid_ = false; + all_tv_uses_valid_ = false; } void Fusion::addOutputInternal(Val* output) { @@ -272,7 +272,7 @@ void Fusion::addOutputInternal(Val* output) { outputs_.push_back(output); output->setIsFusionOutput(true); - all_tvs_and_uses_valid_ = false; + all_tv_uses_valid_ = false; } void Fusion::addOutput(Val* output) { @@ -298,7 +298,7 @@ void Fusion::removeInput(Val* input) { inputs_.erase(find_input); } input->setIsFusionInput(false); - all_tvs_and_uses_valid_ = false; + all_tv_uses_valid_ = false; } void Fusion::removeOutput(Val* output) { @@ -307,7 +307,7 @@ void Fusion::removeOutput(Val* output) { outputs_.erase(find_output); } output->setIsFusionOutput(false); - all_tvs_and_uses_valid_ = false; + all_tv_uses_valid_ = false; } void Fusion::replaceOutput(Val* output, Val* replacement) { @@ -619,12 +619,12 @@ void Fusion::registerExpr(Expr* expr) { } } -void Fusion::resetAllTvsAndUses() { - FUSER_PERF_SCOPE("Fusion::resetAllTvsAndUses"); +void Fusion::resetTvUses() { + FUSER_PERF_SCOPE("Fusion::resetTvUses"); is_during_update_uses_ = true; // getExprs only uses definition, so even if we've modified uses already to - // remove dead exprs, this could reinsert them. getExprs is also bounded by + // remove dead exprs, this could reinsert them. getExprs is also boundeds by // inputs as registered inputs will return nullptr as their definition. const auto all_tvs = ir_utils::filterByType(vals_); const auto used_exprs = StmtSort::getExprs(this); @@ -640,11 +640,7 @@ void Fusion::resetAllTvsAndUses() { } } - // Update all_tvs entry as well - all_tvs_ptr_ = - std::make_unique>(ir_utils::allTvs(this)); - - all_tvs_and_uses_valid_ = true; + all_tv_uses_valid_ = true; is_during_update_uses_ = false; } @@ -867,8 +863,9 @@ bool isExpressionEvaluated(Fusion* fusion) { } const std::vector& Fusion::allTvs() { - if (!all_tvs_and_uses_valid_) { - resetAllTvsAndUses(); + if (all_tvs_ptr_ == nullptr) { + all_tvs_ptr_ = + std::make_unique>(ir_utils::allTvs(this)); } return *all_tvs_ptr_; } diff --git a/csrc/fusion.h b/csrc/fusion.h index bd19fb035f7..a7228776b9c 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -263,11 +263,11 @@ class NVF_API Fusion : public IrContainer { //! aliased. const AliasInfo& getOutputAlias(const Val* output) const; - bool isAllTvsAndUsesValid() { - return all_tvs_and_uses_valid_; + bool isTVUseInfoValid() { + return all_tv_uses_valid_; } - bool isUpdatingAllTvsAndUses() { + bool isUpdatingTVUseInfo() { return is_during_update_uses_; } @@ -455,13 +455,13 @@ class NVF_API Fusion : public IrContainer { //! Clear Expr's from TV uses that are not required to produce outputs from //! inputs. Only other place this is used (other than Fusion) is in - //! Val::uses(). Also populate the all_tvs_ entry with all the used TVs. - void resetAllTvsAndUses(); + //! Val::uses() + void resetTvUses(); //! Declare that TensorView uses need to be updated (but don't actually do //! the update). void invalidateTvUses() { - all_tvs_and_uses_valid_ = false; + all_tv_uses_valid_ = false; all_tvs_ptr_ = nullptr; } @@ -480,7 +480,7 @@ class NVF_API Fusion : public IrContainer { // Records if the current use data in the IR nodes are valid // the states are either all valid or all invalid - bool all_tvs_and_uses_valid_ = false; + bool all_tv_uses_valid_ = false; bool is_during_update_uses_ = false; std::vector> managed_data_; diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index c2a70a74579..16c2b61b397 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -89,9 +89,8 @@ NVFUSER_DEFINE_CLONE(Val) const std::vector& Val::uses() const { if (vtype_ == ValType::TensorView) { - if (!fusion()->isAllTvsAndUsesValid() && - !fusion()->isUpdatingAllTvsAndUses()) { - fusion()->resetAllTvsAndUses(); + if (!fusion()->isTVUseInfoValid() && !fusion()->isUpdatingTVUseInfo()) { + fusion()->resetTvUses(); } } return uses_; diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index 27f59b1548b..024fe3cbf4b 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -361,8 +361,8 @@ class NVF_API Val : public Statement { //! Returns the Exprs for which this is an input. //! Note that uses() will occasionally trigger a deferred call to - //! resetAllTvsAndUses() which can be expensive as it requires traversing the - //! graph using Val definitions. + //! resetTvUses() which can be expensive as it requires traversing the graph + //! using Val definitions. const std::vector& uses() const; bool isFusionInput() const { From 71ad77d03c0b9dcd0e43b391433ae573aa51e38c Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 1 Sep 2024 12:25:09 -0700 Subject: [PATCH 4/8] Cache allTvs in Fusion like allTvUses. --- csrc/fusion.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 3b9eff347fd..84476b2538a 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -175,6 +175,7 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); + all_tvs_ptr_.reset(); all_tv_uses_valid_ = false; is_during_update_uses_ = false; @@ -259,6 +260,7 @@ void Fusion::addInput(Val* input) { input->setIsFusionInput(true); all_tv_uses_valid_ = false; + all_tvs_ptr_.reset(); } void Fusion::addOutputInternal(Val* output) { @@ -273,6 +275,7 @@ void Fusion::addOutputInternal(Val* output) { output->setIsFusionOutput(true); all_tv_uses_valid_ = false; + all_tvs_ptr_.reset(); } void Fusion::addOutput(Val* output) { @@ -299,6 +302,7 @@ void Fusion::removeInput(Val* input) { } input->setIsFusionInput(false); all_tv_uses_valid_ = false; + all_tvs_ptr_.reset(); } void Fusion::removeOutput(Val* output) { @@ -308,6 +312,7 @@ void Fusion::removeOutput(Val* output) { } output->setIsFusionOutput(false); all_tv_uses_valid_ = false; + all_tvs_ptr_.reset(); } void Fusion::replaceOutput(Val* output, Val* replacement) { From 57ba06a8963a6f1f1647ca307faa6631a937457a Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 1 Sep 2024 12:28:49 -0700 Subject: [PATCH 5/8] Minor cleanup. --- csrc/fusion.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fusion.h b/csrc/fusion.h index a7228776b9c..c7ded315cec 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -462,7 +462,7 @@ class NVF_API Fusion : public IrContainer { //! the update). void invalidateTvUses() { all_tv_uses_valid_ = false; - all_tvs_ptr_ = nullptr; + all_tvs_ptr_.reset(); } private: From 8e308b5ddb66f0f8e86d4478e045abbea07a674f Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 1 Sep 2024 12:30:31 -0700 Subject: [PATCH 6/8] Minor cleanup. --- csrc/fusion.cpp | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 84476b2538a..a85047db242 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -175,9 +175,9 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); - all_tvs_ptr_.reset(); - all_tv_uses_valid_ = false; + invalidateTvUses(); + is_during_update_uses_ = false; } @@ -259,8 +259,7 @@ void Fusion::addInput(Val* input) { inputs_.push_back(input); input->setIsFusionInput(true); - all_tv_uses_valid_ = false; - all_tvs_ptr_.reset(); + invalidateTvUses(); } void Fusion::addOutputInternal(Val* output) { @@ -274,8 +273,7 @@ void Fusion::addOutputInternal(Val* output) { outputs_.push_back(output); output->setIsFusionOutput(true); - all_tv_uses_valid_ = false; - all_tvs_ptr_.reset(); + invalidateTvUses(); } void Fusion::addOutput(Val* output) { @@ -301,8 +299,7 @@ void Fusion::removeInput(Val* input) { inputs_.erase(find_input); } input->setIsFusionInput(false); - all_tv_uses_valid_ = false; - all_tvs_ptr_.reset(); + invalidateTvUses(); } void Fusion::removeOutput(Val* output) { @@ -311,8 +308,7 @@ void Fusion::removeOutput(Val* output) { outputs_.erase(find_output); } output->setIsFusionOutput(false); - all_tv_uses_valid_ = false; - all_tvs_ptr_.reset(); + invalidateTvUses(); } void Fusion::replaceOutput(Val* output, Val* replacement) { From 878a225dca72915d1e8a925487c3a096c85852b3 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 1 Sep 2024 12:31:46 -0700 Subject: [PATCH 7/8] Minor cleanup. --- csrc/fusion.cpp | 18 +++++++++--------- csrc/fusion.h | 2 +- csrc/ir/base_nodes.cpp | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index a85047db242..5d67a1c59c4 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -176,7 +176,7 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); - invalidateTvUses(); + invalidateTvsAndUses(); is_during_update_uses_ = false; } @@ -193,7 +193,7 @@ void Fusion::removeExpr(Expr* expr) { // Remove uses in inputs for (auto inp : expr->inputs()) { - // Note that if inp is a TensorView, this may call invalidateTvUses + // Note that if inp is a TensorView, this may call invalidateTvsAndUses inp->removeUse(expr); } @@ -259,7 +259,7 @@ void Fusion::addInput(Val* input) { inputs_.push_back(input); input->setIsFusionInput(true); - invalidateTvUses(); + invalidateTvsAndUses(); } void Fusion::addOutputInternal(Val* output) { @@ -273,7 +273,7 @@ void Fusion::addOutputInternal(Val* output) { outputs_.push_back(output); output->setIsFusionOutput(true); - invalidateTvUses(); + invalidateTvsAndUses(); } void Fusion::addOutput(Val* output) { @@ -299,7 +299,7 @@ void Fusion::removeInput(Val* input) { inputs_.erase(find_input); } input->setIsFusionInput(false); - invalidateTvUses(); + invalidateTvsAndUses(); } void Fusion::removeOutput(Val* output) { @@ -308,7 +308,7 @@ void Fusion::removeOutput(Val* output) { outputs_.erase(find_output); } output->setIsFusionOutput(false); - invalidateTvUses(); + invalidateTvsAndUses(); } void Fusion::replaceOutput(Val* output, Val* replacement) { @@ -335,7 +335,7 @@ void Fusion::replaceOutput(Val* output, Val* replacement) { } } // Mark uses invalid so that they will be reset next time uses() is called - invalidateTvUses(); + invalidateTvsAndUses(); } // Temporary WAR for issue #1112 @@ -591,7 +591,7 @@ void Fusion::registerExpr(Expr* expr) { // Don't just add this expr as a use of the input if it's a tensor as the // whole fusion needs to be traversed to rebuild the usage lists if (input->isA()) { - invalidateTvUses(); + invalidateTvsAndUses(); } else { input->addUse(expr); } @@ -614,7 +614,7 @@ void Fusion::registerExpr(Expr* expr) { // If that happens, our definition-based traversal can change and // introduce whole new branches, so we need to recompute the uses_ // vector after setDefinition. - invalidateTvUses(); + invalidateTvsAndUses(); } } } diff --git a/csrc/fusion.h b/csrc/fusion.h index c7ded315cec..90a405c4371 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -460,7 +460,7 @@ class NVF_API Fusion : public IrContainer { //! Declare that TensorView uses need to be updated (but don't actually do //! the update). - void invalidateTvUses() { + void invalidateTvsAndUses() { all_tv_uses_valid_ = false; all_tvs_ptr_.reset(); } diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 16c2b61b397..477cf2af126 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -110,7 +110,7 @@ bool Val::removeUse(Expr* expr) { uses_.erase(it); if (this->isA()) { // Call for a rebuild of uses_ vector - fusion()->invalidateTvUses(); + fusion()->invalidateTvsAndUses(); } return true; } From 72615c805bd858a9b518b0a588d3537c51648045 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 1 Sep 2024 19:24:45 -0700 Subject: [PATCH 8/8] Return a copy of allTvs so it doesn't invalidate while traversing fusion->allTvs. --- csrc/fusion.cpp | 12 ++++++++++-- csrc/fusion.h | 8 +++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 5d67a1c59c4..bd864c9f881 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -188,6 +188,9 @@ void Fusion::removeExpr(Expr* expr) { // we're going with the strictest model which errors. for (auto out : expr->outputs()) { + if (out->isA()) { + invalidateTvsAndUses(); + } out->setDefinition(nullptr); } @@ -195,6 +198,9 @@ void Fusion::removeExpr(Expr* expr) { for (auto inp : expr->inputs()) { // Note that if inp is a TensorView, this may call invalidateTvsAndUses inp->removeUse(expr); + if (inp->isA()) { + invalidateTvsAndUses(); + } } IrContainer::removeExpr(expr); @@ -237,6 +243,8 @@ void Fusion::removeVal(Val* val) { removeExpr(e); } IrContainer::removeVal(val); + + invalidateTvsAndUses(); } void Fusion::addInput(Val* input) { @@ -863,12 +871,12 @@ bool isExpressionEvaluated(Fusion* fusion) { }); } -const std::vector& Fusion::allTvs() { +std::vector Fusion::allTvs() { if (all_tvs_ptr_ == nullptr) { all_tvs_ptr_ = std::make_unique>(ir_utils::allTvs(this)); } - return *all_tvs_ptr_; + return std::vector(*all_tvs_ptr_); } } // namespace nvfuser diff --git a/csrc/fusion.h b/csrc/fusion.h index 90a405c4371..b640fc02fb5 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -429,9 +429,11 @@ class NVF_API Fusion : public IrContainer { expected_dynamic_smem_bytes_ = bytes; } - //! This is a cached version of ir_utils::allTvs that is invalidated - //! whenever we invalidate TV uses - const std::vector& allTvs(); + //! This is a cached version of ir_utils::allTvs that is invalidated. Return a + //! copy of the vector instead of a reference as it can be invalidated by many + //! operations. If we returned a reference and are iterating on it while + //! making modifications to the fusion, it can easily cause a segfault. + std::vector allTvs(); protected: friend SegmentCandidateFinder;