Skip to content

Commit

Permalink
A graph-based pipeline splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
spupyrev committed Apr 26, 2024
1 parent 73e349b commit d976072
Show file tree
Hide file tree
Showing 6 changed files with 686 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__pycache__
build
pippy.egg-info
torchpippy.egg-info
pippy/version.py
dist
.idea/
Expand Down
46 changes: 39 additions & 7 deletions examples/huggingface/pippy_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def run(args):
config.n_embd = args.n_embd or config.n_embd
config.n_layer = args.n_layer or config.n_layer
config.n_head = args.n_head or config.n_head
print("Using device:", args.device)
print("[Rank {}] Using device: {}".format(args.rank, args.device))

# Create model
model_class = GPT2ForSequenceClassification
Expand All @@ -55,6 +55,8 @@ def run(args):
example_inputs = generate_inputs_for_model(
model_class, gpt2, model_name, args.batch_size, args.device)

assert not args.autosplit or not args.graphsplit

if args.autosplit:
# Automatic split
from pippy import split_into_equal_size
Expand All @@ -65,6 +67,16 @@ def run(args):
example_kwargs=example_inputs,
split_policy=split_into_equal_size(args.world_size),
)
elif args.graphsplit:
# Graph-based split
from pippy import split_by_graph
gpt2_pipe = pipeline(
gpt2,
num_chunks=args.chunks,
example_args=(),
example_kwargs=example_inputs,
split_policy=split_by_graph(args.world_size),
)
else:
# Manually annotate split points
add_split_points(gpt2, args.world_size)
Expand All @@ -90,14 +102,33 @@ def run(args):
# Attach to a schedule
schedule = ScheduleGPipe(stage, args.chunks)

# Run
if args.rank == 0:
schedule.step(**example_inputs)
else:
out = schedule.step()
def iter():
if args.rank == 0:
schedule.step(**example_inputs)
elif args.rank == args.world_size - 1:
out = schedule.step()
else:
schedule.step()

import time
# Warm-up
for _ in range(3):
iter()

print(f"Rank {args.rank} completes")
# Add a barrier here to synchronize all ranks
dist.barrier()
start_time = time.time()

for i in range(args.batches):
iter()

torch.cuda.synchronize()
end_time = time.time()

print("[Rank {}]: Time per batch: {:.4f} sec".format(
args.rank,
(end_time - start_time) / args.batches,
))

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -117,6 +148,7 @@ def run(args):
parser.add_argument('--n_layer', type=int, default=None)
parser.add_argument('--n_head', type=int, default=None)
parser.add_argument('--autosplit', action="store_true")
parser.add_argument('--graphsplit', action="store_true")

args = parser.parse_args()

Expand Down
55 changes: 55 additions & 0 deletions pippy/ModelSplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.fx as fx
from pippy.graphsplit import split_by_graph_with_num_stages

from .IR import aten_pipe_split_alias

Expand Down Expand Up @@ -200,3 +201,57 @@ def _split_into_nstages_equal_size(
return gm

return _split_into_nstages_equal_size


"""
Create a Callable that splits a model into a given number of stages, based on the computation graph, while
trying to minimize the communication between the stages and to balance the computation
Input:
nstages: the number of stages to split the module into
Output:
a Callable that transforms an input `fx.GraphModule` into an output `fx.GraphModule` that has `pipe_split` inserted
between `nstages` stages
"""


def split_by_graph(nstages: int) -> Callable[[fx.GraphModule], fx.GraphModule]:
def _split_by_graph(
gm: fx.GraphModule,
) -> fx.GraphModule:
node_param_sizes = _analyze_node_size(gm)
node2stage = split_by_graph_with_num_stages(
gm, nstages, node_param_sizes
)

# Remove existing split points
for node in gm.graph.nodes:
if "pipe_split" in node.name:
gm.graph.erase_node(node)

# Modify the graph by grouping nodes on the same stage and adding
# pipe_splits between the stages
node_order = [node for node in gm.graph.nodes if node in node2stage]
last_node = None
for stage_idx in range(nstages):
nodes_at_stage = [
node
for node in node_order
if node in node2stage and node2stage[node] == stage_idx
]
for idx, node in enumerate(nodes_at_stage):
if last_node is not None and last_node.next != node:
last_node.append(node)
last_node = node
# Insert pipe_split nodes after each stage, except the last one
if stage_idx + 1 != nstages and last_node is not None:
with gm.graph.inserting_after(last_node):
last_node = gm.graph.call_function(
aten_pipe_split_alias, (), {}
)

# Since we transformed the graph, recompile the module
gm.recompile()
gm.graph.lint()
return gm

return _split_by_graph
7 changes: 6 additions & 1 deletion pippy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
SplitPoint,
)
from .ManualPipelineStage import ManualPipelineStage
from .ModelSplit import split_into_equal_size, split_on_size_threshold
from .ModelSplit import (
split_by_graph,
split_into_equal_size,
split_on_size_threshold,
)
from .PipelineSchedule import (
Schedule1F1B,
ScheduleGPipe,
Expand All @@ -27,6 +31,7 @@
"annotate_split_points",
"split_into_equal_size",
"split_on_size_threshold",
"split_by_graph",
"pipeline",
"Schedule1F1B",
"ScheduleGPipe",
Expand Down
Loading

0 comments on commit d976072

Please sign in to comment.