Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using sdpa if available #2359

Merged
merged 2 commits into from
Sep 30, 2024
Merged

using sdpa if available #2359

merged 2 commits into from
Sep 30, 2024

Conversation

jongwook
Copy link
Collaborator

No description provided.

@formater
Copy link

formater commented Nov 22, 2024

I encountered an issue with this patch.

File "/usr/local/src/transcriber_rt/lib/python3.11/site-packages/whisper/model.py", line 111, in forward
    wv, qk = self.qkv_attention(q, k, v, mask)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/src/transcriber_rt/lib/python3.11/site-packages/whisper/model.py", line 124, in qkv_attention
    a = scaled_dot_product_attention(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In this part of the code of model.py:

def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
            a = scaled_dot_product_attention(
                q, k, v, is_causal=mask is not None and n_ctx > 1
            )

Somehow we got Tensor and not a boolean in is_causal. If mask is a tensor, the resulting is_causal can itself be a tensor, depending on how it is handled. So maybe a potential solution would be to cast it as bool like:
is_causal = bool(mask is not None and n_ctx > 1) ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants