Skip to content

Commit

Permalink
Cache ir_utils::allTvs as part of Fusion (#2873)
Browse files Browse the repository at this point in the history
This caches `ir_utils::allTvs(fusion)` as `fusion->allTvs()`. The cache
is automatically invalidated whenever the TV graph topology changes;
this mechanism is the same one used to recompute `Expr` uses
automatically.

---------

Co-authored-by: Christian Sarofeen <[email protected]>
  • Loading branch information
jacobhinkle and csarofeen authored Sep 2, 2024
1 parent 02f6514 commit 744bf54
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 13 deletions.
43 changes: 34 additions & 9 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {

to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_;

if (from->all_tvs_ptr_ != nullptr) {
to->all_tvs_ptr_ = std::make_unique<std::vector<TensorView*>>();
to->all_tvs_ptr_->reserve(from->all_tvs_ptr_->size());
for (TensorView* from_tv : *from->all_tvs_ptr_) {
to->all_tvs_ptr_->push_back(ir_cloner.clone(from_tv)->as<TensorView>());
}
}

return ir_cloner;
}

Expand Down Expand Up @@ -168,7 +176,8 @@ void Fusion::clear() noexcept {
managed_data_.clear();
managed_named_data_.clear();

all_tv_uses_valid_ = false;
invalidateTvsAndUses();

is_during_update_uses_ = false;
}

Expand All @@ -179,13 +188,19 @@ void Fusion::removeExpr(Expr* expr) {
// we're going with the strictest model which errors.

for (auto out : expr->outputs()) {
if (out->isA<TensorView>()) {
invalidateTvsAndUses();
}
out->setDefinition(nullptr);
}

// Remove uses in inputs
for (auto inp : expr->inputs()) {
// Note that if inp is a TensorView, this may call invalidateTvUses
// Note that if inp is a TensorView, this may call invalidateTvsAndUses
inp->removeUse(expr);
if (inp->isA<TensorView>()) {
invalidateTvsAndUses();
}
}

IrContainer::removeExpr(expr);
Expand Down Expand Up @@ -228,6 +243,8 @@ void Fusion::removeVal(Val* val) {
removeExpr(e);
}
IrContainer::removeVal(val);

invalidateTvsAndUses();
}

void Fusion::addInput(Val* input) {
Expand All @@ -250,7 +267,7 @@ void Fusion::addInput(Val* input) {
inputs_.push_back(input);
input->setIsFusionInput(true);

all_tv_uses_valid_ = false;
invalidateTvsAndUses();
}

void Fusion::addOutputInternal(Val* output) {
Expand All @@ -264,7 +281,7 @@ void Fusion::addOutputInternal(Val* output) {
outputs_.push_back(output);
output->setIsFusionOutput(true);

all_tv_uses_valid_ = false;
invalidateTvsAndUses();
}

void Fusion::addOutput(Val* output) {
Expand All @@ -290,7 +307,7 @@ void Fusion::removeInput(Val* input) {
inputs_.erase(find_input);
}
input->setIsFusionInput(false);
all_tv_uses_valid_ = false;
invalidateTvsAndUses();
}

void Fusion::removeOutput(Val* output) {
Expand All @@ -299,7 +316,7 @@ void Fusion::removeOutput(Val* output) {
outputs_.erase(find_output);
}
output->setIsFusionOutput(false);
all_tv_uses_valid_ = false;
invalidateTvsAndUses();
}

void Fusion::replaceOutput(Val* output, Val* replacement) {
Expand All @@ -326,7 +343,7 @@ void Fusion::replaceOutput(Val* output, Val* replacement) {
}
}
// Mark uses invalid so that they will be reset next time uses() is called
invalidateTvUses();
invalidateTvsAndUses();
}

// Temporary WAR for issue #1112
Expand Down Expand Up @@ -582,7 +599,7 @@ void Fusion::registerExpr(Expr* expr) {
// Don't just add this expr as a use of the input if it's a tensor as the
// whole fusion needs to be traversed to rebuild the usage lists
if (input->isA<TensorView>()) {
invalidateTvUses();
invalidateTvsAndUses();
} else {
input->addUse(expr);
}
Expand All @@ -605,7 +622,7 @@ void Fusion::registerExpr(Expr* expr) {
// If that happens, our definition-based traversal can change and
// introduce whole new branches, so we need to recompute the uses_
// vector after setDefinition.
invalidateTvUses();
invalidateTvsAndUses();
}
}
}
Expand Down Expand Up @@ -854,4 +871,12 @@ bool isExpressionEvaluated(Fusion* fusion) {
});
}

std::vector<TensorView*> Fusion::allTvs() {
if (all_tvs_ptr_ == nullptr) {
all_tvs_ptr_ =
std::make_unique<std::vector<TensorView*>>(ir_utils::allTvs(this));
}
return std::vector<TensorView*>(*all_tvs_ptr_);
}

} // namespace nvfuser
11 changes: 10 additions & 1 deletion csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,12 @@ class NVF_API Fusion : public IrContainer {
expected_dynamic_smem_bytes_ = bytes;
}

//! This is a cached version of ir_utils::allTvs that is invalidated. Return a
//! copy of the vector instead of a reference as it can be invalidated by many
//! operations. If we returned a reference and are iterating on it while
//! making modifications to the fusion, it can easily cause a segfault.
std::vector<TensorView*> allTvs();

protected:
friend SegmentCandidateFinder;
friend SegmentedFusion;
Expand Down Expand Up @@ -456,8 +462,9 @@ class NVF_API Fusion : public IrContainer {

//! Declare that TensorView uses need to be updated (but don't actually do
//! the update).
void invalidateTvUses() {
void invalidateTvsAndUses() {
all_tv_uses_valid_ = false;
all_tvs_ptr_.reset();
}

private:
Expand Down Expand Up @@ -485,6 +492,8 @@ class NVF_API Fusion : public IrContainer {
// If set to a non-negative value during scheduling, this will be checked by
// the executor.
int64_t expected_dynamic_smem_bytes_ = -1LL;

std::unique_ptr<std::vector<TensorView*>> all_tvs_ptr_ = nullptr;
};

// Returns true if all fusion outputs are expression evaluated.
Expand Down
2 changes: 1 addition & 1 deletion csrc/ir/base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ bool Val::removeUse(Expr* expr) {
uses_.erase(it);
if (this->isA<TensorView>()) {
// Call for a rebuild of uses_ vector
fusion()->invalidateTvUses();
fusion()->invalidateTvsAndUses();
}
return true;
}
Expand Down
6 changes: 4 additions & 2 deletions csrc/kernel_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,8 @@ FusionKernelRuntime::FusionKernelRuntime(

// SchedulerRuntimeInfo modifies the fusion, so it is required for both
// compile paths.
std::vector<TensorView*> all_tvs = ir_utils::allTvs(fusion.get());
std::vector<TensorView*> all_tvs =
fusion->allTvs(); // ir_utils::allTvs(fusion.get());
SchedulerRuntimeInfo runtime_info(
fusion.get(), args, nullptr, all_tvs, forced_index_type);

Expand Down Expand Up @@ -1453,7 +1454,8 @@ std::optional<FusionKernelRuntime::HeuristicsPtr> FusionKernelRuntime::

// Get all tensorviews for segmented fusion
std::vector<TensorView*> all_tvs_for_fusion_to_run =
ir_utils::allTvs(fusion_to_run);
fusion_to_run->allTvs();
// ir_utils::allTvs(fusion_to_run);

SchedulerRuntimeInfo fusion_to_run_info(
fusion_to_run,
Expand Down

0 comments on commit 744bf54

Please sign in to comment.