You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the model.py file, there is a line of code in the class named CausalSelfAttention that says q = apply_rotary_emb_func(q, cos, sin, False, True). However, in the apply_rotary_emb_func.py file, in the forward process of the class named ApplyRotaryEmb, when inplace == true, the forward calculation process should not have changed q, and the returned value should still be the original q. This seems to pose some issues when applying rotation position encoding. The above is my superficial understanding; if any expert could explain this, I would greatly appreciate it.
class ApplyRotaryEmb(torch.autograd.Function): @staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (
out_ro.chunk(2, dim=-1)
if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2])
)
rotary_emb.apply_rotary(
x1,
x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
o1,
o2,
False,
)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x
The text was updated successfully, but these errors were encountered:
The value of the tensor x is already modified in_place after the rotary_emb.apply_rotary function. It is just that we did not initialize a separate memory to store the output value, we store the output value in the same location as the input tensor to save memory.
In the model.py file, there is a line of code in the class named CausalSelfAttention that says q = apply_rotary_emb_func(q, cos, sin, False, True). However, in the apply_rotary_emb_func.py file, in the forward process of the class named ApplyRotaryEmb, when inplace == true, the forward calculation process should not have changed q, and the returned value should still be the original q. This seems to pose some issues when applying rotation position encoding. The above is my superficial understanding; if any expert could explain this, I would greatly appreciate it.
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (
out_ro.chunk(2, dim=-1)
if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2])
)
rotary_emb.apply_rotary(
x1,
x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
o1,
o2,
False,
)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x
The text was updated successfully, but these errors were encountered: