Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache ir_utils::allTvs as part of Fusion #2873

Merged
merged 8 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {

to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_;

if (from->all_tvs_ptr_ != nullptr) {
csarofeen marked this conversation as resolved.
Show resolved Hide resolved
to->all_tvs_ptr_ = std::make_unique<std::vector<TensorView*>>();
to->all_tvs_ptr_->reserve(from->all_tvs_ptr_->size());
for (TensorView* from_tv : *from->all_tvs_ptr_) {
to->all_tvs_ptr_->push_back(ir_cloner.clone(from_tv)->as<TensorView>());
}
}

return ir_cloner;
}

Expand Down Expand Up @@ -168,7 +176,8 @@ void Fusion::clear() noexcept {
managed_data_.clear();
managed_named_data_.clear();

all_tv_uses_valid_ = false;
invalidateTvsAndUses();

is_during_update_uses_ = false;
}

Expand All @@ -184,7 +193,7 @@ void Fusion::removeExpr(Expr* expr) {

// Remove uses in inputs
for (auto inp : expr->inputs()) {
// Note that if inp is a TensorView, this may call invalidateTvUses
// Note that if inp is a TensorView, this may call invalidateTvsAndUses
inp->removeUse(expr);
}

Expand Down Expand Up @@ -250,7 +259,7 @@ void Fusion::addInput(Val* input) {
inputs_.push_back(input);
input->setIsFusionInput(true);

all_tv_uses_valid_ = false;
invalidateTvsAndUses();
}

void Fusion::addOutputInternal(Val* output) {
Expand All @@ -264,7 +273,7 @@ void Fusion::addOutputInternal(Val* output) {
outputs_.push_back(output);
output->setIsFusionOutput(true);

all_tv_uses_valid_ = false;
invalidateTvsAndUses();
}

void Fusion::addOutput(Val* output) {
Expand All @@ -290,7 +299,7 @@ void Fusion::removeInput(Val* input) {
inputs_.erase(find_input);
}
input->setIsFusionInput(false);
all_tv_uses_valid_ = false;
invalidateTvsAndUses();
}

void Fusion::removeOutput(Val* output) {
Expand All @@ -299,7 +308,7 @@ void Fusion::removeOutput(Val* output) {
outputs_.erase(find_output);
}
output->setIsFusionOutput(false);
all_tv_uses_valid_ = false;
invalidateTvsAndUses();
}

void Fusion::replaceOutput(Val* output, Val* replacement) {
Expand All @@ -326,7 +335,7 @@ void Fusion::replaceOutput(Val* output, Val* replacement) {
}
}
// Mark uses invalid so that they will be reset next time uses() is called
invalidateTvUses();
invalidateTvsAndUses();
}

// Temporary WAR for issue #1112
Expand Down Expand Up @@ -582,7 +591,7 @@ void Fusion::registerExpr(Expr* expr) {
// Don't just add this expr as a use of the input if it's a tensor as the
// whole fusion needs to be traversed to rebuild the usage lists
if (input->isA<TensorView>()) {
invalidateTvUses();
invalidateTvsAndUses();
} else {
input->addUse(expr);
}
Expand All @@ -605,7 +614,7 @@ void Fusion::registerExpr(Expr* expr) {
// If that happens, our definition-based traversal can change and
// introduce whole new branches, so we need to recompute the uses_
// vector after setDefinition.
invalidateTvUses();
invalidateTvsAndUses();
}
}
}
Expand Down Expand Up @@ -854,4 +863,12 @@ bool isExpressionEvaluated(Fusion* fusion) {
});
}

const std::vector<TensorView*>& Fusion::allTvs() {
if (all_tvs_ptr_ == nullptr) {
all_tvs_ptr_ =
std::make_unique<std::vector<TensorView*>>(ir_utils::allTvs(this));
}
return *all_tvs_ptr_;
}

} // namespace nvfuser
9 changes: 8 additions & 1 deletion csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ class NVF_API Fusion : public IrContainer {
expected_dynamic_smem_bytes_ = bytes;
}

//! This is a cached version of ir_utils::allTvs that is invalidated
//! whenever we invalidate TV uses
const std::vector<TensorView*>& allTvs();

protected:
friend SegmentCandidateFinder;
friend SegmentedFusion;
Expand Down Expand Up @@ -456,8 +460,9 @@ class NVF_API Fusion : public IrContainer {

//! Declare that TensorView uses need to be updated (but don't actually do
//! the update).
void invalidateTvUses() {
void invalidateTvsAndUses() {
all_tv_uses_valid_ = false;
all_tvs_ptr_.reset();
}

private:
Expand Down Expand Up @@ -485,6 +490,8 @@ class NVF_API Fusion : public IrContainer {
// If set to a non-negative value during scheduling, this will be checked by
// the executor.
int64_t expected_dynamic_smem_bytes_ = -1LL;

std::unique_ptr<std::vector<TensorView*>> all_tvs_ptr_ = nullptr;
};

// Returns true if all fusion outputs are expression evaluated.
Expand Down
2 changes: 1 addition & 1 deletion csrc/ir/base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ bool Val::removeUse(Expr* expr) {
uses_.erase(it);
if (this->isA<TensorView>()) {
// Call for a rebuild of uses_ vector
fusion()->invalidateTvUses();
fusion()->invalidateTvsAndUses();
}
return true;
}
Expand Down
6 changes: 4 additions & 2 deletions csrc/kernel_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,8 @@ FusionKernelRuntime::FusionKernelRuntime(

// SchedulerRuntimeInfo modifies the fusion, so it is required for both
// compile paths.
std::vector<TensorView*> all_tvs = ir_utils::allTvs(fusion.get());
std::vector<TensorView*> all_tvs =
fusion->allTvs(); // ir_utils::allTvs(fusion.get());
SchedulerRuntimeInfo runtime_info(
fusion.get(), args, nullptr, all_tvs, forced_index_type);

Expand Down Expand Up @@ -1453,7 +1454,8 @@ std::optional<FusionKernelRuntime::HeuristicsPtr> FusionKernelRuntime::

// Get all tensorviews for segmented fusion
std::vector<TensorView*> all_tvs_for_fusion_to_run =
ir_utils::allTvs(fusion_to_run);
fusion_to_run->allTvs();
// ir_utils::allTvs(fusion_to_run);

SchedulerRuntimeInfo fusion_to_run_info(
fusion_to_run,
Expand Down
Loading