-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add information for coordinating segments in python frontend. #3289
Changes from all commits
0b7dcdc
8310ace
946403a
09af663
f6246fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,6 +85,22 @@ std::unique_ptr<FusionState> FusionState::clone() { | |
state->fusion_state_.insert( | ||
state->fusion_state_.end(), fusion_state_.begin(), fusion_state_.end()); | ||
state->num_recording_states_ = num_recording_states_; | ||
std::copy( | ||
inputs_fid_.begin(), | ||
inputs_fid_.end(), | ||
std::back_inserter(state->inputs_fid_)); | ||
std::copy( | ||
outputs_fid_.begin(), | ||
outputs_fid_.end(), | ||
std::back_inserter(state->outputs_fid_)); | ||
std::copy( | ||
extents_fid_.begin(), | ||
extents_fid_.end(), | ||
std::back_inserter(state->extents_fid_)); | ||
std::copy( | ||
map_value_to_fid_.begin(), | ||
map_value_to_fid_.end(), | ||
std::inserter(state->map_value_to_fid_, state->map_value_to_fid_.end())); | ||
return state; | ||
} | ||
|
||
|
@@ -108,6 +124,7 @@ void FusionState::buildFusionIr(Fusion* fusion) { | |
e.what()); | ||
} | ||
} | ||
addExtents(); | ||
} | ||
|
||
void FusionState::addRecord(RecordFunctor* record) { | ||
|
@@ -147,6 +164,10 @@ void FusionState::resetFusionState(Fusion* fusion, size_t size) { | |
fusion_ = fusion; | ||
fusion_state_.clear(); | ||
fusion_state_.resize(size, {}); | ||
inputs_fid_.clear(); | ||
outputs_fid_.clear(); | ||
extents_fid_.clear(); | ||
map_value_to_fid_.clear(); | ||
} | ||
|
||
void FusionState::addFusionState(Val* val) { | ||
|
@@ -178,6 +199,7 @@ size_t FusionState::numFusionStates() const { | |
|
||
void FusionState::setFusionState(size_t index, Val* val) { | ||
fusion_state_.at(index) = {val}; | ||
map_value_to_fid_.emplace(val, (int64_t)index); | ||
} | ||
|
||
void FusionState::setFusionStateVector(size_t index, std::vector<Val*> val) { | ||
|
@@ -189,14 +211,18 @@ void FusionState::setFusionStateVector(size_t index, std::vector<Val*> val) { | |
fusion_state_.at(index) = {val}; | ||
} | ||
|
||
void FusionState::addInput(Val* input) { | ||
void FusionState::addInput(Val* input, size_t index) { | ||
NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); | ||
fusion_->addInput(input); | ||
map_value_to_fid_.emplace(input, (int64_t)index); | ||
inputs_fid_.push_back((int64_t)index); | ||
} | ||
|
||
void FusionState::addOutput(Val* output) { | ||
void FusionState::addOutput(Val* output, size_t index) { | ||
NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); | ||
fusion_->addOutput(output); | ||
map_value_to_fid_.emplace(output, (int64_t)index); | ||
outputs_fid_.push_back((int64_t)index); | ||
} | ||
|
||
void FusionState::aliasOutputToInput(Val* output, Val* input) { | ||
|
@@ -206,4 +232,63 @@ void FusionState::aliasOutputToInput(Val* output, Val* input) { | |
fusion_->aliasOutputToInput(output, input, AllocationType::ReuseBuffer); | ||
} | ||
|
||
const std::unordered_map<const Val*, int64_t>& FusionState::getValueMap() | ||
const { | ||
return map_value_to_fid_; | ||
} | ||
|
||
const std::vector<int64_t>& FusionState::inputs() const { | ||
return inputs_fid_; | ||
} | ||
|
||
const std::vector<int64_t>& FusionState::outputs() const { | ||
return outputs_fid_; | ||
} | ||
|
||
const std::vector<int64_t>& FusionState::extents() const { | ||
return extents_fid_; | ||
} | ||
|
||
std::vector<Val*> FusionState::getExtents(Fusion* fusion) { | ||
NVF_CHECK(fusion != nullptr, "Fusion is undefined."); | ||
|
||
std::vector<Val*> extents; | ||
for (Val* v : fusion->inputs()) { | ||
// short-circuit: skip if not TensorView | ||
if (!v->isA<TensorView>()) { | ||
continue; | ||
} | ||
TensorView* tv = v->as<TensorView>(); | ||
std::vector<IterDomain*> logical_dom = | ||
TensorDomain::noReductions(tv->getLogicalDomain()); | ||
std::transform( | ||
logical_dom.begin(), | ||
logical_dom.end(), | ||
std::back_inserter(extents), | ||
[](IterDomain* id) { return id->getMaybeExpandedExtent(); }); | ||
} | ||
return extents; | ||
} | ||
|
||
void FusionState::addExtents() { | ||
NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); | ||
|
||
// The size of the tensor dimensions can be used as an input of the | ||
// segments. NvFuser does not support returning scalar values. Segmentation | ||
// must pass those sizes as segment arguments manually. | ||
std::vector<Val*> extents = getExtents(fusion_); | ||
for (Val* extent : extents) { | ||
int64_t num_extents = (int64_t)extents_fid_.size(); | ||
// Use negative numbers to represent extent of iterDomains to avoid conflict | ||
// with non-negative numbers used for scalars, vectors, and tensors. | ||
// The extents are ordered based on the order of the fusion's inputs. | ||
int64_t extent_fid = -num_extents - 1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this a negative index? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All scalars, vectors, and tensors use positive indices. The extents do not exist in the The extents are the size of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so the negative number here is just an initialization? does the number carry any meaning or does a global There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is an ordering component to the extent index. It is used for the same purpose as collecting the extents in We're mapping the tensor sizes to the extents like so https://github.com/NVIDIA/Fuser/pull/3025/files#diff-e512bea3b02f75ab1e81b759562879c5867e6e863679d6e7696fa34087dc3dc9R98-R100. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add this in a comment listing the use of negative indices to avoid conflict with other indices. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got'ya. It's hard to figure out the necessity without looking at the actual use. We can keep it as-is and revisit in follow up PRs. |
||
extents_fid_.push_back(extent_fid); | ||
// The extent can already exist in the fusion. However, since scalars cannot | ||
// be passed between segments, always overwrited existing fids. The original | ||
// fusion definition will provide scalar extents. | ||
map_value_to_fid_[extent] = extent_fid; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit lost here. iiuc, the map_value_to_fid_ on other values are mapped from the Val* to their index field in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not exposing the TensorView's logical domain to the python frontend, but I am tracking it in the |
||
} | ||
} | ||
|
||
} // namespace nvfuser::python_frontend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wujingyue is trying to change how we bind IO buffers to kernels. i.e. we might rethink which domain and how we are going to use here.
Not proposing any change, just trying to raise awareness.