Skip to content

Commit

Permalink
each PipelineStage holds multiple submods (#855)
Browse files Browse the repository at this point in the history
## Description

if inner_depth > 1, each PipelineStage holds multiple submods

Fixes #(issue)

- [v] New feature (non-breaking change which adds functionality)

## Feature/Issue validation/testing

No error occurs when internal depth == 1 (default case)

## Checklist:

- [v] Has code been commented, particularly in hard-to-understand areas?
  • Loading branch information
moonbucks authored Aug 8, 2023
1 parent 21ce7a4 commit de6ba61
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
50 changes: 48 additions & 2 deletions pippy/PipelineStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
kwargs_chunk_spec=None,
output_chunk_spec=None,
nstreams=2,
global_depth=None,
):
super().__init__()
self.pipe = pipe
Expand All @@ -76,10 +77,26 @@ def __init__(
for i in range(nstreams):
self.streams.append(torch.cuda.Stream())

# inner pipelining
if global_depth is not None:
self.global_depth = global_depth
self.inner_depth = global_depth // nstages
else:
self.global_depth = nstages
self.inner_depth = 1

# Find my submodule
self.split_gm = self.pipe.split_gm
named_children = list(self.split_gm.named_children())
self.name, self.submod = named_children[rank]

# submod = first inner node of each rank
self.name, self.submod = named_children[rank * self.inner_depth]
self.names, self.submods = [], []
for i in range(self.inner_depth):
name, submod = named_children[rank * self.inner_depth + i]
self.names.append(name)
self.submods.append(submod)

logging.info(
f"[{self.rank}][{self.name}] "
f"Creating PipelineStage:\n"
Expand All @@ -96,6 +113,16 @@ def __init__(
if not found_node:
raise AssertionError(f"Cannot find {self.name} in graph")

if (
self.inner_depth > 1
): # when inner pipelining is enabled, we have multiple nodes for this rank
self.nodes = []
for node in self.split_gm.graph.nodes:
if node.name in self.names:
self.nodes.append(node)
if len(self.nodes) == 0:
raise AssertionError(f"Cannot find {self.names} in graph")

# Find my backward node in graph
if self.pipe.has_loss_and_backwards:
found_bwd = False
Expand All @@ -112,10 +139,29 @@ def __init__(
f"Cannot find backward for {self.name} in graph"
)

if (
self.inner_depth > 1
): # when inner pipelining is enabled, we have multiple bwd nodes for this rank
self.bwd_nodes = []
seen_bwd = -1
added_bwd = 0
for node in reversed(self.split_gm.graph.nodes):
if (node.op, node.target) == ("call_function", stage_backward):
seen_bwd += 1
if seen_bwd // self.inner_depth == self.rank:
self.bwd_nodes.append(node)
added_bwd += 1
if added_bwd == self.inner_depth:
break
if len(self.bwd_nodes) == 0:
raise AssertionError(
f"Cannot find backward for {self.names} in graph"
)

# Create submod to rank mapping
self.submod_to_rank: Dict[str, int] = {}
for i, (name, _) in enumerate(self.split_gm.named_children()):
self.submod_to_rank.setdefault(name, i)
self.submod_to_rank.setdefault(name, i // self.inner_depth)

# Prepare send/recv infrastructure
self._prepare_send_recv_infra()
Expand Down
15 changes: 15 additions & 0 deletions pippy/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def compile_stage(
output_chunk_spec=None,
schedule="FillDrain",
nstreams=2,
num_stages=None,
**kwargs,
) -> PipelineStage:
# If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
Expand Down Expand Up @@ -308,6 +309,20 @@ def compile_stage(
kwargs_chunk_spec,
output_chunk_spec,
)
elif schedule == "TwoLevel":
return PipelineStage(
pipe,
rank,
num_ranks,
num_chunks,
device,
group,
args_chunk_spec,
kwargs_chunk_spec,
output_chunk_spec,
nstreams=nstreams,
global_depth=num_stages,
)
else:
return PipelineStage(
pipe,
Expand Down

0 comments on commit de6ba61

Please sign in to comment.