Skip to content

Commit

Permalink
Migrate some of the HF examples to use 2.4 PP APIs
Browse files Browse the repository at this point in the history
ghstack-source-id: 267357bc4a61719407aa0bcc3dfa62bf40b7643e
Pull Request resolved: #1124
  • Loading branch information
kwen2501 committed Jun 10, 2024
1 parent 17dae2c commit 3cde2b1
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 29 deletions.
22 changes: 12 additions & 10 deletions examples/huggingface/pippy_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint
from torch.distributed.pipelining import pipeline, ScheduleGPipe, SplitPoint

from transformers import BertModel, BertConfig

Expand All @@ -31,9 +31,9 @@ def run(args):
print(f"Total number of params = {get_number_of_params(bert) // 10 ** 6}M")
print(bert)

# Input configs
example_inputs = generate_inputs_for_model(
model_class, bert, model_name, args.batch_size, args.device)
# Example microbatch inputs
example_mb = generate_inputs_for_model(
model_class, bert, model_name, args.batch_size // args.chunks, args.device)

# Split points
layers_per_rank = bert.config.num_hidden_layers // args.world_size
Expand All @@ -45,9 +45,8 @@ def run(args):
# Create pipeline
pipe = pipeline(
bert,
num_chunks=args.chunks,
example_args=(),
example_kwargs=example_inputs,
mb_args=(),
mb_kwargs=example_mb,
split_spec=split_spec,
)

Expand All @@ -56,18 +55,21 @@ def run(args):
print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params")

# Create schedule runtime
stage = PipelineStage(
pipe,
stage = pipe.build_stage(
args.rank,
device=args.device,
)

# Attach to a schedule
schedule = ScheduleGPipe(stage, args.chunks)

# Full batch inputs as in single-worker case
inputs = generate_inputs_for_model(
model_class, bert, model_name, args.batch_size, args.device)

# Run
if args.rank == 0:
schedule.step(**example_inputs)
schedule.step(**inputs)
else:
out = schedule.step()

Expand Down
22 changes: 12 additions & 10 deletions examples/huggingface/pippy_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint
from torch.distributed.pipelining import pipeline, ScheduleGPipe, SplitPoint

from transformers import DebertaModel, DebertaConfig

Expand All @@ -31,9 +31,9 @@ def run(args):
print(f"Total number of params = {get_number_of_params(deberta) // 10 ** 6}M")
print(deberta)

# Input configs
example_inputs = generate_inputs_for_model(
model_class, deberta, model_name, args.batch_size, args.device)
# Example microbatch inputs
mb_inputs = generate_inputs_for_model(
model_class, deberta, model_name, args.batch_size // args.chunks, args.device)

# Split points
layers_per_rank = deberta.config.num_hidden_layers // args.world_size
Expand All @@ -45,9 +45,8 @@ def run(args):
# Create pipeline
pipe = pipeline(
deberta,
num_chunks=args.chunks,
example_args=(),
example_kwargs=example_inputs,
mb_args=(),
mb_kwargs=mb_inputs,
split_spec=split_spec,
)

Expand All @@ -56,18 +55,21 @@ def run(args):
print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params")

# Create schedule runtime
stage = PipelineStage(
pipe,
stage = pipe.build_stage(
args.rank,
device=args.device,
)

# Attach to a schedule
schedule = ScheduleGPipe(stage, args.chunks)

# Full batch inputs as in single-worker case
inputs = generate_inputs_for_model(
model_class, deberta, model_name, args.batch_size, args.device)

# Run
if args.rank == 0:
schedule.step(**example_inputs)
schedule.step(**inputs)
else:
out = schedule.step()

Expand Down
21 changes: 12 additions & 9 deletions examples/huggingface/pippy_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def run(args):
print(f"GPT-2 total number of params = {get_number_of_params(gpt2) // 10 ** 6}M")
print(gpt2)

# Input configs
example_inputs = generate_inputs_for_model(
model_class, gpt2, model_name, args.batch_size, args.device)
# Example microbatch inputs
mb_inputs = generate_inputs_for_model(
model_class, gpt2, model_name, args.batch_size // args.chunks, args.device)

assert not args.autosplit or not args.graphsplit

Expand All @@ -45,6 +45,7 @@ def run(args):

if args.autosplit:
# Automatic split
# TODO: Migrate to new auto split algorithms
from pippy import split_into_equal_size
split_policy = split_into_equal_size(args.world_size)
elif args.graphsplit:
Expand All @@ -63,9 +64,8 @@ def run(args):
# Only one of `split_spec` and `split_policy` is used
pipe = pipeline(
gpt2,
num_chunks=args.chunks,
example_args=(),
example_kwargs=example_inputs,
mb_args=(),
mb_kwargs=mb_inputs,
split_spec=split_spec,
split_policy=split_policy,
)
Expand All @@ -75,18 +75,21 @@ def run(args):
print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params")

# Create schedule runtime
stage = PipelineStage(
pipe,
stage = pipe.build_stage(
args.rank,
device=args.device,
)

# Attach to a schedule
schedule = ScheduleGPipe(stage, args.chunks)

# Full batch inputs as in single-worker case
inputs = generate_inputs_for_model(
model_class, gpt2, model_name, args.batch_size, args.device)

# Run
if args.rank == 0:
schedule.step(**example_inputs)
schedule.step(**inputs)
else:
out = schedule.step()

Expand Down

0 comments on commit 3cde2b1

Please sign in to comment.