Skip to content

Commit

Permalink
Add information for coordinating segments in python frontend. (#3289)
Browse files Browse the repository at this point in the history
# Overview
This PR adds information necessary for coordinating segments in the
python frontend. Changes pulled from
#3025.

## PR Details
* Track the fusion state ids for the inputs, outputs, and extents of a
Fusion. Inputs and extents are used to gather tensor arguments and
scalars to run a fusion segment, while the outputs are employed to store
results between segments.
* A map from a CPP value to its corresponding fusion state id, which is
needed to map values from original fusion to its segmented fusions.

## Implementation Details

- `FusionState` is a lightweight representation of a CPP `Fusion`. 
- When calling `buildFusionIr`, a CPP `Fusion` is created from the
Python `FusionDefinition`. At this point, the `FusionState` creates a
mapping from CPP `Fusion` to its `State` objects.
- However, the `FusionState` is temporary and the CPP `Fusion` is cached
in `FusionCache`. The information linking the CPP `Fusion` and Python
`FusionDefinition` is stored in `FusionCache`.
- When we create a new `FusionState`, we look for a cached CPP `Fusion`.
If it exists, we restore the mapping from the data stored in
`FusionSchedules`.
  • Loading branch information
rdspring1 authored Oct 31, 2024
1 parent 3d9677d commit 621e146
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 12 deletions.
12 changes: 10 additions & 2 deletions csrc/python_frontend/fusion_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,15 +781,23 @@ void FusionCache::deserialize(std::string filename) {
NVF_CHECK(
trie_ptr->fusion_id == fb_trie_node->fusion_id(),
"The fusion id for this TrieNode should already be set.")
Fusion* fusion =
queryFusionSchedules(fb_trie_node->fusion_id())->preschedFusion();
FusionSchedules* fs = queryFusionSchedules(fb_trie_node->fusion_id());
Fusion* fusion = fs->preschedFusion();
try {
// There could be bad fusion in the serialization.
state->buildFusionIr(fusion);
} catch (const std::exception& e) {
// catch exception and setException for the terminal node
trie_ptr->setException(e.what());
}
// The FusionState creates a mapping from CPP Fusion to its State objects.
// Since the CPP Fusion is cached in FusionCache and the FusionState is
// temporary, the information linking CPP Fusion and Python
// FusionDefinition is stored in FusionCache.
fs->inputs_fid_ = state->inputs();
fs->outputs_fid_ = state->outputs();
fs->extents_fid_ = state->extents();
fs->map_value_to_fid_ = state->getValueMap();
}

// Table TrieNode => Field: children: [ulong]
Expand Down
8 changes: 8 additions & 0 deletions csrc/python_frontend/fusion_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ struct FusionSchedules {
std::mutex scheds_lock;
//! ID of fusion in python frontend fusion cache
int64_t fusion_id_ = -1;
//! Fusion IDs of input arguments for FusionState
std::vector<int64_t> inputs_fid_;
//! IDs for Extents for TensorView input arguments for FusionState
std::vector<int64_t> extents_fid_;
//! Fusion IDs of output arguments for FusionState
std::vector<int64_t> outputs_fid_;
//! Map Fusion Val to its corresponding FusionDefinition index
std::unordered_map<const Val*, int64_t> map_value_to_fid_;
};

//! \struct TrieNode
Expand Down
22 changes: 22 additions & 0 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ void FusionDefinition::finalizeDefinition() {
throw;
}

// The FusionState creates a mapping from CPP Fusion to its State objects.
// Since the CPP Fusion is cached in FusionCache and the FusionState is
// temporary, the information linking CPP Fusion and Python
// FusionDefinition is stored in FusionCache.
FusionSchedules* fs =
fusionCache()->queryFusionSchedules(fusion_id_.value());
fs->inputs_fid_ = inputs();
fs->outputs_fid_ = outputs();
fs->extents_fid_ = extents();
fs->map_value_to_fid_ = getValueMap();

if (isDebugDumpEnabled(DebugDumpOption::FusionIrOriginal)) {
printIr();
}
Expand All @@ -121,6 +132,17 @@ void FusionDefinition::finalizeDefinition() {
// build a proper fusion earlier.
NVF_CHECK(!opt_e.has_value(), opt_e.value());
fusion_id_ = std::optional<size_t>(trie_node_->fusion_id);

// A CPP fusion already exists in the FusionCache for this FusionDefinition.
// In this case, a new CPP Fusion is not created, so the mapping from CPP
// fusion to Python FusionDefinition is not initialized. This state is
// stored within FusionSchedules and is retrieved for this FusionDefinition.
FusionSchedules* fs =
fusionCache()->queryFusionSchedules(fusion_id_.value());
inputs_fid_ = fs->inputs_fid_;
outputs_fid_ = fs->outputs_fid_;
extents_fid_ = fs->extents_fid_;
map_value_to_fid_ = fs->map_value_to_fid_;
}

NVF_ERROR(
Expand Down
8 changes: 4 additions & 4 deletions csrc/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,7 @@ struct TensorRecord : RecordFunctor {
}

fd.setFusionState(outputs_.at(0).index, tv);
fd.addInput(tv);
fd.addInput(tv, outputs_.at(0).index);
}

void print(std::ostream& os, bool close_function = true) const final {
Expand Down Expand Up @@ -1545,12 +1545,12 @@ struct OutputRecord : RecordFunctor {
}
tv_output->setAllocationDomain(allocation_domain, true);
}
fd.addOutput(tv_output);
fd.addOutput(tv_output, args_.at(0).index);
} else {
NVF_CHECK(
stride_order_.empty(),
"stride_order can't be dictated for scalar outputs.");
fd.addOutput(output);
fd.addOutput(output, args_.at(0).index);
}
}
}
Expand Down Expand Up @@ -2015,7 +2015,7 @@ struct ScalarRecord : RecordFunctor {
void operator()(FusionState& fd) final {
Val* output = IrBuilder::create<nvfuser::Val>(value_, dtype_);
if (!value_.hasValue()) {
fd.addInput(output);
fd.addInput(output, outputs_.at(0).index);
}
fd.setFusionState(outputs_.at(0).index, output);
}
Expand Down
89 changes: 87 additions & 2 deletions csrc/python_frontend/fusion_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,22 @@ std::unique_ptr<FusionState> FusionState::clone() {
state->fusion_state_.insert(
state->fusion_state_.end(), fusion_state_.begin(), fusion_state_.end());
state->num_recording_states_ = num_recording_states_;
std::copy(
inputs_fid_.begin(),
inputs_fid_.end(),
std::back_inserter(state->inputs_fid_));
std::copy(
outputs_fid_.begin(),
outputs_fid_.end(),
std::back_inserter(state->outputs_fid_));
std::copy(
extents_fid_.begin(),
extents_fid_.end(),
std::back_inserter(state->extents_fid_));
std::copy(
map_value_to_fid_.begin(),
map_value_to_fid_.end(),
std::inserter(state->map_value_to_fid_, state->map_value_to_fid_.end()));
return state;
}

Expand All @@ -108,6 +124,7 @@ void FusionState::buildFusionIr(Fusion* fusion) {
e.what());
}
}
addExtents();
}

void FusionState::addRecord(RecordFunctor* record) {
Expand Down Expand Up @@ -147,6 +164,10 @@ void FusionState::resetFusionState(Fusion* fusion, size_t size) {
fusion_ = fusion;
fusion_state_.clear();
fusion_state_.resize(size, {});
inputs_fid_.clear();
outputs_fid_.clear();
extents_fid_.clear();
map_value_to_fid_.clear();
}

void FusionState::addFusionState(Val* val) {
Expand Down Expand Up @@ -178,6 +199,7 @@ size_t FusionState::numFusionStates() const {

void FusionState::setFusionState(size_t index, Val* val) {
fusion_state_.at(index) = {val};
map_value_to_fid_.emplace(val, (int64_t)index);
}

void FusionState::setFusionStateVector(size_t index, std::vector<Val*> val) {
Expand All @@ -189,14 +211,18 @@ void FusionState::setFusionStateVector(size_t index, std::vector<Val*> val) {
fusion_state_.at(index) = {val};
}

void FusionState::addInput(Val* input) {
void FusionState::addInput(Val* input, size_t index) {
NVF_CHECK(fusion_ != nullptr, "Fusion is undefined.");
fusion_->addInput(input);
map_value_to_fid_.emplace(input, (int64_t)index);
inputs_fid_.push_back((int64_t)index);
}

void FusionState::addOutput(Val* output) {
void FusionState::addOutput(Val* output, size_t index) {
NVF_CHECK(fusion_ != nullptr, "Fusion is undefined.");
fusion_->addOutput(output);
map_value_to_fid_.emplace(output, (int64_t)index);
outputs_fid_.push_back((int64_t)index);
}

void FusionState::aliasOutputToInput(Val* output, Val* input) {
Expand All @@ -206,4 +232,63 @@ void FusionState::aliasOutputToInput(Val* output, Val* input) {
fusion_->aliasOutputToInput(output, input, AllocationType::ReuseBuffer);
}

const std::unordered_map<const Val*, int64_t>& FusionState::getValueMap()
const {
return map_value_to_fid_;
}

const std::vector<int64_t>& FusionState::inputs() const {
return inputs_fid_;
}

const std::vector<int64_t>& FusionState::outputs() const {
return outputs_fid_;
}

const std::vector<int64_t>& FusionState::extents() const {
return extents_fid_;
}

std::vector<Val*> FusionState::getExtents(Fusion* fusion) {
NVF_CHECK(fusion != nullptr, "Fusion is undefined.");

std::vector<Val*> extents;
for (Val* v : fusion->inputs()) {
// short-circuit: skip if not TensorView
if (!v->isA<TensorView>()) {
continue;
}
TensorView* tv = v->as<TensorView>();
std::vector<IterDomain*> logical_dom =
TensorDomain::noReductions(tv->getLogicalDomain());
std::transform(
logical_dom.begin(),
logical_dom.end(),
std::back_inserter(extents),
[](IterDomain* id) { return id->getMaybeExpandedExtent(); });
}
return extents;
}

void FusionState::addExtents() {
NVF_CHECK(fusion_ != nullptr, "Fusion is undefined.");

// The size of the tensor dimensions can be used as an input of the
// segments. NvFuser does not support returning scalar values. Segmentation
// must pass those sizes as segment arguments manually.
std::vector<Val*> extents = getExtents(fusion_);
for (Val* extent : extents) {
int64_t num_extents = (int64_t)extents_fid_.size();
// Use negative numbers to represent extent of iterDomains to avoid conflict
// with non-negative numbers used for scalars, vectors, and tensors.
// The extents are ordered based on the order of the fusion's inputs.
int64_t extent_fid = -num_extents - 1;
extents_fid_.push_back(extent_fid);
// The extent can already exist in the fusion. However, since scalars cannot
// be passed between segments, always overwrited existing fids. The original
// fusion definition will provide scalar extents.
map_value_to_fid_[extent] = extent_fid;
}
}

} // namespace nvfuser::python_frontend
27 changes: 24 additions & 3 deletions csrc/python_frontend/fusion_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,21 @@ class FusionState {
NVF_API void setFusionStateVector(size_t index, std::vector<Val*> val);

//! Adds a Tensor/Scalar input to the Fusion object
NVF_API void addInput(Val* input);
NVF_API void addInput(Val* input, size_t index);
//! Adds a Tensor/Scalar output to the Fusion object
NVF_API void addOutput(Val* output);
NVF_API void addOutput(Val* output, size_t index);
//! Alias an Output to Input in the Fusion object
NVF_API void aliasOutputToInput(Val* output, Val* input);

//! Get map between CPP Fusion and Python FusionDefinition
NVF_API const std::unordered_map<const Val*, int64_t>& getValueMap() const;
//! Get indicies for the inputs of FusionState
NVF_API const std::vector<int64_t>& inputs() const;
//! Get indicies for the outputs of FusionState
NVF_API const std::vector<int64_t>& outputs() const;
//! Get indicies for the extents of TensorView inputs of FusionState
NVF_API const std::vector<int64_t>& extents() const;

//! Add a Record
void addRecord(RecordFunctor* record);
//! Builds an nvFuser Fusion IR object
Expand All @@ -94,6 +103,10 @@ class FusionState {
std::unique_ptr<FusionState> clone();

private:
//! Get extents for TensorView inputs in Fusion
std::vector<Val*> getExtents(Fusion* fusion);
//! Add extents of TensorView inputs to FusionState
void addExtents();
//! Change the fusion ptr and reset its state
void resetFusionState(Fusion* fusion, size_t size);

Expand All @@ -104,10 +117,18 @@ class FusionState {
std::vector<std::unique_ptr<RecordFunctor>> recording_;
//! A vector of state that represents Tensors/Vectors/Scalars
std::vector<State> recording_state_;
//! Input arguments for FusionState
std::vector<int64_t> inputs_fid_;
//! Output arguments for FusionState
std::vector<int64_t> outputs_fid_;
//! Extents for TensorView input arguments for FusionState
std::vector<int64_t> extents_fid_;
//! Map Fusion Val to its corresponding FusionDefinition index
std::unordered_map<const Val*, int64_t> map_value_to_fid_;

private:
//! A ptr to the container used when building the Fusion IR from a definition
Fusion* fusion_;
Fusion* fusion_ = nullptr;
//! A vector of nvFuser Fusion IR TensorViews/Vectors/Scalars for building the
//! Fusion IR graph.
//! NOTE: Vectors are represented by a vector<Val*>. This could
Expand Down
3 changes: 3 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,9 @@ void initNvFuserPythonBindings(PyObject* module) {
// Mark the end of a schedule
inst::Trace::instance()->endEvent(nullptr);
})
.def("inputs", [](FusionDefinition& self) { return self.inputs(); })
.def("outputs", [](FusionDefinition& self) { return self.outputs(); })
.def("extents", [](FusionDefinition& self) { return self.extents(); })
.def(
"__repr__",
[](FusionDefinition& self) {
Expand Down
1 change: 0 additions & 1 deletion nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class FusionDefinition(_C._FusionDefinition):
def __init__(self, id=None, max_length=1024):
super(FusionDefinition, self).__init__(id, max_length)
self.profiled = False
self.inputs = None

def __enter__(self):
return self._setup_definition()
Expand Down
36 changes: 36 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4601,6 +4601,42 @@ def fusion_func(fd: FusionDefinition) -> None:
for out in nvf_out:
self.assertTrue(out.allclose(x[:, 1:, 2:]))

def test_fusion_information(self):
inputs = [
torch.ones(2, 4, 8, device="cuda"),
torch.ones(2, 4, 8, device="cuda"),
]

def fusion_func(fd: FusionDefinition) -> None:
t0 = fd.from_pytorch(inputs[0])
t1 = fd.from_pytorch(inputs[1])
c2 = fd.define_scalar(3.0)

t3 = fd.ops.add(t0, t1)
t4 = fd.ops.mul(t3, c2)
t5 = fd.ops.sum(t4, [-1], False, DataType.Float)

fd.add_output(t5)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
eager_out = torch.sum((inputs[0] + inputs[1]) * 3.0, dim=-1)
self.assertEqual(eager_out, nvf_out[0])

with FusionDefinition() as fd:
fusion_func(fd)

nvf_out1 = fd.execute(inputs)
self.assertEqual(eager_out, nvf_out1[0])

# The input tensors are t0 and t1.
self.assertEqual(fd.inputs(), [0, 1])
# The output tensors is t5.
self.assertEqual(fd.outputs(), [5])
# The extents correspond with the dimensions for each input tensor.
# There are two input tensors with three dimensions each, so the
# extents range from [-1, -6].
self.assertEqual(fd.extents(), [idx for idx in range(-1, -7, -1)])

def test_issue_3292(self):
inputs = [
torch.testing.make_tensor(
Expand Down

0 comments on commit 621e146

Please sign in to comment.