-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
generation utils update (minor) #1468
base: main
Are you sure you want to change the base?
Conversation
- Fix the type hint, dtype can not be a str - Fix the device hint - Remove the pad token id arg, the decoder_attention_mask is a binary of 0, and 1
- Added an early return - Extracted is_mqa_model and lazy_mode to avoid repeated dictionary lookups - Used more descriptive variable names and simplified the nested loops for better readability
The text-generation CI has been executed and will be compared with the main branch once the run is complete. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yafshar , Just a couple of comments below.
Please post results of CI, before and after change.
@yafshar , Makes sense. |
@yafshar , Could you post CI results. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Aligned with @emascarenhas, let's make sure there is no regression in generation tests and then I'll merge it 🙂
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I am doing slow CI tests |
@yafshar , Can you post results here from the CI tests. Thanks. |
I just finished the CI tests on both main and this PR on the same machine >>> python -m pytest tests/test_text_generation_example.py tests/test_encoder_decoder.py -v -s
4 failed, 59 passed I checked the failures -> test_text_generation_bf16_1x[token0-EleutherAI/gpt-j-6b-1-False-160.5823842101192-False]
pr: FAILED tests/test_text_generation_example.py::test_text_generation_bf16_1x[token0-google/gemma-7b-1-False-109.70751574382221-True] - AssertionError: assert 'DeepSpeed is...be efficient,' == 'DeepSpeed is... PyTorch, and'
main: FAILED tests/test_text_generation_example.py::test_text_generation_bf16_1x[token0-google/gemma-7b-1-False-109.70751574382221-True] - AssertionError: assert 'DeepSpeed is...be efficient,' == 'DeepSpeed is... PyTorch, and'
-> test_text_generation_bf16_1x[token0-state-spaces/mamba-130m-hf-1536-False-5385.511100161605-False]
pr: FAILED tests/test_text_generation_example.py::test_text_generation_bf16_1x[token0-state-spaces/mamba-130m-hf-1536-False-5385.511100161605-False] - assert 4895.173518373703 >= ((2 - 1.05) * 5385.511100161605)
main: FAILED tests/test_text_generation_example.py::test_text_generation_bf16_1x[token0-state-spaces/mamba-130m-hf-1536-False-5385.511100161605-False] - assert 4895.212904578489 >= ((2 - 1.05) * 5385.511100161605)
-> test_text_generation_bf16_1x[token0-Deci/DeciLM-7B-1-False-120-False]
pr: FAILED tests/test_text_generation_example.py::test_text_generation_bf16_1x[token0-Deci/DeciLM-7B-1-False-120-False] - assert 107.58924903315328 >= ((2 - 1.05) * 120)
main: FAILED tests/test_text_generation_example.py::test_text_generation_bf16_1x[token0-Deci/DeciLM-7B-1-False-120-False] - assert 107.56332773820075 >= ((2 - 1.05) * 120)
-> test_text_generation_fp8[token0-tiiuae/falcon-180B-4-950-True-128-128-2506.68]
pr: FAILED tests/test_text_generation_example.py::test_text_generation_fp8[token0-tiiuae/falcon-180B-4-950-True-128-128-2506.68] - AssertionError: The following command failed:
main: FAILED tests/test_text_generation_example.py::test_text_generation_fp8[token0-tiiuae/falcon-180B-4-950-True-128-128-2506.68] - AssertionError: The following command failed: The failures are exactly the same. |
@regisss @emascarenhas I do not see any regression. The behavior is the same as far as I tested |
What does this PR do?
transformers.streamers
->transformers.generation.streamers
return x.index_fill(1, torch.tensor(0), 1)
uses the wrong index oftorch.tensor(0)
, it is fixed to the correct index on the correct deviceindex = torch.tensor(0, device=device)
Before submitting