Skip to content

Commit

Permalink
Move auto split out of GPT2 example into a separate file
Browse files Browse the repository at this point in the history
ghstack-source-id: c7bad774204a595f634124c4018f77e64c12835f
Pull Request resolved: #1125
  • Loading branch information
kwen2501 committed Jun 11, 2024
1 parent 3cde2b1 commit 098a2c7
Showing 1 changed file with 9 additions and 27 deletions.
36 changes: 9 additions & 27 deletions examples/huggingface/pippy_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint
from torch.distributed.pipelining import pipeline, ScheduleGPipe, SplitPoint

from transformers import GPT2ForSequenceClassification, GPT2Config

Expand Down Expand Up @@ -38,36 +38,20 @@ def run(args):
mb_inputs = generate_inputs_for_model(
model_class, gpt2, model_name, args.batch_size // args.chunks, args.device)

assert not args.autosplit or not args.graphsplit
# Pipeline split spec
decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size
print(f"decoders_per_rank = {decoders_per_rank}")
split_spec = {
f'transformer.h.{i * decoders_per_rank}': SplitPoint.BEGINNING
for i in range(1, args.world_size)
}

split_policy = None
split_spec = None

if args.autosplit:
# Automatic split
# TODO: Migrate to new auto split algorithms
from pippy import split_into_equal_size
split_policy = split_into_equal_size(args.world_size)
elif args.graphsplit:
# Graph-based split
from pippy import split_by_graph
split_policy = split_by_graph(args.world_size)
else:
# Use manual split spec
decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size
print(f"decoders_per_rank = {decoders_per_rank}")
split_spec = {
f'transformer.h.{i * decoders_per_rank}': SplitPoint.BEGINNING
for i in range(1, args.world_size)
}

# Only one of `split_spec` and `split_policy` is used
# Create pipeline representation
pipe = pipeline(
gpt2,
mb_args=(),
mb_kwargs=mb_inputs,
split_spec=split_spec,
split_policy=split_policy,
)

assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}"
Expand Down Expand Up @@ -114,8 +98,6 @@ def run(args):
parser.add_argument('--n_embd', type=int, default=None)
parser.add_argument('--n_layer', type=int, default=None)
parser.add_argument('--n_head', type=int, default=None)
parser.add_argument('--autosplit', action="store_true")
parser.add_argument('--graphsplit', action="store_true")

args = parser.parse_args()

Expand Down

0 comments on commit 098a2c7

Please sign in to comment.