Skip to content

Commit

Permalink
add supports_segmentation flag
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Nov 2, 2024
1 parent 1fb73ef commit 8923dc7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
3 changes: 2 additions & 1 deletion csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ class FusionDefinition;
class FusionInterface;
class FusionState;
struct RecordFunctor;
struct UserSchedule;
class SegmentationState;
struct TrieNode;
struct UserSchedule;

//! This is helper function used to print a python formated
//! Fusion IR DataType when printing a fusion definition.
Expand Down
10 changes: 5 additions & 5 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,7 +1950,7 @@ def fusion_func(fd: FusionDefinition):
# ...but skip outputting scalars, which we don't support
fd.add_output(t)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

ab = [inputs[0], inputs[1], 1.0, -1.0]
i = 0
Expand Down Expand Up @@ -3183,7 +3183,7 @@ def fusion_func(fd: FusionDefinition) -> None:
fd.add_output(T54)
fd.add_output(T30)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)
# self.assertEqual(nvf_out[0], t24)

# Test that symbolic IterDomains can be concatenated
Expand Down Expand Up @@ -3881,7 +3881,7 @@ def fusion_func(fd: FusionDefinition) -> None:

fd.add_output(T88)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=True)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

# See https://github.com/NVIDIA/Fuser/issues/2275
@pytest.mark.skipif(
Expand Down Expand Up @@ -3927,7 +3927,7 @@ def fusion_func(fd: FusionDefinition) -> None:
T101 = fd.ops.cat([T7, T100], dim=-1)
fd.add_output(T101)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=True)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

# See https://github.com/NVIDIA/Fuser/issues/2317
@pytest.mark.skipif(
Expand Down Expand Up @@ -4708,4 +4708,4 @@ def fusion_func(fd: FusionDefinition) -> None:
T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0)
fd.add_output(T223)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)
20 changes: 17 additions & 3 deletions tests/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,24 @@ def check_captured_python_definition(reference_outputs, fd, inputs, device=None)

# Run original FusionDefinition
# Clone FusionDefinition
# Apply segmentation if it supported for this FusionDefinition
# Run cloned python definition
# Check that the result of cloned python definition matches original results
def check_cpp_translation(reference_outputs, fd, inputs, device=None):
def check_cpp_translation(
reference_outputs, fd, inputs, supports_segmentation, device=None
):
try:
torch.manual_seed(0)

# Clone
cloned_fd = FusionDefinition()
clone(fd, cloned_fd)
cloned_fd.segment(inputs)

# Segment
if supports_segmentation:
cloned_fd.segment(inputs)

# Run
cloned_outputs = cloned_fd.execute(inputs, device=device)

# Make sure the results of original and cloned definitions match.
Expand All @@ -269,6 +279,7 @@ def check_cpp_translation(reference_outputs, fd, inputs, device=None):
print(
"(A failure here suggests a mismatch in functionality between the original and cloned definitions.)"
)
print("Does FusionDefinition supports segmentation?\t", supports_segmentation)
print(fd.getReproErrorString("executing", inputs))
raise err

Expand Down Expand Up @@ -420,6 +431,7 @@ def exec_nvfuser(
new_fusion_expected=True,
device=None,
is_clonable=True,
supports_segmentation=True,
):
fc = FusionCache.get()
before_fusions = fc.num_fusions()
Expand All @@ -442,5 +454,7 @@ def exec_nvfuser(
self.assertEqual(fc.num_fusions() - before_fusions, int(new_fusion_expected))

if is_clonable:
self.assertTrue(check_cpp_translation(out, fd, inputs_cloned))
self.assertTrue(
check_cpp_translation(out, fd, inputs_cloned, supports_segmentation)
)
return out, fd

0 comments on commit 8923dc7

Please sign in to comment.