Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model(X, labels=Y, return_dict=True).loss is wrong #133

Closed
jettjaniak opened this issue Apr 27, 2024 · 5 comments · Fixed by #137
Closed

model(X, labels=Y, return_dict=True).loss is wrong #133

jettjaniak opened this issue Apr 27, 2024 · 5 comments · Fixed by #137

Comments

@jettjaniak
Copy link
Contributor

it should be X, labels=X
ideally we would force it to do what it was supposed to, instead of shifting tokens on it's own
but if we can't we need to adjust the design of tokenization script to produce sequences of seq_len (512) instead of seq_len+1 (513)

@jettjaniak
Copy link
Contributor Author

we need some performance test to catch issues like this in the future

@jaidhyani
Copy link
Collaborator

jaidhyani commented Apr 27, 2024

Don't we still want to pass it seq_len+1? If it's converting it to internally we still get 512 positions on inputs of length 513, right?

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1233

@SrGonao
Copy link
Contributor

SrGonao commented Apr 29, 2024

I agree with jai, we don't need to change the tokenizer. @jaidhyani could you make a PR to fix this loss?

@jaidhyani
Copy link
Collaborator

jaidhyani commented Apr 29, 2024 via email

@jettjaniak
Copy link
Contributor Author

Don't we still want to pass it seq_len+1? If it's converting it to internally we still get 512 positions on inputs of length 513, right?

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1233

I'm not sure what you're linking to, probably line numer shifted with some new commits on main. I believe it'll compute logits for all input ids, and then discard logits for last position when computing loss.

https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/llama/modeling_llama.py#L1213-L1214

It's not a big deal, but considering this I think inputs should be seq_len. LMK if you have strong takes

@jettjaniak jettjaniak linked a pull request May 15, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants