Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add information for coordinating segments in python frontend. #3289

Merged
merged 5 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions csrc/python_frontend/fusion_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,15 +780,23 @@ void FusionCache::deserialize(std::string filename) {
NVF_CHECK(
trie_ptr->fusion_id == fb_trie_node->fusion_id(),
"The fusion id for this TrieNode should already be set.")
Fusion* fusion =
queryFusionSchedules(fb_trie_node->fusion_id())->preschedFusion();
FusionSchedules* fs = queryFusionSchedules(fb_trie_node->fusion_id());
Fusion* fusion = fs->preschedFusion();
try {
// There could be bad fusion in the serialization.
state->buildFusionIr(fusion);
} catch (const std::exception& e) {
// catch exception and setException for the terminal node
trie_ptr->setException(e.what());
}
// The FusionState creates a mapping from CPP Fusion to its State objects.
// Since the CPP Fusion is cached in FusionCache and the FusionState is
// temporary, the information linking CPP Fusion and Python
// FusionDefinition is stored in FusionCache.
fs->inputs_fid_ = state->inputs();
fs->outputs_fid_ = state->outputs();
fs->extents_fid_ = state->extents();
fs->map_value_to_fid_ = state->getValueMap();
}

// Table TrieNode => Field: children: [ulong]
Expand Down
8 changes: 8 additions & 0 deletions csrc/python_frontend/fusion_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ struct FusionSchedules {
std::mutex scheds_lock;
//! ID of fusion in python frontend fusion cache
int64_t fusion_id_ = -1;
//! Input arguments for FusionState
rdspring1 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<int64_t> inputs_fid_;
//! Extents for TensorView input arguments for FusionState
std::vector<int64_t> extents_fid_;
//! Output arguments for FusionState
std::vector<int64_t> outputs_fid_;
//! Map Fusion Val to its corresponding FusionDefinition index
std::unordered_map<const Val*, int64_t> map_value_to_fid_;
};

//! \struct TrieNode
Expand Down
22 changes: 22 additions & 0 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ void FusionDefinition::finalizeDefinition() {
throw;
}

// The FusionState creates a mapping from CPP Fusion to its State objects.
// Since the CPP Fusion is cached in FusionCache and the FusionState is
// temporary, the information linking CPP Fusion and Python
// FusionDefinition is stored in FusionCache.
FusionSchedules* fs =
fusionCache()->queryFusionSchedules(fusion_id_.value());
fs->inputs_fid_ = inputs();
fs->outputs_fid_ = outputs();
fs->extents_fid_ = extents();
fs->map_value_to_fid_ = getValueMap();

if (isDebugDumpEnabled(DebugDumpOption::FusionIrOriginal)) {
printIr();
}
Expand All @@ -120,6 +131,17 @@ void FusionDefinition::finalizeDefinition() {
// build a proper fusion earlier.
NVF_CHECK(!opt_e.has_value(), opt_e.value());
fusion_id_ = std::optional<size_t>(trie_node_->fusion_id);

// A CPP fusion already exists in the FusionCache for this FusionDefinition.
// In this case, a new CPP Fusion is not created, so the mapping from CPP
// fusion to Python FusionDefinition is not initialized. This state is
// stored within FusionSchedules and is retrieved for this FusionDefinition.
FusionSchedules* fs =
fusionCache()->queryFusionSchedules(fusion_id_.value());
inputs_fid_ = fs->inputs_fid_;
outputs_fid_ = fs->outputs_fid_;
extents_fid_ = fs->extents_fid_;
map_value_to_fid_ = fs->map_value_to_fid_;
}

NVF_ERROR(
Expand Down
8 changes: 4 additions & 4 deletions csrc/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,7 @@ struct TensorRecord : RecordFunctor {
}

fd.setFusionState(outputs_.at(0).index, tv);
fd.addInput(tv);
fd.addInput(tv, outputs_.at(0).index);
}

void print(std::ostream& os, bool close_function = true) const final {
Expand Down Expand Up @@ -1545,12 +1545,12 @@ struct OutputRecord : RecordFunctor {
}
tv_output->setAllocationDomain(allocation_domain, true);
}
fd.addOutput(tv_output);
fd.addOutput(tv_output, args_.at(0).index);
} else {
NVF_CHECK(
stride_order_.empty(),
"stride_order can't be dictated for scalar outputs.");
fd.addOutput(output);
fd.addOutput(output, args_.at(0).index);
}
}
}
Expand Down Expand Up @@ -2015,7 +2015,7 @@ struct ScalarRecord : RecordFunctor {
void operator()(FusionState& fd) final {
Val* output = IrBuilder::create<nvfuser::Val>(value_, dtype_);
if (!value_.hasValue()) {
fd.addInput(output);
fd.addInput(output, outputs_.at(0).index);
}
fd.setFusionState(outputs_.at(0).index, output);
}
Expand Down
86 changes: 84 additions & 2 deletions csrc/python_frontend/fusion_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,22 @@ std::unique_ptr<FusionState> FusionState::clone() {
state->fusion_state_.insert(
state->fusion_state_.end(), fusion_state_.begin(), fusion_state_.end());
state->num_recording_states_ = num_recording_states_;
std::copy(
inputs_fid_.begin(),
inputs_fid_.end(),
std::back_inserter(state->inputs_fid_));
std::copy(
outputs_fid_.begin(),
outputs_fid_.end(),
std::back_inserter(state->outputs_fid_));
std::copy(
extents_fid_.begin(),
extents_fid_.end(),
std::back_inserter(state->extents_fid_));
std::copy(
map_value_to_fid_.begin(),
map_value_to_fid_.end(),
std::inserter(state->map_value_to_fid_, state->map_value_to_fid_.end()));
return state;
}

Expand All @@ -108,6 +124,7 @@ void FusionState::buildFusionIr(Fusion* fusion) {
e.what());
}
}
addExtents();
}

