Skip to content

Commit

Permalink
code cleaned
Browse files Browse the repository at this point in the history
  • Loading branch information
moonbucks committed Aug 9, 2023
1 parent e0b1357 commit 139442e
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions pippy/PipelineStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,28 +814,21 @@ def forward(self, *args, **kwargs):

output_chunks = [None] * self.chunks

if self.inner_depth > 1:
# Forward pass of all chunks
for chunk in range(self.chunks):
s = self.streams[chunk % self.nstreams]
with torch.cuda.stream(s):
# Forward pass of all chunks
for chunk in range(self.chunks):
s = self.streams[chunk % self.nstreams]
with torch.cuda.stream(s):
if self.inner_depth > 1:
output, send_reqs = self.forward_one_chunk_ipipe(
chunk, args_split, kwargs_split, fwd_cache
)
all_send_reqs += send_reqs
# Prepare for final output merge or reduction
output_chunks[chunk] = output
else:
# Forward pass of all chunks
for chunk in range(self.chunks):
s = self.streams[chunk % self.nstreams]
with torch.cuda.stream(s):
else:
output, send_reqs = self.forward_one_chunk(
chunk, args_split, kwargs_split, fwd_cache
)
all_send_reqs += send_reqs
# Prepare for final output merge or reduction
output_chunks[chunk] = output
all_send_reqs += send_reqs
# Prepare for final output merge or reduction
output_chunks[chunk] = output

# Wait for all sends to finish
# TODO: okay to delay the sync till completion of all chunks?
Expand Down

0 comments on commit 139442e

Please sign in to comment.