Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] enable doraPP #1117

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions pippy/PipelineSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .microbatch import merge_chunks, split_args_kwargs_into_chunks

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)



class PipelineSchedule(ABC):
Expand Down Expand Up @@ -776,3 +778,112 @@ def backward_stage_local_index(step):

# Return losses if there is a container passed in
self._update_losses(self._stages, losses)


class ScheduleDoraPP(PipelineScheduleMulti):
"""
This is interleaved dfs+bfs zero bubble schedule.
"""
def __init__(
self,
stages: List[PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
self.pp_group_size = stages[0].group_size
self.deallocate_pipeline_outputs = False
# TODO: is this limitation a must?
if n_microbatches % self.pp_group_size != 0:
raise ValueError(
"Interleaved 1F1B requires the number of microbatches to be a "
f"multiple of the number of pipeline ranks ({self.pp_group_size}), "
f"but got {n_microbatches}."
)

super().__init__(
stages=stages,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
output_merge_spec=output_merge_spec,
)

self.n_local_stages = len(stages)
self.rank = stages[0].group_rank

def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
microbatch_size: Optional[int] = None,
model_dim: Optional[int] = None,
):

"""
Operate on the microbatches for interleaved 1f1b schedule (https://arxiv.org/pdf/2104.04473.pdf).

Highest rank has a warmup (fwd only) count of [len(stages) - 1] * number of PP ranks
and each rank away from highest rank adds 2 warmup steps due to:
- one happened before highest rank's warmup started,
- one waiting for backward result to trickle down from highest rank

TODO: Interleaved 1F1B does not support using sorted_batch_isend_irecv()
because it requires recvs and sends from different peers
to execute in the same coalesced operation. As a result, this schedule does
not support models with skip connections.
"""
arg_mbs, kwarg_mbs = self._check_inputs(
arg_mbs, kwarg_mbs, target_mbs, losses
)

num_round = max(self._n_microbatches // self.pp_group_size, 1)
assert (
self._n_microbatches % num_round == 0
), "Number of microbatches should be divisible by number of pipeline rounds."
# the number of microbatches run in each round, in dfs it is pipeline_parallel_size
num_microbatch_per_round = self._n_microbatches // num_round

total_num_microbatches = self._n_microbatches * self.n_local_stages

# increment warmup_steps by 2 for each hop away
num_warmup_microbatches = 0
# The number of microbatches that last pipeline stage run before 1f1b.
num_warmup_microbatches += (self.n_local_stages - 1) * num_microbatch_per_round
# From last PP stage up, each rank will be 2 more than the previous one.
num_warmup_microbatches += (
self.pp_group_size - self.rank - 1
) * 2

num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
if self._n_microbatches == self.pp_group_size:
num_1f1b_microbatches = self.rank
else:
num_1f1b_microbatches = 2 * self.rank

print("------------------------")
print(f"{num_warmup_microbatches=}, {num_microbatches_remaining=}, {num_1f1b_microbatches=}, {total_num_microbatches=}")
print("------------------------")

# Run warmup steps.
with record_function("warmup forward passes"):
fwd_wait_handles = None
bwd_wait_handles = None
for k in range(num_warmup_microbatches):
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()

with record_function("1f"):
print("forward step")

# Run 1F1B in steady state.
with record_function("forward 1F1B steady"):
for k in range(num_microbatches_remaining):
print("fwd_bwd")

with torch.profiler.record_function("cooldown backward"):
for k in range(num_microbatches_remaining, total_num_microbatches):
print("cooldown")
Loading