-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inner cache implemented #858
Conversation
# Find my submodule | ||
self.split_gm = self.pipe.split_gm | ||
named_children = list(self.split_gm.named_children()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fduwjj comment is left on blank line. if it indicate L91-L92 or L95, we will leave it for safety reason until we modify all functions to use 'list of submods' instead of one submod.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am indeed asking about L95.
…last inner node (self.nodes[-1]) instead of self.node
code rebased |
pippy/PipelineStage.py
Outdated
if self.inner_depth > 1: | ||
# Forward pass of all chunks | ||
for chunk in range(self.chunks): | ||
s = self.streams[chunk % self.nstreams] | ||
with torch.cuda.stream(s): | ||
output, send_reqs = self.forward_one_chunk_ipipe( | ||
chunk, args_split, kwargs_split, fwd_cache | ||
) | ||
all_send_reqs += send_reqs | ||
# Prepare for final output merge or reduction | ||
output_chunks[chunk] = output | ||
else: | ||
# Forward pass of all chunks | ||
for chunk in range(self.chunks): | ||
s = self.streams[chunk % self.nstreams] | ||
with torch.cuda.stream(s): | ||
output, send_reqs = self.forward_one_chunk( | ||
chunk, args_split, kwargs_split, fwd_cache | ||
) | ||
all_send_reqs += send_reqs | ||
# Prepare for final output merge or reduction | ||
output_chunks[chunk] = output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code has lots of duplication here, can you kindly consolidate a little bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleaned in the subsequent commit
kwargs_split, | ||
fwd_cache: Dict[int, Any], | ||
): | ||
if self.rank == self.nstages - 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, self.nstages
is the global stages? If so does this line of code ever gets hit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.nstages
has been used in original codebase so i added a new variable self.global_depth
to store global stages. so self.nstages
is same as pp_group_size. (# of ranks)
|
||
try: | ||
if self.rank == self.nstages - 1: # last stage | ||
output = self.forward_maybe_with_nosync( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need a inner_depth == 0 check here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no we don't. inner_depth == 1 is default (no inner pipelining) and when inner_depth == 1, this function is not called. Instead, it calls forward_one_chunk function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, maybe we can consolidate the logic and extrac common logic to make the code look cleaner?
Description
instead of direct passing, use internal cache
Type of change
Checklist: