diff --git a/CMakeLists.txt b/CMakeLists.txt index de3e52f5055..702b562871b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -270,6 +270,7 @@ if(BUILD_PYTHON) ${NVFUSER_SRCS_DIR}/python_frontend/fusion_cache.cpp ${NVFUSER_SRCS_DIR}/python_frontend/fusion_definition.cpp ${NVFUSER_SRCS_DIR}/python_frontend/fusion_state.cpp + ${NVFUSER_SRCS_DIR}/python_frontend/segmentation.cpp ${NVFUSER_SRCS_DIR}/python_frontend/translation.cpp ${NVFUSER_SRCS_DIR}/python_frontend/translation_utils.cpp ${NVFUSER_SRCS_DIR}/serde/fusion_record.cpp diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index dc1af7a9f5c..9a58405a8b3 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -673,4 +674,19 @@ std::vector> FusionDefinition::getValTolerances( return get_val_constants(preschedFusion(), inputs); } +int64_t FusionDefinition::setupSegmentation( + const at::ArrayRef& inputs) { + NVF_CHECK(id().has_value(), "FusionDefinition definition does not exist!"); + NVF_ERROR( + segmentation_state_ == nullptr, "SegmentationState already exists!"); + segmentation_state_ = std::make_unique(); + return segmentation_state_->setupSegmentation( + preschedFusion(), map_value_to_fid_, inputs); +} + +void FusionDefinition::finalizeSegmentation() { + // Destroy SegmentedState + segmentation_state_.reset(); +} + } // namespace nvfuser::python_frontend diff --git a/csrc/python_frontend/fusion_definition.h b/csrc/python_frontend/fusion_definition.h index f3400d6e2d5..88ab14fae75 100644 --- a/csrc/python_frontend/fusion_definition.h +++ b/csrc/python_frontend/fusion_definition.h @@ -7,11 +7,13 @@ // clang-format on #pragma once #include +#include #include #include -#include +#include #include +#include namespace nvfuser::python_frontend { @@ -20,8 +22,9 @@ class FusionDefinition; class FusionInterface; class FusionState; struct RecordFunctor; -struct UserSchedule; +class SegmentationState; struct TrieNode; +struct UserSchedule; //! This is helper function used to print a python formated //! Fusion IR DataType when printing a fusion definition. @@ -254,6 +257,12 @@ class NVF_API FusionDefinition : public FusionState { //! Get all Tensors in FusionState. NVF_API std::vector tensors(); + //! Run segmentation algorithm on FusionDefinition. Returns the number of + //! segments. + NVF_API int64_t setupSegmentation(const at::ArrayRef& inputs); + //! After creating segments, destroy SegmentationState. + NVF_API void finalizeSegmentation(); + private: //! Returns the FusionCache Ptr that holds the cache of Fusions FusionCache* fusionCache() const; @@ -288,6 +297,9 @@ class NVF_API FusionDefinition : public FusionState { UserSchedule* user_sched_; //! Number of recording_states_ before applying user schedule int64_t num_recording_states_presched_ = 0; + //! Data member that creates SegmentedFusion from cloned, prescheduled Fusion + //! then translates the segments to python FusionDefinitions. + std::unique_ptr segmentation_state_; public: //! The Operators are not directly defined in this header. They are defined diff --git a/csrc/python_frontend/fusion_state.cpp b/csrc/python_frontend/fusion_state.cpp index be8d8d0c514..a3fdb85963e 100644 --- a/csrc/python_frontend/fusion_state.cpp +++ b/csrc/python_frontend/fusion_state.cpp @@ -68,6 +68,27 @@ std::ostream& operator<<(std::ostream& os, const State& state) { return os; } +std::vector getExtents(Fusion* fusion) { + NVF_CHECK(fusion != nullptr, "Fusion is undefined."); + + std::vector extents; + for (Val* v : fusion->inputs()) { + // short-circuit: skip if not TensorView + if (!v->isA()) { + continue; + } + TensorView* tv = v->as(); + std::vector logical_dom = + TensorDomain::noReductions(tv->getLogicalDomain()); + std::transform( + logical_dom.begin(), + logical_dom.end(), + std::back_inserter(extents), + [](IterDomain* id) { return id->getMaybeExpandedExtent(); }); + } + return extents; +} + FusionState::FusionState() : end_record_(new EndRecord()), recording_(), @@ -249,27 +270,6 @@ const std::vector& FusionState::extents() const { return extents_fid_; } -std::vector FusionState::getExtents(Fusion* fusion) { - NVF_CHECK(fusion != nullptr, "Fusion is undefined."); - - std::vector extents; - for (Val* v : fusion->inputs()) { - // short-circuit: skip if not TensorView - if (!v->isA()) { - continue; - } - TensorView* tv = v->as(); - std::vector logical_dom = - TensorDomain::noReductions(tv->getLogicalDomain()); - std::transform( - logical_dom.begin(), - logical_dom.end(), - std::back_inserter(extents), - [](IterDomain* id) { return id->getMaybeExpandedExtent(); }); - } - return extents; -} - void FusionState::addExtents() { NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); diff --git a/csrc/python_frontend/fusion_state.h b/csrc/python_frontend/fusion_state.h index 7a83886514a..20444fc1a65 100644 --- a/csrc/python_frontend/fusion_state.h +++ b/csrc/python_frontend/fusion_state.h @@ -45,6 +45,9 @@ struct State { NVF_API std::ostream& operator<<(std::ostream& os, const State& state); +//! Get extents for TensorView inputs in Fusion +std::vector getExtents(Fusion* fusion); + //! FusionState contains the information used to build a new cpp Fusion object. //! Unlike FusionDefinition, it does not modify the FusionCache Trie structure. class FusionState { @@ -103,8 +106,6 @@ class FusionState { std::unique_ptr clone(); private: - //! Get extents for TensorView inputs in Fusion - std::vector getExtents(Fusion* fusion); //! Add extents of TensorView inputs to FusionState void addExtents(); //! Change the fusion ptr and reset its state diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index f17ea228ad0..adb5bda2117 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1002,6 +1002,35 @@ void initNvFuserPythonBindings(PyObject* module) { .def("inputs", [](FusionDefinition& self) { return self.inputs(); }) .def("outputs", [](FusionDefinition& self) { return self.outputs(); }) .def("extents", [](FusionDefinition& self) { return self.extents(); }) + .def( + "_setup_segmentation", + [](FusionDefinition& self, const py::iterable& iter) { + // Instrumentation to mark the beginning of segmentation + inst::Trace::instance()->beginEvent( + "FusionDefinition Segmentation"); + std::vector inputs; + for (py::handle obj : iter) { + // Allows for a Vector of Sizes to be inputed as a list/tuple + if (py::isinstance(obj) || + py::isinstance(obj)) { + for (py::handle item : obj) { + inputs.push_back( + torch::jit::toIValue(item, c10::AnyType::get())); + } + } else { + inputs.push_back( + torch::jit::toIValue(obj, c10::AnyType::get())); + } + } + return self.setupSegmentation(inputs); + }) + .def( + "_finalize_segmentation", + [](FusionDefinition& self) { + self.finalizeSegmentation(); + // Mark the end of segmentation + inst::Trace::instance()->endEvent(nullptr); + }) .def( "__repr__", [](FusionDefinition& self) { diff --git a/csrc/python_frontend/segmentation.cpp b/csrc/python_frontend/segmentation.cpp new file mode 100644 index 00000000000..7bdcbab3c29 --- /dev/null +++ b/csrc/python_frontend/segmentation.cpp @@ -0,0 +1,143 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +namespace nvfuser::python_frontend { + +int64_t SegmentationState::setupSegmentation( + Fusion* fusion, + const std::unordered_map& map_value_to_original_fid, + const at::ArrayRef& inputs) { + // Check state + NVF_ERROR(fusion != nullptr); + NVF_ERROR(cloned_fusion_ == nullptr); + NVF_ERROR(segmented_fusion_ == nullptr); + NVF_ERROR(group_run_order_.empty()); + NVF_ERROR(map_cloned_value_to_fid_.empty()); + NVF_ERROR(cloned_extents_.empty()); + + int8_t device = getCommonDeviceCUDA(inputs); + NVF_CHECK( + inputs.empty() || device > -1, "Inputs are not all on the same device!"); + + // Step 1) Clone preschedFusion CPP Fusion. + cloned_fusion_ = std::make_unique(); + + // The IRCloner returned by Fusion::copy acts as map from the original fusion + // to the cloned fusion. + IrCloner original_to_cloned_map = Fusion::copy(fusion, cloned_fusion_.get()); + + KernelArgumentHolder args = + KernelArgumentHolder::createKernelArgumentHolder(inputs, device); + + // Step 2) Concretize fusion with input arguments. + std::unordered_map symbolic_to_concrete_map = + DynamicTransform::concretizeFusion(cloned_fusion_.get(), args); + + // Step 3) Given the map_value_to_original_fid, the IRCloner returned by + // Fusion::copy, AND the symbolic_to_concrete map returned by + // concretization pass, create a mapping from cloned Vals to original fusion + // state indices. + std::transform( + map_value_to_original_fid.begin(), + map_value_to_original_fid.end(), + std::inserter(map_cloned_value_to_fid_, map_cloned_value_to_fid_.end()), + [&](const auto& item) { + const Val* original_value = item.first; + int64_t fid = item.second; + Val* cloned_val = original_to_cloned_map.clone(original_value); + if (symbolic_to_concrete_map.count(cloned_val)) { + cloned_val = symbolic_to_concrete_map.at(cloned_val); + } + return std::make_pair(cloned_val, fid); + }); + + // Track the extents for input TensorViews in cloned CPP Fusion. + cloned_extents_ = getExtents(cloned_fusion_.get()); + + // Create runtime infomation + SchedulerRuntimeInfo runtime_info( + cloned_fusion_.get(), + args, + /*precomputed_values=*/nullptr, + cloned_fusion_->allTvs()); + + // Run segmentation algorithm + segmented_fusion_ = SegmentCandidateFinder::segment( + std::move(cloned_fusion_), &args, runtime_info); + + // Get the order for fusion segments + prepareGroupOrder(); + + // Return the number of segments created by segmentation algorithm. + return (int64_t)segmented_fusion_->groups().size(); +} + +void SegmentationState::prepareGroupOrder() { + NVF_ERROR(segmented_fusion_ != nullptr); + + // Gather initial inputs for SegmentedFusion. + std::unordered_set available_input( + segmented_fusion_->inputs().begin(), segmented_fusion_->inputs().end()); + + // The size of the tensor dimensions can be used as an input of the segments. + // NvFuser does not support returning scalar values. Segmentation must pass + // those sizes as segment arguments manually. + std::vector extents = getExtents(segmented_fusion_->completeFusion()); + std::copy( + extents.begin(), + extents.end(), + std::inserter(available_input, available_input.end())); + + // Track the run status of all SegmentedGroups in SegmentedFusion + std::vector group_ran(segmented_fusion_->groups().size(), false); + + // While not all the SegmentedGroups are run: + while (!std::all_of( + group_ran.begin(), group_ran.end(), [](bool b) { return b; })) { + bool ran_any_group = false; + + // Find the first segment with all inputs available to run + for (size_t group_i : c10::irange(segmented_fusion_->groups().size())) { + SegmentedGroup* group = segmented_fusion_->groups().at(group_i); + + // short-circuit: Already ran this segmented group. + if (group_ran.at(group_i)) { + continue; + } + + const std::vector& group_inputs = group->inputs(); + bool ready_to_run = std::all_of( + group_inputs.begin(), + group_inputs.end(), + [&available_input](Val* val) { return available_input.count(val); }); + + // short-circuit: This segmented group is not ready to run. + if (!ready_to_run) { + continue; + } + + // Add SegmentedGroup to group_run_order_. + group_run_order_.push_back(group); + + // Mark all outputs of SegmentedGroup as ready. + const std::vector& group_outputs = group->outputs(); + for (size_t group_out_i : c10::irange(group_outputs.size())) { + available_input.insert(group_outputs.at(group_out_i)); + } + group_ran[group_i] = true; + ran_any_group = true; + } + NVF_ERROR( + ran_any_group, + "Failed to run any group; An error must have occured in segmentation."); + } +} + +} // namespace nvfuser::python_frontend diff --git a/csrc/python_frontend/segmentation.h b/csrc/python_frontend/segmentation.h new file mode 100644 index 00000000000..b2cb6582fdf --- /dev/null +++ b/csrc/python_frontend/segmentation.h @@ -0,0 +1,227 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include + +namespace nvfuser::python_frontend { + +class FusionDefinition; + +//! Overview: +//! Segmentation decomposes a fusion into a directed acyclic graph (DAG) of +//! sub-fusions. After applying the segmentation algorithm, we can translate +//! the sub-fusions into their corresponding python definitions. Then, given the +//! fusion's input arguments, the segments are run in the correct order to +//! produce the output results. +//! +//! Each FusionDefinition contains a set of states representing tensors, vectors +//! and scalars. Every state has a unique index, which matches the insertion +//! order of the state in the FusionDefinition. These indices form a linear +//! index space for each FusionDefinition. +//! +//! The original FusionDefinition stores the sequence of sub-fusions and acts as +//! an argument manager. It gathers the input arguments before running the +//! sub-fusion and stores its results. To perform this function, it requires a +//! map from the segment index space to the original index space. This mapping +//! is generated while creating the python definition for each sub-fusion. +//! +//! Algorithm: +//! Step 1: setupSegmentation runs the segmentation algorithm on the CPP Fusion +//! to create the SegmentedFusion. Then, sub-fusions are ordered according to +//! their dependencies by the prepareGroupOrder function. It returns the number +//! of segments in SegmentedFusion. +//! +//! Step 2: buildSegment creates the CPP Fusion for a given segment id, +//! translates it to a python FusionDefinition, then returns a mapping from the +//! segment fusion state indices to the original fusion state indices. +//! +//! =========================================================================== +//! +//! Example 1: A simple fusion with two iota operations. +//! +//! Original Fusion: +//! def nvfuser_fusion_id1(fd : FusionDefinition) -> None : +//! S0 = fd.define_scalar(2, dtype=DataType.Int) +//! S1 = fd.define_scalar(0, dtype=DataType.Int) +//! S2 = fd.define_scalar(2, dtype=DataType.Int) +//! T3 = fd.ops.iota(S0, S1, S2, dtype=DataType.Int) +//! S4 = fd.define_scalar(3, dtype=DataType.Int) +//! S5 = fd.define_scalar(100, dtype=DataType.Int32) +//! S6 = fd.define_scalar(1, dtype=DataType.Int32) +//! T7 = fd.ops.iota(S4, S5, S6, dtype=DataType.Int32) +//! fd.add_output(T3) +//! fd.add_output(T7) +//! +//! After Segmentation: +//! The original fusion is divided into two segments. There is no dependencies +//! between either segment so they can run in any order. +//! +//! First Segment: +//! def nvfuser_fusion_id2(fd : FusionDefinition) -> None : +//! S0 = fd.define_scalar(2, dtype=DataType.Int) +//! S1 = fd.define_scalar(0, dtype=DataType.Int) +//! S2 = fd.define_scalar(2, dtype=DataType.Int) +//! T3 = fd.ops.iota(S0, S1, S2, dtype=DataType.Int) +//! fd.add_output(T3) +//! +//! Second Segment: +//! def nvfuser_fusion_id3(fd : FusionDefinition) -> None : +//! S0 = fd.define_scalar(3, dtype=DataType.Int) +//! S1 = fd.define_scalar(100, dtype=DataType.Int32) +//! S2 = fd.define_scalar(1, dtype=DataType.Int32) +//! T3 = fd.ops.iota(S0, S1, S2, dtype=DataType.Int32) +//! fd.add_output(T3) +//! +//! The first segment corresponds with [S0, S1, S2, T3] in the original fusion. +//! The second segment corresponds with [S4, S5, S6, S7] in the original fusion. +//! +//! Neither segment requires any input arguments from the original fusion. +//! +//! For the first segment, the segment's T3 is mapped to the original's T3. +//! Segment Index : Original Index Mapping +//! -------------------------------------- +//! T3 : T3 +//! +//! For the second segment the segment's T3 is mapped to the original's T7. +//! Segment Index : Original Index Mapping +//! -------------------------------------- +//! T3 : T7 +//! +//! =========================================================================== +//! +//! Example 2: A reduction + broadcast + pointwise fusion. +//! +//! Original Fusion: +//! def nvfuser_fusion_id1(fd : FusionDefinition) -> None : +//! T0 = fd.define_tensor(shape=[-1, -1], +//! contiguity=[True, True], +//! dtype=DataType.Float, +//! is_cpu=False) +//! T1 = fd.define_tensor(shape=[-1, -1], +//! contiguity=[True, True], +//! dtype=DataType.Float, +//! is_cpu=False) +//! T2 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float) +//! T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True]) +//! T4 = fd.ops.add(T1, T3) +//! fd.add_output(T4) +//! +//! After Segmentation: +//! The reduction scheduler does not support fusing any operations with an +//! inner reduction, so the original fusion is divided into two segments. +//! Segment 2 depends on Segment 1, so there is a strict segment ordering. +//! +//! First Segment: +//! def nvfuser_fusion_id2(fd : FusionDefinition) -> None : +//! T0 = fd.define_tensor(shape=[-1, -1], +//! contiguity=[True, True], +//! dtype=DataType.Float, +//! is_cpu=False) +//! T1 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float) +//! T2 = fd.ops.broadcast(T1, is_broadcast_dim=[False, True]) +//! fd.add_output(T2) +//! +//! Second Segment: +//! def nvfuser_fusion_id3(fd : FusionDefinition) -> None : +//! T0 = fd.define_tensor(shape=[-1, -1], +//! contiguity=[True, True], +//! dtype=DataType.Float, +//! is_cpu=False) +//! T1 = fd.define_tensor(shape=[-1, 1], +//! contiguity=[True, None], +//! dtype=DataType.Float, +//! is_cpu=False) +//! T2 = fd.ops.add(T0, T1) +//! fd.add_output(T2) +//! +//! The first segment contains the reduction and broadcast operations, which +//! corresponds with [T0, T2, T3] in the original fusion. Therefore, the segment +//! index to original index map has two entries. +//! +//! Segment Index : Original Index Mapping +//! -------------------------------------- +//! T0 : T0 --- The first tensor argument for the original fusion. +//! T2 : T3 --- The broadcasted, reduction tensor is this segment's output. +//! +//! The second segment is the pointwise addition with the broadcasted reduction. +//! It corresponds with [T1, T3, T4] in the original fusion. +//! +//! Segment Index : Original Index Mapping +//! -------------------------------------- +//! T0 : T1 --- The second tensor argument for the original fusion. +//! T1 : T3 --- The broadcasted, reduction tensor, which is the output from the +//! first segment. +//! T2 : T4 --- The pointwise addition, which is the output for the original +//! fusion. +//! =========================================================================== +class SegmentationState { + public: + // setupSegmentation runs the segmentation algorithm on CPP Fusion to create + // SegmentedFusion. It returns the number of segments in SegmentedFusion. + // + // Details: + // 1) Clone preschedFusion CPP Fusion. + // 2) Concretize fusion using input arguments. + // 3) Given the map_value_to_original_fid, the IRCloner returned by + // Fusion::copy, AND symbolic_to_concrete map returned by concretization + // pass, create a mapping from cloned Vals to original fusion state + // indices + // 4) Get extents for cloned fusion + // 5) Create SchedulerRuntimeInfo + // 6) Run segmentation algorithm using cloned fusion, input arguments, and + // scheduler runtime information. + // 7) Get sequential order of fusion segments using prepareGroupOrder. + // 8) Return the number of segments created by segmentation algorithm. + int64_t setupSegmentation( + Fusion* fusion, + const std::unordered_map& map_value_to_original_fid, + const at::ArrayRef& inputs); + + private: + // prepareGroupOrder is similar to prepareRuntimeOrder. It generates the + // topological order of SegmentedGroups in SegmentedFusion. + // + // Details: + // 1) Gather initial inputs for SegmentedFusion. + // 2) Gather IterDomain extents from the tensor input arguments. + // 3) Track the run status of all SegmentedGroups in SegmentedFusion + // 4) While not all the SegmentedGroups are run: + // 5) For each SegmentedGroup: + // 6) Skip SegmentedGroup if it is already run + // 7) Skip SegmentedGroup if inputs are not ready + // 8) Add SegmentedGroup to group_run_order_. Mark all outputs of + // SegmentedGroup as ready. + // 9) End For + // 10) Fail if none of the SegmentedGroups are available to run. + // 11) End While + void prepareGroupOrder(); + + private: + // Clone of original fusion for segmentation + std::unique_ptr cloned_fusion_ = nullptr; + + // This FusionDefinition may require multiple kernels if it cannot be handled + // by a single heuristic scheduler. SegmentedFusion takes a fusion and runs + // the segmentation algorithm. + std::unique_ptr segmented_fusion_ = nullptr; + + // Pre-determined order to run the segmented groups + std::vector group_run_order_; + + // Create copy of fusion for segmentation algorithm. IrCloner is a map + // between values in original and cloned fusions. + std::unordered_map map_cloned_value_to_fid_; + + // Extents for TensorView input arguments for cloned Fusion + std::vector cloned_extents_; +}; + +} // namespace nvfuser::python_frontend diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index f986ffa0640..743aed4a23c 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -54,6 +54,11 @@ def __init__(self, id=None, max_length=1024): super(FusionDefinition, self).__init__(id, max_length) self.profiled = False + def segment(self, inputs): + num_segments = self._setup_segmentation(inputs) + self._finalize_segmentation() + return num_segments + def __enter__(self): return self._setup_definition() diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 0f2a9f9314d..22810b4e08f 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -1244,6 +1244,27 @@ def fusion_func(fd: FusionDefinition): for torch_dtype in list_of_dtype: test_dtype(torch_dtype) + def test_segmentation_reduction_pointwise_epilogue(self): + inputs = [ + torch.randn(2, 32, device="cuda", dtype=torch.float32), + torch.randn(2, 128, device="cuda", dtype=torch.float32), + ] + + def fusion_func(fd: FusionDefinition): + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.from_pytorch(inputs[1]) + t2 = fd.ops.sum(t0, [-1], True, torch_dtype_to_nvfuser_dtype(torch.float32)) + t3 = fd.ops.add(t1, t2) + fd.add_output(t3) + + nvf_out1, _ = self.exec_nvfuser(fusion_func, inputs) + eager_out = torch.sum(inputs[0], dim=-1, keepdim=True) + inputs[1] + self.assertEqual(eager_out, nvf_out1[0]) + + with FusionDefinition() as fd: + fusion_func(fd) + assert fd.segment(inputs) == 2 + def test_arithmetic_ops(self): inputs = [ torch.randn(3, 4, 5, device="cuda", dtype=torch.float32), @@ -3881,7 +3902,7 @@ def fusion_func(fd: FusionDefinition) -> None: fd.add_output(T88) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=True) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) # See https://github.com/NVIDIA/Fuser/issues/2275 @pytest.mark.skipif( @@ -3927,7 +3948,7 @@ def fusion_func(fd: FusionDefinition) -> None: T101 = fd.ops.cat([T7, T100], dim=-1) fd.add_output(T101) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=True) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) # See https://github.com/NVIDIA/Fuser/issues/2317 @pytest.mark.skipif(