diff --git a/examples/huggingface/pippy_gpt2.py b/examples/huggingface/pippy_gpt2.py index b11c7c3ea..3d1a3e1c8 100644 --- a/examples/huggingface/pippy_gpt2.py +++ b/examples/huggingface/pippy_gpt2.py @@ -8,7 +8,7 @@ import torch import torch.distributed as dist -from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint +from torch.distributed.pipelining import pipeline, ScheduleGPipe, SplitPoint from transformers import GPT2ForSequenceClassification, GPT2Config @@ -38,36 +38,20 @@ def run(args): mb_inputs = generate_inputs_for_model( model_class, gpt2, model_name, args.batch_size // args.chunks, args.device) - assert not args.autosplit or not args.graphsplit + # Pipeline split spec + decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size + print(f"decoders_per_rank = {decoders_per_rank}") + split_spec = { + f'transformer.h.{i * decoders_per_rank}': SplitPoint.BEGINNING + for i in range(1, args.world_size) + } - split_policy = None - split_spec = None - - if args.autosplit: - # Automatic split - # TODO: Migrate to new auto split algorithms - from pippy import split_into_equal_size - split_policy = split_into_equal_size(args.world_size) - elif args.graphsplit: - # Graph-based split - from pippy import split_by_graph - split_policy = split_by_graph(args.world_size) - else: - # Use manual split spec - decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size - print(f"decoders_per_rank = {decoders_per_rank}") - split_spec = { - f'transformer.h.{i * decoders_per_rank}': SplitPoint.BEGINNING - for i in range(1, args.world_size) - } - - # Only one of `split_spec` and `split_policy` is used + # Create pipeline representation pipe = pipeline( gpt2, mb_args=(), mb_kwargs=mb_inputs, split_spec=split_spec, - split_policy=split_policy, ) assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" @@ -114,8 +98,6 @@ def run(args): parser.add_argument('--n_embd', type=int, default=None) parser.add_argument('--n_layer', type=int, default=None) parser.add_argument('--n_head', type=int, default=None) - parser.add_argument('--autosplit', action="store_true") - parser.add_argument('--graphsplit', action="store_true") args = parser.parse_args()