Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model结构 #187

Closed
daxian-lh opened this issue May 28, 2024 · 1 comment
Closed

model结构 #187

daxian-lh opened this issue May 28, 2024 · 1 comment

Comments

@daxian-lh
Copy link

在model中,有一块q = apply_rotary_emb_func(q, cos, sin, False, True);但是在fused_rotary_embedding中ApplyRotaryEmb,

def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
    """
        x: (batch_size, seqlen, nheads, headdim)
        cos, sin: (seqlen, rotary_dim / 2)
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
            of 1st half and 2nd half (GPT-NeoX style).
    rotary_dim must be <= headdim
    Apply rotary embedding to the first rotary_dim of x.
    """
    batch, seqlen, nheads, headdim = x.shape
    rotary_seqlen, rotary_dim = cos.shape
    rotary_dim *= 2
    assert rotary_dim <= headdim
    assert seqlen <= rotary_seqlen
    assert sin.shape == (rotary_seqlen, rotary_dim // 2)
    x_ro = x[..., :rotary_dim]
    x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
    out = torch.empty_like(x) if not inplace else x
    out_ro = out[..., :rotary_dim]
    if inplace:
        o1, o2 = x1, x2
    else:
        o1, o2 = (
            out_ro.chunk(2, dim=-1)
            if not interleaved
            else (out_ro[..., ::2], out_ro[..., 1::2])
        )
    rotary_emb.apply_rotary(
        x1,
        x2,
        rearrange(cos[:seqlen], "s d -> s 1 d"),
        rearrange(sin[:seqlen], "s d -> s 1 d"),
        o1,
        o2,
        False,
    )
    if not inplace and rotary_dim < headdim:
        out[..., rotary_dim:].copy_(x[..., rotary_dim:])
    ctx.save_for_backward(cos, sin)
    ctx.interleaved = interleaved
    ctx.inplace = inplace

return out if not inplace else x

当return==true时,返回的是x,并且在forward中,没有对x本身进行改变吧,这一块是不是有问题
有没有哪位大佬能解答一下,非常感谢

@jzhang38
Copy link
Owner

jzhang38 commented Jun 6, 2024

duplicate of #188

@jzhang38 jzhang38 closed this as completed Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants