-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch native 2D LLaMA inference #922
base: main
Are you sure you want to change the base?
Conversation
examples/llama/2d_llama.py
Outdated
f"{layer_name}_mlp_down_proj": RowwiseParallel(), | ||
}) | ||
tp_mesh = mesh_2d["tp"] | ||
parallelize_module(stage.submod, tp_mesh, plan) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What you need to do for a submod is that:
self.n_local_heads = self.n_local_heads // tp_degree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried, but doesn't seem to work in this case.
- Trial 1: modify
num_heads
before tracing
Tracer will error out:
shape '[4, 4, 8, 128]' is invalid for input of size 65536
i.e. tracer seems to be very smart to do some size check. In this case, the input size is still the old size, so they mismatch.
- Trial 2: modify
num_heads
after tracing
This won't work because the old value ofnum_heads
has been burned into the generated program during tracing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What works:
Modify the view
and reshape
ops' arguments after tracing
e.g. Modified "view" node's arg from:
(l__self___model_layers_0_self_attn_q_proj, 4, 4, 32, 128) to
(l__self___model_layers_0_self_attn_q_proj, 4, 4, 8, 128)
See the added util modify_view
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see comments inlined. I don't think we should have this modify_view
operation to compose TP and PP, it's very dangerous to users. I think we should probably not trace into the TranformerBlock and leave it for TP, or make TP happen first.
examples/llama/2d_llama.py
Outdated
|
||
|
||
# Utility | ||
def modify_view( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmmm this is a very dangerous operation, we shouldn't compose PP + TP like this IMO. In particular this would modify any view operations in the traced graph and thereby making the graph be super unsafe to the user, also this is not scalable to another models especially non-llama model might have view operations in non-attention layers, this would either trigger the assertion failure or wrongly modify the view ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks much @wanchaol for the comments. I agree with your thoughts. modify_view
is more of a workaround than a general solution. For a better one, I wonder if the following direction is worth consideration.
Today, for TP to work well with attention, view ops of user code needs to be modified. This is true whether for eager mode or graphed mode. I wonder if it would be possible for the output of ColwiseParalell to play well with unchanged view ops. This may require: the output of ColwiseParalell to stay in some kind of DTensor form (rather than a regular tensor), and for this DTensor form to have special rule with the view ops (i.e. recognizing that some dimension of the view ops needs to be adjusted based on how the source operation is distributed).
Any thoughts here?
Cc @fduwjj
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed modify_view() now
examples/llama/2d_llama.py
Outdated
mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp")) | ||
pp_group = mesh_2d["pp"].get_group() | ||
|
||
llama.to(device).eval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be great to start envisioning the deferred init for the next step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. We are working towards that direction. See for example this PR and the BERT example in it.
#923
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functionality wise is working for me and LGTM.
|
||
# We set this flag to true to allow operations on a mix of tensor and dtensor | ||
# arguments. The mix is a result of `use_local_output=False` | ||
DTensor._op_dispatcher._allow_implicit_replication = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should set this flag, it's a hack and only suppose to be used by FSDP...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does it do?
# HACK: we convert DTensor to regular tensor here for it to | ||
# work with send ops. DTensor may show up in PP + TP cases. | ||
out.to_local() | ||
if isinstance(out, torch.distributed._tensor.DTensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should understand why isend would get a DTensor, if Pipeline split each transformerblock it should not get DTensor as inputs
I still think the cleanest fix here is to make PP tracing + |
If they are all tensors, |
Documenting my discussion with @wanchaol wrt to DTensor and @kwen2501 : @wanchaol : @kwen2501 : @wanchaol : |
Current status
Working
Previous issues:
TP self attention hitting the following issue:
4 * 4 * 32 * 128 = 65536
65536 / 4 = 16384 (4 is my TP size)
so that explains it.
User code:
Cc: @fduwjj @wanchaol @HamidShojanazeri @wconstab
Can you shed light here?
@fduwjj mentioned that we would need to modify self.n_local_heads to be 4 times smaller -- whether in eager case or traced case.
In the traced case, I can modify the view node to change its arg, for example, 32 -> 8. That's slightly better than asking user to modify model code. But, is there a better way?