Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can we get attention weights from example sequence and structure? #119

Open
what-is-what opened this issue Oct 8, 2024 · 2 comments

Comments

@what-is-what
Copy link

How can we get attention weights from example sequence and structure?
There were no arguments to get attention weights in transformer blocks, unlike esm2.

@gelnesr
Copy link

gelnesr commented Oct 20, 2024

also interested in this feature, if available!

@ebetica
Copy link
Contributor

ebetica commented Nov 11, 2024

Unfortunately, pytorch flash attention doesn't let you do this. You'll have to hack it in, we'll look into support it officially. Here's where the attention is computed, you'll just have use a pytorch implementation of attention to expose the attention matrix.

if seq_id is not None:
# Where True, enable participation in attention.
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask_BHLL = mask_BLL.unsqueeze(1)
context_BHLD = F.scaled_dot_product_attention(
query_BHLD, key_BHLD, value_BHLD, mask_BHLL
)
else:
# Shortcut, if we don't use attention biases then torch
# will autoselect flashattention as the implementation
context_BHLD = F.scaled_dot_product_attention(
query_BHLD, key_BHLD, value_BHLD
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants