Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch native 2D LLaMA inference #922

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

PyTorch native 2D LLaMA inference #922

wants to merge 6 commits into from

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Dec 21, 2023

Current status

Working

# PP = 2, TP = 4
$ torchrun --nproc-per-node 8 pippy_llama.py
['make', 'think', 'you', 'be', 'getting', 'great', 'favorite', 'right']
['make', 'think', 'you', 'be', 'getting', 'great', 'favorite', 'right']
['make', 'think', 'you', 'be', 'getting', 'great', 'favorite', 'right']
['make', 'think', 'you', 'be', 'getting', 'great', 'favorite', 'right']

Previous issues:

TP self attention hitting the following issue:

view = l__self___model_layers_0_self_attn_q_proj.view(4, 4, 32, 128);  l__self___model_layers_0_self_attn_q_proj = None
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[4, 4, 32, 128]' is invalid for input of size 16384

4 * 4 * 32 * 128 = 65536
65536 / 4 = 16384 (4 is my TP size)
so that explains it.
User code:

xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)

Cc: @fduwjj @wanchaol @HamidShojanazeri @wconstab

Can you shed light here?
@fduwjj mentioned that we would need to modify self.n_local_heads to be 4 times smaller -- whether in eager case or traced case.
In the traced case, I can modify the view node to change its arg, for example, 32 -> 8. That's slightly better than asking user to modify model code. But, is there a better way?

f"{layer_name}_mlp_down_proj": RowwiseParallel(),
})
tp_mesh = mesh_2d["tp"]
parallelize_module(stage.submod, tp_mesh, plan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you need to do for a submod is that:

self.n_local_heads = self.n_local_heads // tp_degree

Copy link
Contributor Author

@kwen2501 kwen2501 Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried, but doesn't seem to work in this case.

  • Trial 1: modify num_heads before tracing
    Tracer will error out:
shape '[4, 4, 8, 128]' is invalid for input of size 65536

i.e. tracer seems to be very smart to do some size check. In this case, the input size is still the old size, so they mismatch.

  • Trial 2: modify num_heads after tracing
    This won't work because the old value of num_heads has been burned into the generated program during tracing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What works:
Modify the view and reshape ops' arguments after tracing
e.g. Modified "view" node's arg from:
(l__self___model_layers_0_self_attn_q_proj, 4, 4, 32, 128) to
(l__self___model_layers_0_self_attn_q_proj, 4, 4, 8, 128)
See the added util modify_view

@kwen2501 kwen2501 changed the title [WIP] PyTorch native 2D LLaMA PyTorch native 2D LLaMA Dec 28, 2023
@kwen2501 kwen2501 changed the title PyTorch native 2D LLaMA PyTorch native 2D LLaMA inference Dec 28, 2023
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see comments inlined. I don't think we should have this modify_view operation to compose TP and PP, it's very dangerous to users. I think we should probably not trace into the TranformerBlock and leave it for TP, or make TP happen first.



# Utility
def modify_view(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmm this is a very dangerous operation, we shouldn't compose PP + TP like this IMO. In particular this would modify any view operations in the traced graph and thereby making the graph be super unsafe to the user, also this is not scalable to another models especially non-llama model might have view operations in non-attention layers, this would either trigger the assertion failure or wrongly modify the view ops.

Copy link
Contributor Author

@kwen2501 kwen2501 Jan 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks much @wanchaol for the comments. I agree with your thoughts. modify_view is more of a workaround than a general solution. For a better one, I wonder if the following direction is worth consideration.
Today, for TP to work well with attention, view ops of user code needs to be modified. This is true whether for eager mode or graphed mode. I wonder if it would be possible for the output of ColwiseParalell to play well with unchanged view ops. This may require: the output of ColwiseParalell to stay in some kind of DTensor form (rather than a regular tensor), and for this DTensor form to have special rule with the view ops (i.e. recognizing that some dimension of the view ops needs to be adjusted based on how the source operation is distributed).
Any thoughts here?
Cc @fduwjj

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed modify_view() now

mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp"))
pp_group = mesh_2d["pp"].get_group()

llama.to(device).eval()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be great to start envisioning the deferred init for the next step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. We are working towards that direction. See for example this PR and the BERT example in it.
#923

Copy link
Contributor

@HamidShojanazeri HamidShojanazeri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functionality wise is working for me and LGTM.


# We set this flag to true to allow operations on a mix of tensor and dtensor
# arguments. The mix is a result of `use_local_output=False`
DTensor._op_dispatcher._allow_implicit_replication = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should set this flag, it's a hack and only suppose to be used by FSDP...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it do?

# HACK: we convert DTensor to regular tensor here for it to
# work with send ops. DTensor may show up in PP + TP cases.
out.to_local()
if isinstance(out, torch.distributed._tensor.DTensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should understand why isend would get a DTensor, if Pipeline split each transformerblock it should not get DTensor as inputs

@wanchaol
Copy link
Contributor

wanchaol commented Jan 3, 2024

I still think the cleanest fix here is to make PP tracing + unflattener work, otherwise we should probably wait for DTensor supports scaled dot product attention op instead, the current thing that use_local_outputs works surprised me, I think the only reason is that the llama 7B does not use scaled_dot_product_attention

@fduwjj
Copy link
Contributor

fduwjj commented Jan 8, 2024

If they are all tensors, scaled_dot_product_attention should work as long as we pass in correct sizes?

@kwen2501
Copy link
Contributor Author

Documenting my discussion with @wanchaol wrt to DTensor and scaled_dot_product_attention:

@kwen2501 :
Should we do to_local as soon as we did colwise, or should we do to_local when we hit some op like scaled dot product, or should we have scaled dot product support a local form of DTensor. Maybe 2 and 3 are the same thing, meaning, the dispatcher of DTensor performs a to_local before calling the actual scaled dot product.

@wanchaol :
The current way is that we do to_local as soon as we leave the linear layer computation, this is the easiest thing to do with module forward hooks, if we do to_local as soon as we hit op like scaled dot product attention, I feel this is technically like implementing the scaled dot product attention op already. i.e. when implementing a DTensor op, we just figure out the sharding and then call local tensor with the op

@kwen2501 :
Now, in this case, the view ops are between colwise and scaled dot product. So it seems that the delayed route would work better. But i do agree that, if without the view ops, the early route would be easier. This means, the delayed route is a user choice (likely non-default), and we patch that route with DTensor support of scale dot product.

@wanchaol :
Yeah I think we should support both routes via use_local_output=False/True.
the delayed route require us to implement scaled dot product attention I think but it shouldn’t be too hard to enable it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants