diff --git a/csrc/python_frontend/fusion_cache.cpp b/csrc/python_frontend/fusion_cache.cpp index 83ce851dbab..e95ee6820da 100644 --- a/csrc/python_frontend/fusion_cache.cpp +++ b/csrc/python_frontend/fusion_cache.cpp @@ -781,8 +781,8 @@ void FusionCache::deserialize(std::string filename) { NVF_CHECK( trie_ptr->fusion_id == fb_trie_node->fusion_id(), "The fusion id for this TrieNode should already be set.") - Fusion* fusion = - queryFusionSchedules(fb_trie_node->fusion_id())->preschedFusion(); + FusionSchedules* fs = queryFusionSchedules(fb_trie_node->fusion_id()); + Fusion* fusion = fs->preschedFusion(); try { // There could be bad fusion in the serialization. state->buildFusionIr(fusion); @@ -790,6 +790,14 @@ void FusionCache::deserialize(std::string filename) { // catch exception and setException for the terminal node trie_ptr->setException(e.what()); } + // The FusionState creates a mapping from CPP Fusion to its State objects. + // Since the CPP Fusion is cached in FusionCache and the FusionState is + // temporary, the information linking CPP Fusion and Python + // FusionDefinition is stored in FusionCache. + fs->inputs_fid_ = state->inputs(); + fs->outputs_fid_ = state->outputs(); + fs->extents_fid_ = state->extents(); + fs->map_value_to_fid_ = state->getValueMap(); } // Table TrieNode => Field: children: [ulong] diff --git a/csrc/python_frontend/fusion_cache.h b/csrc/python_frontend/fusion_cache.h index 190671b2b82..2d4f2533ba5 100644 --- a/csrc/python_frontend/fusion_cache.h +++ b/csrc/python_frontend/fusion_cache.h @@ -107,6 +107,14 @@ struct FusionSchedules { std::mutex scheds_lock; //! ID of fusion in python frontend fusion cache int64_t fusion_id_ = -1; + //! Fusion IDs of input arguments for FusionState + std::vector inputs_fid_; + //! IDs for Extents for TensorView input arguments for FusionState + std::vector extents_fid_; + //! Fusion IDs of output arguments for FusionState + std::vector outputs_fid_; + //! Map Fusion Val to its corresponding FusionDefinition index + std::unordered_map map_value_to_fid_; }; //! \struct TrieNode diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 09648a0bf36..05f12a7c2af 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -108,6 +108,17 @@ void FusionDefinition::finalizeDefinition() { throw; } + // The FusionState creates a mapping from CPP Fusion to its State objects. + // Since the CPP Fusion is cached in FusionCache and the FusionState is + // temporary, the information linking CPP Fusion and Python + // FusionDefinition is stored in FusionCache. + FusionSchedules* fs = + fusionCache()->queryFusionSchedules(fusion_id_.value()); + fs->inputs_fid_ = inputs(); + fs->outputs_fid_ = outputs(); + fs->extents_fid_ = extents(); + fs->map_value_to_fid_ = getValueMap(); + if (isDebugDumpEnabled(DebugDumpOption::FusionIrOriginal)) { printIr(); } @@ -121,6 +132,17 @@ void FusionDefinition::finalizeDefinition() { // build a proper fusion earlier. NVF_CHECK(!opt_e.has_value(), opt_e.value()); fusion_id_ = std::optional(trie_node_->fusion_id); + + // A CPP fusion already exists in the FusionCache for this FusionDefinition. + // In this case, a new CPP Fusion is not created, so the mapping from CPP + // fusion to Python FusionDefinition is not initialized. This state is + // stored within FusionSchedules and is retrieved for this FusionDefinition. + FusionSchedules* fs = + fusionCache()->queryFusionSchedules(fusion_id_.value()); + inputs_fid_ = fs->inputs_fid_; + outputs_fid_ = fs->outputs_fid_; + extents_fid_ = fs->extents_fid_; + map_value_to_fid_ = fs->map_value_to_fid_; } NVF_ERROR( diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index 82879912509..154f8d28805 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -1368,7 +1368,7 @@ struct TensorRecord : RecordFunctor { } fd.setFusionState(outputs_.at(0).index, tv); - fd.addInput(tv); + fd.addInput(tv, outputs_.at(0).index); } void print(std::ostream& os, bool close_function = true) const final { @@ -1545,12 +1545,12 @@ struct OutputRecord : RecordFunctor { } tv_output->setAllocationDomain(allocation_domain, true); } - fd.addOutput(tv_output); + fd.addOutput(tv_output, args_.at(0).index); } else { NVF_CHECK( stride_order_.empty(), "stride_order can't be dictated for scalar outputs."); - fd.addOutput(output); + fd.addOutput(output, args_.at(0).index); } } } @@ -2015,7 +2015,7 @@ struct ScalarRecord : RecordFunctor { void operator()(FusionState& fd) final { Val* output = IrBuilder::create(value_, dtype_); if (!value_.hasValue()) { - fd.addInput(output); + fd.addInput(output, outputs_.at(0).index); } fd.setFusionState(outputs_.at(0).index, output); } diff --git a/csrc/python_frontend/fusion_state.cpp b/csrc/python_frontend/fusion_state.cpp index 99868f14b21..be8d8d0c514 100644 --- a/csrc/python_frontend/fusion_state.cpp +++ b/csrc/python_frontend/fusion_state.cpp @@ -85,6 +85,22 @@ std::unique_ptr FusionState::clone() { state->fusion_state_.insert( state->fusion_state_.end(), fusion_state_.begin(), fusion_state_.end()); state->num_recording_states_ = num_recording_states_; + std::copy( + inputs_fid_.begin(), + inputs_fid_.end(), + std::back_inserter(state->inputs_fid_)); + std::copy( + outputs_fid_.begin(), + outputs_fid_.end(), + std::back_inserter(state->outputs_fid_)); + std::copy( + extents_fid_.begin(), + extents_fid_.end(), + std::back_inserter(state->extents_fid_)); + std::copy( + map_value_to_fid_.begin(), + map_value_to_fid_.end(), + std::inserter(state->map_value_to_fid_, state->map_value_to_fid_.end())); return state; } @@ -108,6 +124,7 @@ void FusionState::buildFusionIr(Fusion* fusion) { e.what()); } } + addExtents(); } void FusionState::addRecord(RecordFunctor* record) { @@ -147,6 +164,10 @@ void FusionState::resetFusionState(Fusion* fusion, size_t size) { fusion_ = fusion; fusion_state_.clear(); fusion_state_.resize(size, {}); + inputs_fid_.clear(); + outputs_fid_.clear(); + extents_fid_.clear(); + map_value_to_fid_.clear(); } void FusionState::addFusionState(Val* val) { @@ -178,6 +199,7 @@ size_t FusionState::numFusionStates() const { void FusionState::setFusionState(size_t index, Val* val) { fusion_state_.at(index) = {val}; + map_value_to_fid_.emplace(val, (int64_t)index); } void FusionState::setFusionStateVector(size_t index, std::vector val) { @@ -189,14 +211,18 @@ void FusionState::setFusionStateVector(size_t index, std::vector val) { fusion_state_.at(index) = {val}; } -void FusionState::addInput(Val* input) { +void FusionState::addInput(Val* input, size_t index) { NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); fusion_->addInput(input); + map_value_to_fid_.emplace(input, (int64_t)index); + inputs_fid_.push_back((int64_t)index); } -void FusionState::addOutput(Val* output) { +void FusionState::addOutput(Val* output, size_t index) { NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); fusion_->addOutput(output); + map_value_to_fid_.emplace(output, (int64_t)index); + outputs_fid_.push_back((int64_t)index); } void FusionState::aliasOutputToInput(Val* output, Val* input) { @@ -206,4 +232,63 @@ void FusionState::aliasOutputToInput(Val* output, Val* input) { fusion_->aliasOutputToInput(output, input, AllocationType::ReuseBuffer); } +const std::unordered_map& FusionState::getValueMap() + const { + return map_value_to_fid_; +} + +const std::vector& FusionState::inputs() const { + return inputs_fid_; +} + +const std::vector& FusionState::outputs() const { + return outputs_fid_; +} + +const std::vector& FusionState::extents() const { + return extents_fid_; +} + +std::vector FusionState::getExtents(Fusion* fusion) { + NVF_CHECK(fusion != nullptr, "Fusion is undefined."); + + std::vector extents; + for (Val* v : fusion->inputs()) { + // short-circuit: skip if not TensorView + if (!v->isA()) { + continue; + } + TensorView* tv = v->as(); + std::vector logical_dom = + TensorDomain::noReductions(tv->getLogicalDomain()); + std::transform( + logical_dom.begin(), + logical_dom.end(), + std::back_inserter(extents), + [](IterDomain* id) { return id->getMaybeExpandedExtent(); }); + } + return extents; +} + +void FusionState::addExtents() { + NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); + + // The size of the tensor dimensions can be used as an input of the + // segments. NvFuser does not support returning scalar values. Segmentation + // must pass those sizes as segment arguments manually. + std::vector extents = getExtents(fusion_); + for (Val* extent : extents) { + int64_t num_extents = (int64_t)extents_fid_.size(); + // Use negative numbers to represent extent of iterDomains to avoid conflict + // with non-negative numbers used for scalars, vectors, and tensors. + // The extents are ordered based on the order of the fusion's inputs. + int64_t extent_fid = -num_extents - 1; + extents_fid_.push_back(extent_fid); + // The extent can already exist in the fusion. However, since scalars cannot + // be passed between segments, always overwrited existing fids. The original + // fusion definition will provide scalar extents. + map_value_to_fid_[extent] = extent_fid; + } +} + } // namespace nvfuser::python_frontend diff --git a/csrc/python_frontend/fusion_state.h b/csrc/python_frontend/fusion_state.h index bd75f7af5d6..7a83886514a 100644 --- a/csrc/python_frontend/fusion_state.h +++ b/csrc/python_frontend/fusion_state.h @@ -79,12 +79,21 @@ class FusionState { NVF_API void setFusionStateVector(size_t index, std::vector val); //! Adds a Tensor/Scalar input to the Fusion object - NVF_API void addInput(Val* input); + NVF_API void addInput(Val* input, size_t index); //! Adds a Tensor/Scalar output to the Fusion object - NVF_API void addOutput(Val* output); + NVF_API void addOutput(Val* output, size_t index); //! Alias an Output to Input in the Fusion object NVF_API void aliasOutputToInput(Val* output, Val* input); + //! Get map between CPP Fusion and Python FusionDefinition + NVF_API const std::unordered_map& getValueMap() const; + //! Get indicies for the inputs of FusionState + NVF_API const std::vector& inputs() const; + //! Get indicies for the outputs of FusionState + NVF_API const std::vector& outputs() const; + //! Get indicies for the extents of TensorView inputs of FusionState + NVF_API const std::vector& extents() const; + //! Add a Record void addRecord(RecordFunctor* record); //! Builds an nvFuser Fusion IR object @@ -94,6 +103,10 @@ class FusionState { std::unique_ptr clone(); private: + //! Get extents for TensorView inputs in Fusion + std::vector getExtents(Fusion* fusion); + //! Add extents of TensorView inputs to FusionState + void addExtents(); //! Change the fusion ptr and reset its state void resetFusionState(Fusion* fusion, size_t size); @@ -104,10 +117,18 @@ class FusionState { std::vector> recording_; //! A vector of state that represents Tensors/Vectors/Scalars std::vector recording_state_; + //! Input arguments for FusionState + std::vector inputs_fid_; + //! Output arguments for FusionState + std::vector outputs_fid_; + //! Extents for TensorView input arguments for FusionState + std::vector extents_fid_; + //! Map Fusion Val to its corresponding FusionDefinition index + std::unordered_map map_value_to_fid_; private: //! A ptr to the container used when building the Fusion IR from a definition - Fusion* fusion_; + Fusion* fusion_ = nullptr; //! A vector of nvFuser Fusion IR TensorViews/Vectors/Scalars for building the //! Fusion IR graph. //! NOTE: Vectors are represented by a vector. This could diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index b229107c45b..79f460fa232 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -999,6 +999,9 @@ void initNvFuserPythonBindings(PyObject* module) { // Mark the end of a schedule inst::Trace::instance()->endEvent(nullptr); }) + .def("inputs", [](FusionDefinition& self) { return self.inputs(); }) + .def("outputs", [](FusionDefinition& self) { return self.outputs(); }) + .def("extents", [](FusionDefinition& self) { return self.extents(); }) .def( "__repr__", [](FusionDefinition& self) { diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index 4b4f25b9d66..7d9048e7bf6 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -53,7 +53,6 @@ class FusionDefinition(_C._FusionDefinition): def __init__(self, id=None, max_length=1024): super(FusionDefinition, self).__init__(id, max_length) self.profiled = False - self.inputs = None def __enter__(self): return self._setup_definition() diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 874223471eb..e0597757a9c 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4601,6 +4601,42 @@ def fusion_func(fd: FusionDefinition) -> None: for out in nvf_out: self.assertTrue(out.allclose(x[:, 1:, 2:])) + def test_fusion_information(self): + inputs = [ + torch.ones(2, 4, 8, device="cuda"), + torch.ones(2, 4, 8, device="cuda"), + ] + + def fusion_func(fd: FusionDefinition) -> None: + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.from_pytorch(inputs[1]) + c2 = fd.define_scalar(3.0) + + t3 = fd.ops.add(t0, t1) + t4 = fd.ops.mul(t3, c2) + t5 = fd.ops.sum(t4, [-1], False, DataType.Float) + + fd.add_output(t5) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + eager_out = torch.sum((inputs[0] + inputs[1]) * 3.0, dim=-1) + self.assertEqual(eager_out, nvf_out[0]) + + with FusionDefinition() as fd: + fusion_func(fd) + + nvf_out1 = fd.execute(inputs) + self.assertEqual(eager_out, nvf_out1[0]) + + # The input tensors are t0 and t1. + self.assertEqual(fd.inputs(), [0, 1]) + # The output tensors is t5. + self.assertEqual(fd.outputs(), [5]) + # The extents correspond with the dimensions for each input tensor. + # There are two input tensors with three dimensions each, so the + # extents range from [-1, -6]. + self.assertEqual(fd.extents(), [idx for idx in range(-1, -7, -1)]) + def test_issue_3292(self): inputs = [ torch.testing.make_tensor(