Skip to content

Commit

Permalink
code cleaned
Browse files Browse the repository at this point in the history
  • Loading branch information
moonbucks committed Aug 9, 2023
1 parent e0b1357 commit 3942e82
Showing 1 changed file with 18 additions and 28 deletions.
46 changes: 18 additions & 28 deletions pippy/PipelineStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,38 +692,28 @@ def forward_one_chunk_ipipe(
)

try:
if self.rank == self.nstages - 1: # last stage
output = self.forward_maybe_with_nosync(
0, targets, *composite_args, **composite_kwargs
)
self.pipe_cache[chunk][0] = output

for i in range(1, self.inner_depth - 1):
# 0th (first) inner node
output = self.forward_maybe_with_nosync(
0, targets, *composite_args, **composite_kwargs
)
self.pipe_cache[chunk][0] = output

# other inner nodes uses 'output' of previous inner node
for i in range(1, self.inner_depth):
if (
self.rank == self.nstages - 1 and i == self.inner_depth - 1
): # last stage last inner node
# last inner node takes targets to calculate loss
output = self.forward_maybe_with_nosync(
i,
self.inner_depth - 1,
targets,
self.pipe_cache[chunk][self.inner_depth - 2],
targets,
self.pipe_cache[chunk][i - 1],
**composite_kwargs,
)
self.pipe_cache[chunk][i] = output

output = self.forward_maybe_with_nosync(
self.inner_depth - 1,
targets,
self.pipe_cache[chunk][self.inner_depth - 2],
targets,
**composite_kwargs,
) # self.inner_pipe >= 2 is asserted
self.pipe_cache[chunk][self.inner_depth - 1] = output
else:
# 0th (first) inner node
output = self.forward_maybe_with_nosync(
0, targets, *composite_args, **composite_kwargs
)
self.pipe_cache[chunk][0] = output
) # self.inner_pipe >= 2 is asserted
self.pipe_cache[chunk][self.inner_depth - 1] = output

# other inner nodes uses 'output' of previous inner node
for i in range(1, self.inner_depth):
else:
output = self.forward_maybe_with_nosync(
i,
targets,
Expand Down

0 comments on commit 3942e82

Please sign in to comment.