Skip to content

Commit

Permalink
create test_fusion_information
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Oct 30, 2024
1 parent 0b7dcdc commit 8310ace
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 4 deletions.
6 changes: 3 additions & 3 deletions csrc/python_frontend/fusion_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ struct FusionSchedules {
std::mutex scheds_lock;
//! ID of fusion in python frontend fusion cache
int64_t fusion_id_ = -1;
//! Input arguments for FusionState
//! Fusion IDs of input arguments for FusionState
std::vector<int64_t> inputs_fid_;
//! Extents for TensorView input arguments for FusionState
//! IDs for Extents for TensorView input arguments for FusionState
std::vector<int64_t> extents_fid_;
//! Output arguments for FusionState
//! Fusion IDs of output arguments for FusionState
std::vector<int64_t> outputs_fid_;
//! Map Fusion Val to its corresponding FusionDefinition index
std::unordered_map<const Val*, int64_t> map_value_to_fid_;
Expand Down
3 changes: 3 additions & 0 deletions csrc/python_frontend/fusion_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ void FusionState::addExtents() {
std::vector<Val*> extents = getExtents(fusion_);
for (Val* extent : extents) {
int64_t num_extents = (int64_t)extents_fid_.size();
// Use negative numbers to represent extent of iterDomains to avoid conflict
// with non-negative numbers used for scalars, vectors, and tensors.
// The extents are ordered based on the order of the fusion's inputs.
int64_t extent_fid = -num_extents - 1;
extents_fid_.push_back(extent_fid);
// The extent can already exist in the fusion. However, since scalars cannot
Expand Down
1 change: 0 additions & 1 deletion nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class FusionDefinition(_C._FusionDefinition):
def __init__(self, id=None, max_length=1024):
super(FusionDefinition, self).__init__(id, max_length)
self.profiled = False
self.inputs = None

def __enter__(self):
return self._setup_definition()
Expand Down
31 changes: 31 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4600,3 +4600,34 @@ def fusion_func(fd: FusionDefinition) -> None:
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
for out in nvf_out:
self.assertTrue(out.allclose(x[:, 1:, 2:]))

def test_fusion_information(self):
inputs = [
torch.ones(2, 4, 8, device="cuda"),
torch.ones(2, 4, 8, device="cuda"),
]

def fusion_func(fd: FusionDefinition) -> None:
t0 = fd.from_pytorch(inputs[0])
t1 = fd.from_pytorch(inputs[1])
c0 = fd.define_scalar(3.0)

t2 = fd.ops.add(t0, t1)
t3 = fd.ops.mul(t2, c0)
t4 = fd.ops.sum(t3, [-1], False, DataType.Float)

fd.add_output(t4)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
eager_out = torch.sum((inputs[0] + inputs[1]) * 3.0, dim=-1)
self.assertEqual(eager_out, nvf_out[0])

with FusionDefinition() as fd:
fusion_func(fd)

nvf_out1 = fd.execute(inputs)
self.assertEqual(eager_out, nvf_out1[0])

self.assertEqual(fd.inputs(), [0, 1])
self.assertEqual(fd.outputs(), [5])
self.assertEqual(fd.extents(), [idx for idx in range(-1, -7, -1)])

0 comments on commit 8310ace

Please sign in to comment.