Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.py #188

Open
daxian-lh opened this issue May 28, 2024 · 1 comment
Open

model.py #188

daxian-lh opened this issue May 28, 2024 · 1 comment

Comments

@daxian-lh
Copy link

daxian-lh commented May 28, 2024

In the model.py file, there is a line of code in the class named CausalSelfAttention that says q = apply_rotary_emb_func(q, cos, sin, False, True). However, in the apply_rotary_emb_func.py file, in the forward process of the class named ApplyRotaryEmb, when inplace == true, the forward calculation process should not have changed q, and the returned value should still be the original q. This seems to pose some issues when applying rotation position encoding. The above is my superficial understanding; if any expert could explain this, I would greatly appreciate it.

class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (
out_ro.chunk(2, dim=-1)
if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2])
)
rotary_emb.apply_rotary(
x1,
x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
o1,
o2,
False,
)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x

@jzhang38
Copy link
Owner

jzhang38 commented Jun 6, 2024

The value of the tensor x is already modified in_place after the rotary_emb.apply_rotary function. It is just that we did not initialize a separate memory to store the output value, we store the output value in the same location as the input tensor to save memory.

@jzhang38 jzhang38 mentioned this issue Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants