Skip to content

Commit

Permalink
Added graph-split presolve to speedup computation
Browse files Browse the repository at this point in the history
  • Loading branch information
spupyrev committed May 29, 2024
1 parent d83ef61 commit 1b435b0
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 63 deletions.
9 changes: 8 additions & 1 deletion examples/huggingface/pippy_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run(args):
config.n_embd = args.n_embd or config.n_embd
config.n_layer = args.n_layer or config.n_layer
config.n_head = args.n_head or config.n_head
print("Using device:", args.device)
print("[Rank {}] Using device: {}".format(args.rank, args.device))

# Create model
model_class = GPT2ForSequenceClassification
Expand All @@ -38,13 +38,19 @@ def run(args):
example_inputs = generate_inputs_for_model(
model_class, gpt2, model_name, args.batch_size, args.device)

assert not args.autosplit or not args.graphsplit

split_policy = None
split_spec = None

if args.autosplit:
# Automatic split
from pippy import split_into_equal_size
split_policy = split_into_equal_size(args.world_size)
elif args.graphsplit:
# Graph-based split
from pippy import split_by_graph
split_policy = split_by_graph(args.world_size)
else:
# Use manual split spec
decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size
Expand Down Expand Up @@ -106,6 +112,7 @@ def run(args):
parser.add_argument('--n_layer', type=int, default=None)
parser.add_argument('--n_head', type=int, default=None)
parser.add_argument('--autosplit', action="store_true")
parser.add_argument('--graphsplit', action="store_true")

args = parser.parse_args()

Expand Down
2 changes: 2 additions & 0 deletions pippy/ModelSplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import torch.fx as fx

from pippy.graphsplit import split_by_graph_with_num_stages

from ._IR import aten_pipe_split_alias


Expand Down
8 changes: 4 additions & 4 deletions pippy/_IR.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,10 +925,10 @@ def set_multi_use_param_spec(
if isinstance(multi_use_param_spec, MultiUseParameterConfig):
multi_use_params_qualnames[param] = multi_use_param_spec
elif isinstance(multi_use_param_spec, dict):
multi_use_params_qualnames[
param
] = multi_use_param_spec.get(
param, MultiUseParameterConfig.TRANSMIT
multi_use_params_qualnames[param] = (
multi_use_param_spec.get(
param, MultiUseParameterConfig.TRANSMIT
)
)
else:
raise ValueError(
Expand Down
Loading

0 comments on commit 1b435b0

Please sign in to comment.