diff --git a/csrc/python_frontend/fusion_cache.h b/csrc/python_frontend/fusion_cache.h index 76c4f68af6c..ae64b588c89 100644 --- a/csrc/python_frontend/fusion_cache.h +++ b/csrc/python_frontend/fusion_cache.h @@ -103,11 +103,11 @@ struct FusionSchedules { std::mutex scheds_lock; //! ID of fusion in python frontend fusion cache int64_t fusion_id_ = -1; - //! Input arguments for FusionState + //! Fusion IDs of input arguments for FusionState std::vector inputs_fid_; - //! Extents for TensorView input arguments for FusionState + //! IDs for Extents for TensorView input arguments for FusionState std::vector extents_fid_; - //! Output arguments for FusionState + //! Fusion IDs of output arguments for FusionState std::vector outputs_fid_; //! Map Fusion Val to its corresponding FusionDefinition index std::unordered_map map_value_to_fid_; diff --git a/csrc/python_frontend/fusion_state.cpp b/csrc/python_frontend/fusion_state.cpp index c3970c6edf9..be8d8d0c514 100644 --- a/csrc/python_frontend/fusion_state.cpp +++ b/csrc/python_frontend/fusion_state.cpp @@ -279,6 +279,9 @@ void FusionState::addExtents() { std::vector extents = getExtents(fusion_); for (Val* extent : extents) { int64_t num_extents = (int64_t)extents_fid_.size(); + // Use negative numbers to represent extent of iterDomains to avoid conflict + // with non-negative numbers used for scalars, vectors, and tensors. + // The extents are ordered based on the order of the fusion's inputs. int64_t extent_fid = -num_extents - 1; extents_fid_.push_back(extent_fid); // The extent can already exist in the fusion. However, since scalars cannot diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index 4b4f25b9d66..7d9048e7bf6 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -53,7 +53,6 @@ class FusionDefinition(_C._FusionDefinition): def __init__(self, id=None, max_length=1024): super(FusionDefinition, self).__init__(id, max_length) self.profiled = False - self.inputs = None def __enter__(self): return self._setup_definition() diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 8080c48278c..b0b4af0236d 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4600,3 +4600,34 @@ def fusion_func(fd: FusionDefinition) -> None: nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) for out in nvf_out: self.assertTrue(out.allclose(x[:, 1:, 2:])) + + def test_fusion_information(self): + inputs = [ + torch.ones(2, 4, 8, device="cuda"), + torch.ones(2, 4, 8, device="cuda"), + ] + + def fusion_func(fd: FusionDefinition) -> None: + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.from_pytorch(inputs[1]) + c0 = fd.define_scalar(3.0) + + t2 = fd.ops.add(t0, t1) + t3 = fd.ops.mul(t2, c0) + t4 = fd.ops.sum(t3, [-1], False, DataType.Float) + + fd.add_output(t4) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + eager_out = torch.sum((inputs[0] + inputs[1]) * 3.0, dim=-1) + self.assertEqual(eager_out, nvf_out[0]) + + with FusionDefinition() as fd: + fusion_func(fd) + + nvf_out1 = fd.execute(inputs) + self.assertEqual(eager_out, nvf_out1[0]) + + self.assertEqual(fd.inputs(), [0, 1]) + self.assertEqual(fd.outputs(), [5]) + self.assertEqual(fd.extents(), [idx for idx in range(-1, -7, -1)])