forked from pytorch/PiPPy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
148 lines (116 loc) · 5.03 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from typing import Any
class MyNetworkBlock(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.lin = torch.nn.Linear(in_dim, out_dim)
def forward(self, x):
x = self.lin(x)
x = torch.relu(x)
return x
class MyNetwork(torch.nn.Module):
def __init__(self, in_dim, layer_dims):
super().__init__()
prev_dim = in_dim
for i, dim in enumerate(layer_dims):
setattr(self, f"layer{i}", MyNetworkBlock(prev_dim, dim))
prev_dim = dim
self.num_layers = len(layer_dims)
# 10 output classes
self.output_proj = torch.nn.Linear(layer_dims[-1], 10)
def forward(self, x):
for i in range(self.num_layers):
x = getattr(self, f"layer{i}")(x)
return self.output_proj(x)
mn = MyNetwork(512, [512, 1024, 256])
from pippy.IR import Pipe
pipe = Pipe.from_tracing(mn)
print(pipe)
print(pipe.split_gm.submod_0)
from pippy.IR import annotate_split_points, PipeSplitWrapper
annotate_split_points(
mn,
{
"layer0": PipeSplitWrapper.SplitPoint.END,
"layer1": PipeSplitWrapper.SplitPoint.END,
},
)
pipe = Pipe.from_tracing(mn)
print(" pipe ".center(80, "*"))
print(pipe)
print(" submod0 ".center(80, "*"))
print(pipe.split_gm.submod_0)
print(" submod1 ".center(80, "*"))
print(pipe.split_gm.submod_1)
print(" submod2 ".center(80, "*"))
print(pipe.split_gm.submod_2)
# To run a distributed training job, we must launch the script in multiple
# different processes. We are using `torchrun` to do so in this example.
# `torchrun` defines two environment variables: `LOCAL_RANK` and `WORLD_SIZE`,
# which represent the index of this process within the set of processes and
# the total number of processes, respectively.
#
# To learn more about `torchrun`, see
# https://pytorch.org/docs/stable/elastic/run.html
import os
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# PiPPy uses the PyTorch RPC interface. To use RPC, we must call `init_rpc`
# and inform the RPC framework of this process's rank and the total world
# size. We can directly pass values `torchrun` provided.`
#
# To learn more about the PyTorch RPC framework, see
# https://pytorch.org/docs/stable/rpc.html
import torch.distributed.rpc as rpc
rpc.init_rpc(f"worker{local_rank}", rank=local_rank, world_size=world_size)
# PiPPy relies on the concept of a "driver" process. The driver process
# should be a single process within the RPC group that instantiates the
# PipelineDriver and issues commands on that object. The other processes
# in the RPC group will receive commands from this process and execute
# the pipeline stages
if local_rank == 0:
# We are going to use the PipelineDriverFillDrain class. This class
# provides an interface for executing the `Pipe` in a style similar
# to the GPipe fill-drain schedule. To learn more about GPipe and
# the fill-drain schedule, see https://arxiv.org/abs/1811.06965
from pippy.PipelineDriver import PipelineDriverFillDrain
from pippy.microbatch import TensorChunkSpec
# Pipelining relies on _micro-batching_--that is--the process of
# dividing the program's input data into smaller chunks and
# feeding those chunks through the pipeline sequentially. Doing
# this requires that the data and operations be _separable_, i.e.
# there should be at least one dimension along which data can be
# split such that the program does not have interactions across
# this dimension. PiPPy provides `chunk_spec` arguments for this
# purpose, to specify the batch dimension for tensors in each of
# the args, kwargs, and outputs. The structure of the `chunk_spec`s
# should mirror that of the data type. Here, the program has a
# single tensor input and single tensor output, so we specify
# a single `TensorChunkSpec` instance indicating dimension 0
# for args[0] and the output value.
args_chunk_spec: Any = (TensorChunkSpec(0),)
kwargs_chunk_spec: Any = {}
output_chunk_spec: Any = TensorChunkSpec(0)
# Finally, we instantiate the PipelineDriver. We pass in the pipe,
# chunk specs, and world size, and the constructor will distribute
# our code to the processes in the RPC group. `driver` is an object
# we can invoke to run the pipeline.
driver = PipelineDriverFillDrain(
pipe,
64,
args_chunk_spec=args_chunk_spec,
kwargs_chunk_spec=kwargs_chunk_spec,
output_chunk_spec=output_chunk_spec,
world_size=world_size,
)
x = torch.randn(512, 512)
# Run the pipeline with input `x`. Divide the batch into 64 micro-batches
# and run them in parallel on the pipeline
output = driver(x)
# Run the original code and get the output for comparison
reference_output = mn(x)
# Compare numerics of pipeline and original model
torch.testing.assert_close(output, reference_output)
print(" Pipeline parallel model ran successfully! ".center(80, "*"))
rpc.shutdown()