Skip to content

Commit

Permalink
A graph-based pipeline splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
spupyrev committed May 24, 2024
1 parent 395801c commit d83ef61
Show file tree
Hide file tree
Showing 5 changed files with 646 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__pycache__
build
pippy.egg-info
torchpippy.egg-info
pippy/version.py
dist
.idea/
Expand Down
54 changes: 54 additions & 0 deletions pippy/ModelSplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,57 @@ def _split_into_nstages_equal_size(
return gm

return _split_into_nstages_equal_size


"""
Create a Callable that splits a model into a given number of stages, based on the computation graph, while
trying to minimize the communication between the stages and to balance the computation
Input:
nstages: the number of stages to split the module into
Output:
a Callable that transforms an input `fx.GraphModule` into an output `fx.GraphModule` that has `pipe_split` inserted
between `nstages` stages
"""


def split_by_graph(nstages: int) -> Callable[[fx.GraphModule], fx.GraphModule]:
def _split_by_graph(
gm: fx.GraphModule,
) -> fx.GraphModule:
node_param_sizes = _analyze_node_size(gm)
node2stage = split_by_graph_with_num_stages(
gm, nstages, node_param_sizes
)

# Remove existing split points
for node in gm.graph.nodes:
if "pipe_split" in node.name:
gm.graph.erase_node(node)

# Modify the graph by grouping nodes on the same stage and adding
# pipe_splits between the stages
node_order = [node for node in gm.graph.nodes if node in node2stage]
last_node = None
for stage_idx in range(nstages):
nodes_at_stage = [
node
for node in node_order
if node in node2stage and node2stage[node] == stage_idx
]
for idx, node in enumerate(nodes_at_stage):
if last_node is not None and last_node.next != node:
last_node.append(node)
last_node = node
# Insert pipe_split nodes after each stage, except the last one
if stage_idx + 1 != nstages and last_node is not None:
with gm.graph.inserting_after(last_node):
last_node = gm.graph.call_function(
aten_pipe_split_alias, (), {}
)

# Since we transformed the graph, recompile the module
gm.recompile()
gm.graph.lint()
return gm

return _split_by_graph
7 changes: 6 additions & 1 deletion pippy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
)
from ._PipelineStage import PipelineStage
from .ManualPipelineStage import ManualPipelineStage
from .ModelSplit import split_into_equal_size, split_on_size_threshold
from .ModelSplit import (
split_by_graph,
split_into_equal_size,
split_on_size_threshold,
)
from .PipelineSchedule import (
Schedule1F1B,
ScheduleGPipe,
Expand All @@ -27,6 +31,7 @@
"annotate_split_points",
"split_into_equal_size",
"split_on_size_threshold",
"split_by_graph",
"pipeline",
"Schedule1F1B",
"ScheduleGPipe",
Expand Down
Loading

0 comments on commit d83ef61

Please sign in to comment.