Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inner cache implemented #858

Merged
merged 16 commits into from
Aug 11, 2023
Merged

inner cache implemented #858

merged 16 commits into from
Aug 11, 2023

Conversation

moonbucks
Copy link
Contributor

Description

instead of direct passing, use internal cache

Type of change

  • [v] New feature (non-breaking change which adds functionality)

Checklist:

  • [v] Have you added tests that prove your fix is effective or that this feature works?
  • [v] Has code been commented, particularly in hard-to-understand areas?
  • [v] Have you made corresponding changes to the documentation?

# Find my submodule
self.split_gm = self.pipe.split_gm
named_children = list(self.split_gm.named_children())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fduwjj comment is left on blank line. if it indicate L91-L92 or L95, we will leave it for safety reason until we modify all functions to use 'list of submods' instead of one submod.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am indeed asking about L95.

@moonbucks
Copy link
Contributor Author

code rebased

Comment on lines 817 to 838
if self.inner_depth > 1:
# Forward pass of all chunks
for chunk in range(self.chunks):
s = self.streams[chunk % self.nstreams]
with torch.cuda.stream(s):
output, send_reqs = self.forward_one_chunk_ipipe(
chunk, args_split, kwargs_split, fwd_cache
)
all_send_reqs += send_reqs
# Prepare for final output merge or reduction
output_chunks[chunk] = output
else:
# Forward pass of all chunks
for chunk in range(self.chunks):
s = self.streams[chunk % self.nstreams]
with torch.cuda.stream(s):
output, send_reqs = self.forward_one_chunk(
chunk, args_split, kwargs_split, fwd_cache
)
all_send_reqs += send_reqs
# Prepare for final output merge or reduction
output_chunks[chunk] = output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code has lots of duplication here, can you kindly consolidate a little bit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaned in the subsequent commit

kwargs_split,
fwd_cache: Dict[int, Any],
):
if self.rank == self.nstages - 1:
Copy link
Contributor

@fduwjj fduwjj Aug 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, self.nstages is the global stages? If so does this line of code ever gets hit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.nstages has been used in original codebase so i added a new variable self.global_depth to store global stages. so self.nstages is same as pp_group_size. (# of ranks)


try:
if self.rank == self.nstages - 1: # last stage
output = self.forward_maybe_with_nosync(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need a inner_depth == 0 check here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no we don't. inner_depth == 1 is default (no inner pipelining) and when inner_depth == 1, this function is not called. Instead, it calls forward_one_chunk function.

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, maybe we can consolidate the logic and extrac common logic to make the code look cleaner?

@moonbucks moonbucks merged commit b740566 into pytorch:pp_tp_optimization Aug 11, 2023
21 of 25 checks passed
@moonbucks moonbucks deleted the ip3 branch August 11, 2023 01:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants