Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
moonbucks committed Aug 8, 2023
1 parent a090a83 commit e0b1357
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions pippy/PipelineStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def __init__(
self.global_depth = nstages
self.inner_depth = 1

self.pipe_cache = [{} for i in range(chunks)] # [chunk][inner_rank] = value
self.pipe_cache: List[Dict[int, tuple]] = [
{} for i in range(chunks)
] # [chunk][inner_rank] = value

# Find my submodule
self.split_gm = self.pipe.split_gm
Expand Down Expand Up @@ -565,7 +567,6 @@ def forward_maybe_with_nosync(self, inner_rank, targets, *args, **kwargs):
out_val = self.submods[inner_rank](*args, **kwargs)
return out_val


def forward_maybe_with_nosync_save(self, targets, *args, **kwargs):
# If submod is wrapped with DDP, we use the `no_sync` context manager to
# avoid gradient all-reduce per microbatch
Expand Down Expand Up @@ -691,22 +692,29 @@ def forward_one_chunk_ipipe(
)

try:
if self.rank == self.nstages - 1: # last stage
if self.rank == self.nstages - 1: # last stage
output = self.forward_maybe_with_nosync(
0, targets, *composite_args, **composite_kwargs
)
self.pipe_cache[chunk][0] = output

for i in range(1, self.inner_depth-1):
for i in range(1, self.inner_depth - 1):
output = self.forward_maybe_with_nosync(
i, targets, self.pipe_cache[chunk][i-1], **composite_kwargs
i,
targets,
self.pipe_cache[chunk][i - 1],
**composite_kwargs,
)
self.pipe_cache[chunk][i] = output

output = self.forward_maybe_with_nosync(
self.inner_depth-1, targets, self.pipe_cache[chunk][self.inner_depth-2], targets, **composite_kwargs
) # self.inner_pipe >= 2 is asserted
self.pipe_cache[chunk][self.inner_depth-1] = output
self.inner_depth - 1,
targets,
self.pipe_cache[chunk][self.inner_depth - 2],
targets,
**composite_kwargs,
) # self.inner_pipe >= 2 is asserted
self.pipe_cache[chunk][self.inner_depth - 1] = output
else:
# 0th (first) inner node
output = self.forward_maybe_with_nosync(
Expand All @@ -717,7 +725,10 @@ def forward_one_chunk_ipipe(
# other inner nodes uses 'output' of previous inner node
for i in range(1, self.inner_depth):
output = self.forward_maybe_with_nosync(
i, targets, self.pipe_cache[chunk][i-1], **composite_kwargs
i,
targets,
self.pipe_cache[chunk][i - 1],
**composite_kwargs,
)
self.pipe_cache[chunk][i] = output

Expand Down Expand Up @@ -747,7 +758,6 @@ def forward_one_chunk_ipipe(

return output, send_reqs


def backward_one_chunk(
self,
bwd_chunk: int,
Expand Down Expand Up @@ -799,7 +809,6 @@ def forward(self, *args, **kwargs):
f"[Rank{self.rank}] ArgsSplit: {args_split}, KwargsSplit: {kwargs_split}"
)


# Activation send requests of all chunk
all_send_reqs: List[dist.Work] = []

Expand Down

0 comments on commit e0b1357

Please sign in to comment.