Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug when visualizing T5 models with generate #133

Open
salokr opened this issue May 21, 2024 · 2 comments
Open

Bug when visualizing T5 models with generate #133

salokr opened this issue May 21, 2024 · 2 comments

Comments

@salokr
Copy link

salokr commented May 21, 2024

I tried to visualize the attention maps for the T5 model but have encountered issues while getting the plots.

I would like to emphasize few points:

  • I have used model.generate because I don't have labels assumed in my real project and generate can also output attentions.
  • I have also tried passing the input directly to model but that also doesn't works.
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
outputs = model.generate(input_ids, output_attentions=True, return_dict_in_generate=True)

encoder_text = tokenizer.convert_ids_to_tokens(input_ids[0])
decoder_text = tokenizer.convert_ids_to_tokens(outputs.sequences[0])


from bertviz import model_view
model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens = decoder_text
)

The error I am getting is:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[56], line 6
      2 decoder_text = tokenizer.convert_ids_to_tokens(outputs.sequences[0])#tokenizer.convert_ids_to_tokens(decoder_input_ids[0])
      5 from bertviz import model_view
----> 6 model_view(
      7     encoder_attention=outputs.encoder_attentions,
      8     decoder_attention=outputs.decoder_attentions,
      9     cross_attention=outputs.cross_attentions,
     10     encoder_tokens= encoder_text,
     11     decoder_tokens = decoder_text
     12 )

File /Volumes/x/envs/my_env/lib/python3.8/site-packages/bertviz/model_view.py:147, in model_view(attention, tokens, sentence_b_start, prettify_tokens, display_mode, encoder_attention, decoder_attention, cross_attention, encoder_tokens, decoder_tokens, include_layers, include_heads, html_action)
    145     if include_heads is None:
    146         include_heads = list(range(n_heads))
--> 147     decoder_attention = format_attention(decoder_attention, include_layers, include_heads)
    148     attn_data.append(
    149         {
    150             'name': 'Decoder',
   (...)
    154         }
    155     )
    156 if cross_attention is not None:

File /Volumes/x/envs/my_env/lib/python3.8/site-packages/bertviz/util.py:10, in format_attention(attention, layers, heads)
      7 squeezed = []
      8 for layer_attention in attention:
      9     # 1 x num_heads x seq_len x seq_len
---> 10     if len(layer_attention.shape) != 4:
     11         raise ValueError("The attention tensor does not have the correct number of dimensions. Make sure you set "
     12                          "output_attentions=True when initializing your model.")
     13     layer_attention = layer_attention.squeeze(0)

AttributeError: 'tuple' object has no attribute 'shape'
@nikhilajoshy
Copy link

Looks like outputs.decoder_attentions is not a list of tensors but a list of tuples in your case. Please check the data type of attention variables.

@gpantaz
Copy link

gpantaz commented Aug 20, 2024

Any updates on this? facing a similar problem with t5 + generate :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants