diff --git a/examples/huggingface/pippy_bert.py b/examples/huggingface/pippy_bert.py index 16ad57a19..88846cbc6 100644 --- a/examples/huggingface/pippy_bert.py +++ b/examples/huggingface/pippy_bert.py @@ -8,7 +8,7 @@ import torch import torch.distributed as dist -from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint +from torch.distributed.pipelining import pipeline, ScheduleGPipe, SplitPoint from transformers import BertModel, BertConfig @@ -31,9 +31,9 @@ def run(args): print(f"Total number of params = {get_number_of_params(bert) // 10 ** 6}M") print(bert) - # Input configs - example_inputs = generate_inputs_for_model( - model_class, bert, model_name, args.batch_size, args.device) + # Example microbatch inputs + example_mb = generate_inputs_for_model( + model_class, bert, model_name, args.batch_size // args.chunks, args.device) # Split points layers_per_rank = bert.config.num_hidden_layers // args.world_size @@ -45,9 +45,8 @@ def run(args): # Create pipeline pipe = pipeline( bert, - num_chunks=args.chunks, - example_args=(), - example_kwargs=example_inputs, + mb_args=(), + mb_kwargs=example_mb, split_spec=split_spec, ) @@ -56,8 +55,7 @@ def run(args): print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") # Create schedule runtime - stage = PipelineStage( - pipe, + stage = pipe.build_stage( args.rank, device=args.device, ) @@ -65,9 +63,13 @@ def run(args): # Attach to a schedule schedule = ScheduleGPipe(stage, args.chunks) + # Full batch inputs as in single-worker case + inputs = generate_inputs_for_model( + model_class, bert, model_name, args.batch_size, args.device) + # Run if args.rank == 0: - schedule.step(**example_inputs) + schedule.step(**inputs) else: out = schedule.step() diff --git a/examples/huggingface/pippy_deberta.py b/examples/huggingface/pippy_deberta.py index 60a49fde2..76380926f 100644 --- a/examples/huggingface/pippy_deberta.py +++ b/examples/huggingface/pippy_deberta.py @@ -8,7 +8,7 @@ import torch import torch.distributed as dist -from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint +from torch.distributed.pipelining import pipeline, ScheduleGPipe, SplitPoint from transformers import DebertaModel, DebertaConfig @@ -31,9 +31,9 @@ def run(args): print(f"Total number of params = {get_number_of_params(deberta) // 10 ** 6}M") print(deberta) - # Input configs - example_inputs = generate_inputs_for_model( - model_class, deberta, model_name, args.batch_size, args.device) + # Example microbatch inputs + mb_inputs = generate_inputs_for_model( + model_class, deberta, model_name, args.batch_size // args.chunks, args.device) # Split points layers_per_rank = deberta.config.num_hidden_layers // args.world_size @@ -45,9 +45,8 @@ def run(args): # Create pipeline pipe = pipeline( deberta, - num_chunks=args.chunks, - example_args=(), - example_kwargs=example_inputs, + mb_args=(), + mb_kwargs=mb_inputs, split_spec=split_spec, ) @@ -56,8 +55,7 @@ def run(args): print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") # Create schedule runtime - stage = PipelineStage( - pipe, + stage = pipe.build_stage( args.rank, device=args.device, ) @@ -65,9 +63,13 @@ def run(args): # Attach to a schedule schedule = ScheduleGPipe(stage, args.chunks) + # Full batch inputs as in single-worker case + inputs = generate_inputs_for_model( + model_class, deberta, model_name, args.batch_size, args.device) + # Run if args.rank == 0: - schedule.step(**example_inputs) + schedule.step(**inputs) else: out = schedule.step() diff --git a/examples/huggingface/pippy_gpt2.py b/examples/huggingface/pippy_gpt2.py index 2a600cc4e..b11c7c3ea 100644 --- a/examples/huggingface/pippy_gpt2.py +++ b/examples/huggingface/pippy_gpt2.py @@ -34,9 +34,9 @@ def run(args): print(f"GPT-2 total number of params = {get_number_of_params(gpt2) // 10 ** 6}M") print(gpt2) - # Input configs - example_inputs = generate_inputs_for_model( - model_class, gpt2, model_name, args.batch_size, args.device) + # Example microbatch inputs + mb_inputs = generate_inputs_for_model( + model_class, gpt2, model_name, args.batch_size // args.chunks, args.device) assert not args.autosplit or not args.graphsplit @@ -45,6 +45,7 @@ def run(args): if args.autosplit: # Automatic split + # TODO: Migrate to new auto split algorithms from pippy import split_into_equal_size split_policy = split_into_equal_size(args.world_size) elif args.graphsplit: @@ -63,9 +64,8 @@ def run(args): # Only one of `split_spec` and `split_policy` is used pipe = pipeline( gpt2, - num_chunks=args.chunks, - example_args=(), - example_kwargs=example_inputs, + mb_args=(), + mb_kwargs=mb_inputs, split_spec=split_spec, split_policy=split_policy, ) @@ -75,8 +75,7 @@ def run(args): print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") # Create schedule runtime - stage = PipelineStage( - pipe, + stage = pipe.build_stage( args.rank, device=args.device, ) @@ -84,9 +83,13 @@ def run(args): # Attach to a schedule schedule = ScheduleGPipe(stage, args.chunks) + # Full batch inputs as in single-worker case + inputs = generate_inputs_for_model( + model_class, gpt2, model_name, args.batch_size, args.device) + # Run if args.rank == 0: - schedule.step(**example_inputs) + schedule.step(**inputs) else: out = schedule.step()