diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index 75bdbe741..8b597a918 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -59,6 +59,7 @@ def __init__( kwargs_chunk_spec=None, output_chunk_spec=None, nstreams=2, + global_depth=None, ): super().__init__() self.pipe = pipe @@ -76,10 +77,26 @@ def __init__( for i in range(nstreams): self.streams.append(torch.cuda.Stream()) + # inner pipelining + if global_depth is not None: + self.global_depth = global_depth + self.inner_depth = global_depth // nstages + else: + self.global_depth = nstages + self.inner_depth = 1 + # Find my submodule self.split_gm = self.pipe.split_gm named_children = list(self.split_gm.named_children()) - self.name, self.submod = named_children[rank] + + # submod = first inner node of each rank + self.name, self.submod = named_children[rank * self.inner_depth] + self.names, self.submods = [], [] + for i in range(self.inner_depth): + name, submod = named_children[rank * self.inner_depth + i] + self.names.append(name) + self.submods.append(submod) + logging.info( f"[{self.rank}][{self.name}] " f"Creating PipelineStage:\n" @@ -96,6 +113,16 @@ def __init__( if not found_node: raise AssertionError(f"Cannot find {self.name} in graph") + if ( + self.inner_depth > 1 + ): # when inner pipelining is enabled, we have multiple nodes for this rank + self.nodes = [] + for node in self.split_gm.graph.nodes: + if node.name in self.names: + self.nodes.append(node) + if len(self.nodes) == 0: + raise AssertionError(f"Cannot find {self.names} in graph") + # Find my backward node in graph if self.pipe.has_loss_and_backwards: found_bwd = False @@ -112,10 +139,29 @@ def __init__( f"Cannot find backward for {self.name} in graph" ) + if ( + self.inner_depth > 1 + ): # when inner pipelining is enabled, we have multiple bwd nodes for this rank + self.bwd_nodes = [] + seen_bwd = -1 + added_bwd = 0 + for node in reversed(self.split_gm.graph.nodes): + if (node.op, node.target) == ("call_function", stage_backward): + seen_bwd += 1 + if seen_bwd // self.inner_depth == self.rank: + self.bwd_nodes.append(node) + added_bwd += 1 + if added_bwd == self.inner_depth: + break + if len(self.bwd_nodes) == 0: + raise AssertionError( + f"Cannot find backward for {self.names} in graph" + ) + # Create submod to rank mapping self.submod_to_rank: Dict[str, int] = {} for i, (name, _) in enumerate(self.split_gm.named_children()): - self.submod_to_rank.setdefault(name, i) + self.submod_to_rank.setdefault(name, i // self.inner_depth) # Prepare send/recv infrastructure self._prepare_send_recv_infra() diff --git a/pippy/compile.py b/pippy/compile.py index 54d2fa13b..1061ec9d6 100644 --- a/pippy/compile.py +++ b/pippy/compile.py @@ -233,6 +233,7 @@ def compile_stage( output_chunk_spec=None, schedule="FillDrain", nstreams=2, + num_stages=None, **kwargs, ) -> PipelineStage: # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across @@ -308,6 +309,20 @@ def compile_stage( kwargs_chunk_spec, output_chunk_spec, ) + elif schedule == "TwoLevel": + return PipelineStage( + pipe, + rank, + num_ranks, + num_chunks, + device, + group, + args_chunk_spec, + kwargs_chunk_spec, + output_chunk_spec, + nstreams=nstreams, + global_depth=num_stages, + ) else: return PipelineStage( pipe,