Skip to content

Commit

Permalink
Refactor HuggingFace examples to use torch.distributed.pipelining (#1116
Browse files Browse the repository at this point in the history
)

now that the `pippy` library has been migrated to
`torch.distributed.pipelining`.

Next step we will use these examples in this repo as nightly CI against
torch.
  • Loading branch information
kwen2501 authored May 23, 2024
1 parent 2aa360f commit 395801c
Show file tree
Hide file tree
Showing 26 changed files with 290 additions and 1,237 deletions.
114 changes: 0 additions & 114 deletions examples/huggingface/pippy_albert.py

This file was deleted.

112 changes: 0 additions & 112 deletions examples/huggingface/pippy_bart.py

This file was deleted.

41 changes: 18 additions & 23 deletions examples/huggingface/pippy_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,21 @@

import torch
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint

from pippy import pipeline
from pippy import SplitPoint, annotate_split_points
from pippy.PipelineSchedule import ScheduleGPipe
from pippy import PipelineStage

from transformers import BertForMaskedLM, BertConfig
from transformers import BertModel, BertConfig

from hf_utils import generate_inputs_for_model, get_number_of_params


def add_split_points(bert, nranks):
layers_per_rank = bert.config.num_hidden_layers // nranks
for i in range(1, nranks):
annotate_split_points(
bert, {f"bert.encoder.layer.{i * layers_per_rank}": SplitPoint.BEGINNING})


def run(args):
# Model configs
config = BertConfig()
print("Using device:", args.device)

# Create model
model_class = BertForMaskedLM
model_name = "BertForMaskedLM"
model_class = BertModel
model_name = "BertModel"
bert = model_class(config)
bert.to(args.device)
bert.eval()
Expand All @@ -46,24 +35,29 @@ def run(args):
example_inputs = generate_inputs_for_model(
model_class, bert, model_name, args.batch_size, args.device)

# Annotate split points
add_split_points(bert, args.world_size)
# Split points
layers_per_rank = bert.config.num_hidden_layers // args.world_size
split_spec = {
f"encoder.layer.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, args.world_size)
}

# Create pipeline
bert_pipe = pipeline(
pipe = pipeline(
bert,
num_chunks=args.chunks,
example_args=(),
example_kwargs=example_inputs,
split_spec=split_spec,
)
assert bert_pipe.num_stages == args.world_size, f"nstages = {bert_pipe.num_stages} nranks = {args.world_size}"
if args.rank == 0:
for i, sm in enumerate(bert_pipe.split_gm.children()):
print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params")

assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}"
smod = pipe.get_stage_module(args.rank)
print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params")

# Create schedule runtime
stage = PipelineStage(
bert_pipe,
pipe,
args.rank,
device=args.device,
)
Expand All @@ -77,6 +71,7 @@ def run(args):
else:
out = schedule.step()

dist.destroy_process_group()
print(f"Rank {args.rank} completes")


Expand Down
36 changes: 15 additions & 21 deletions examples/huggingface/pippy_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,13 @@

import torch
import torch.distributed as dist

from pippy import pipeline
from pippy import SplitPoint, annotate_split_points
from pippy.PipelineSchedule import ScheduleGPipe
from pippy import PipelineStage
from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint

from transformers import BlenderbotForCausalLM, BlenderbotConfig

from hf_utils import generate_inputs_for_model, get_number_of_params


def add_split_points(blenderbot, nranks):
layers_per_rank = blenderbot.config.decoder_layers // nranks
for i in range(1, nranks):
annotate_split_points(
blenderbot, {f"model.decoder.layers.{i * layers_per_rank}": SplitPoint.BEGINNING})


def run(args):
# Model configs
config = BlenderbotConfig()
Expand All @@ -47,24 +36,28 @@ def run(args):
model_class, blenderbot, model_name, args.batch_size, args.device)
input_ids = example_inputs["input_ids"]

# Annotate split points
add_split_points(blenderbot, args.world_size)
# Split points
layers_per_rank = blenderbot.config.decoder_layers // args.world_size
split_spec = {
f"model.decoder.layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, args.world_size)
}

# Create pipeline
blenderbot_pipe = pipeline(
pipe = pipeline(
blenderbot,
num_chunks=args.chunks,
example_args=(input_ids, ),
split_spec=split_spec,
)
nstages = len(list(blenderbot_pipe.split_gm.children()))
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
if args.rank == 0:
for i, sm in enumerate(blenderbot_pipe.split_gm.children()):
print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params")

assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}"
smod = pipe.get_stage_module(args.rank)
print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params")

# Create schedule runtime
stage = PipelineStage(
blenderbot_pipe,
pipe,
args.rank,
device=args.device,
)
Expand All @@ -78,6 +71,7 @@ def run(args):
else:
out = schedule.step()

dist.destroy_process_group()
print(f"Rank {args.rank} completes")


Expand Down
Loading

0 comments on commit 395801c

Please sign in to comment.