void FusionState::addRecord(RecordFunctor* record) {
Expand Down Expand Up @@ -147,6 +164,10 @@ void FusionState::resetFusionState(Fusion* fusion, size_t size) {
fusion_ = fusion;
fusion_state_.clear();
fusion_state_.resize(size, {});
inputs_fid_.clear();
outputs_fid_.clear();
extents_fid_.clear();
map_value_to_fid_.clear();
}

void FusionState::addFusionState(Val* val) {
Expand Down Expand Up @@ -178,6 +199,7 @@ size_t FusionState::numFusionStates() const {

void FusionState::setFusionState(size_t index, Val* val) {
fusion_state_.at(index) = {val};
map_value_to_fid_.emplace(val, (int64_t)index);
}

void FusionState::setFusionStateVector(size_t index, std::vector<Val*> val) {
Expand All @@ -189,14 +211,18 @@ void FusionState::setFusionStateVector(size_t index, std::vector<Val*> val) {
fusion_state_.at(index) = {val};
}

void FusionState::addInput(Val* input) {
void FusionState::addInput(Val* input, size_t index) {
NVF_CHECK(fusion_ != nullptr, "Fusion is undefined.");
fusion_->addInput(input);
map_value_to_fid_.emplace(input, (int64_t)index);
inputs_fid_.push_back((int64_t)index);
}

void FusionState::addOutput(Val* output) {
void FusionState::addOutput(Val* output, size_t index) {
NVF_CHECK(fusion_ != nullptr, "Fusion is undefined.");
fusion_->addOutput(output);
map_value_to_fid_.emplace(output, (int64_t)index);
outputs_fid_.push_back((int64_t)index);
}

void FusionState::aliasOutputToInput(Val* output, Val* input) {
Expand All @@ -206,4 +232,60 @@ void FusionState::aliasOutputToInput(Val* output, Val* input) {
fusion_->aliasOutputToInput(output, input, AllocationType::ReuseBuffer);
}

const std::unordered_map<const Val*, int64_t>& FusionState::getValueMap()
const {
return map_value_to_fid_;
}

const std::vector<int64_t>& FusionState::inputs() const {
return inputs_fid_;
}

const std::vector<int64_t>& FusionState::outputs() const {
return outputs_fid_;
}

const std::vector<int64_t>& FusionState::extents() const {
return extents_fid_;
}

std::vector<Val*> FusionState::getExtents(Fusion* fusion) {
NVF_CHECK(fusion != nullptr, "Fusion is undefined.");

std::vector<Val*> extents;
for (Val* v : fusion->inputs()) {
// short-circuit: skip if not TensorView
if (!v->isA<TensorView>()) {
continue;
}
TensorView* tv = v->as<TensorView>();
std::vector<IterDomain*> logical_dom =
TensorDomain::noReductions(tv->getLogicalDomain());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wujingyue is trying to change how we bind IO buffers to kernels. i.e. we might rethink which domain and how we are going to use here.

Not proposing any change, just trying to raise awareness.

std::transform(
logical_dom.begin(),
logical_dom.end(),
std::back_inserter(extents),
[](IterDomain* id) { return id->getMaybeExpandedExtent(); });
}
return extents;
}

void FusionState::addExtents() {
NVF_CHECK(fusion_ != nullptr, "Fusion is undefined.");

// The size of the tensor dimensions can be used as an input of the
// segments. NvFuser does not support returning scalar values. Segmentation
// must pass those sizes as segment arguments manually.
std::vector<Val*> extents = getExtents(fusion_);
for (Val* extent : extents) {
int64_t num_extents = (int64_t)extents_fid_.size();
int64_t extent_fid = -num_extents - 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a negative index?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All scalars, vectors, and tensors use positive indices. The extents do not exist in the FusionState, so I used the negative numbers exclusively for the extent scalars.

The extents are the size of iterDomain in CPP fusion. We don't track those in FusionDefinition but they can become input arguments to a FusionDefinition after segmentation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the negative number here is just an initialization? does the number carry any meaning or does a global -1 would do it just fine?
sorry I might miss the part where extents_fid_ is being used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an ordering component to the extent index.

It is used for the same purpose as collecting the extents in prepareRuntimeOrder.
https://github.com/NVIDIA/Fuser/blob/main/csrc/runtime/fusion_cache_utils.cpp#L199-L208

We're mapping the tensor sizes to the extents like so https://github.com/NVIDIA/Fuser/pull/3025/files#diff-e512bea3b02f75ab1e81b759562879c5867e6e863679d6e7696fa34087dc3dc9R98-R100.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this in a comment listing the use of negative indices to avoid conflict with other indices.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got'ya. It's hard to figure out the necessity without looking at the actual use. We can keep it as-is and revisit in follow up PRs.

extents_fid_.push_back(extent_fid);
// The extent can already exist in the fusion. However, since scalars cannot
// be passed between segments, always overwrited existing fids. The original
// fusion definition will provide scalar extents.
map_value_to_fid_[extent] = extent_fid;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit lost here.

iiuc, the map_value_to_fid_ on other values are mapped from the Val* to their index field in FusionState. Here looks like we are trying to create a the same thing for each TensorView's logical domain. Where are we creating the python container for that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not exposing the TensorView's logical domain to the python frontend, but I am tracking it in the FusionState. We may have to pass the scalar extents of the TensorView's logical domain as an input argument to a fusion segment.

}
}

} // namespace nvfuser::python_frontend
28 changes: 25 additions & 3 deletions csrc/python_frontend/fusion_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,24 @@ class FusionState {
NVF_API void setFusionStateVector(size_t index, std::vector<Val*> val);

//! Adds a Tensor/Scalar input to the Fusion object
NVF_API void addInput(Val* input);
NVF_API void addInput(Val* input, size_t index);
//! Adds a Tensor/Scalar output to the Fusion object
NVF_API void addOutput(Val* output);
NVF_API void addOutput(Val* output, size_t index);
//! Alias an Output to Input in the Fusion object
NVF_API void aliasOutputToInput(Val* output, Val* input);

//! Get map between CPP Fusion and Python FusionDefinition
NVF_API const std::unordered_map<const Val*, int64_t>& getValueMap() const;
//! Get indicies for the inputs of FusionState
NVF_API const std::vector<int64_t>& inputs() const;
//! Get indicies for the outputs of FusionState
NVF_API const std::vector<int64_t>& outputs() const;
//! Get indicies for the extents of TensorView inputs of FusionState
NVF_API const std::vector<int64_t>& extents() const;

//! Add extents of TensorView inputs to FusionState
NVF_API void addExtents();
rdspring1 marked this conversation as resolved.
Show resolved Hide resolved

//! Add a Record
void addRecord(RecordFunctor* record);
//! Builds an nvFuser Fusion IR object
Expand All @@ -94,6 +106,8 @@ class FusionState {
std::unique_ptr<FusionState> clone();

private:
//! Get extents for TensorView inputs in Fusion
std::vector<Val*> getExtents(Fusion* fusion);
//! Change the fusion ptr and reset its state
void resetFusionState(Fusion* fusion, size_t size);

Expand All @@ -104,10 +118,18 @@ class FusionState {
std::vector<std::unique_ptr<RecordFunctor>> recording_;
//! A vector of state that represents Tensors/Vectors/Scalars
std::vector<State> recording_state_;
//! Input arguments for FusionState
std::vector<int64_t> inputs_fid_;
//! Output arguments for FusionState
std::vector<int64_t> outputs_fid_;
//! Extents for TensorView input arguments for FusionState
std::vector<int64_t> extents_fid_;
//! Map Fusion Val to its corresponding FusionDefinition index
std::unordered_map<const Val*, int64_t> map_value_to_fid_;

private:
//! A ptr to the container used when building the Fusion IR from a definition
Fusion* fusion_;
Fusion* fusion_ = nullptr;
//! A vector of nvFuser Fusion IR TensorViews/Vectors/Scalars for building the
//! Fusion IR graph.
//! NOTE: Vectors are represented by a vector<Val*>. This could
Expand Down
3 changes: 3 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,9 @@ void initNvFuserPythonBindings(PyObject* module) {
// Mark the end of a schedule
inst::Trace::instance()->endEvent(nullptr);
})
.def("inputs", [](FusionDefinition& self) { return self.inputs(); })
.def("outputs", [](FusionDefinition& self) { return self.outputs(); })
.def("extents", [](FusionDefinition& self) { return self.extents(); })
.def(
"__repr__",
[](FusionDefinition& self) {
Expand Down
Loading