Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gpt_bigcode: added internal bucketing fix #1526

Merged
merged 1 commit into from
Dec 9, 2024

Conversation

mgonchar
Copy link
Contributor

  • update kv-cache state inplace at decode phase
  • slice tensors with cache_idx to reduce excessive compute

This PR fixes lost context issue for gpt_bigcode class of models (starcoderbase/starcoder) when bucket_internal feature is used

It allows to unblock generation quality tests in https://github.com/huggingface/optimum-habana/blob/main/tests/test_text_generation_example.py#L36

For example, with command line

python3 examples/text-generation/run_generation.py --model_name_or_path bigcode/starcoder --batch_size 2 --use_kv_cache --max_new_tokens 100 --bucket_size 128 --bucket_internal --use_hpu_graphs --bf16  --prompt 'def print_hello_world():'

without this fix:

Input/outputs:
input 1: ('def print_hello_world():',)
output 1: ('def print_hello_world():\n    print("Hello World\n\t\n    print("Hello World!\n\n\n\n\n\n\n\n\n\n#\n\n\n\n\n\n\n\n\n#\n        """\n    print_name,\n#0\n\n#\n\n\n\n\n\n\n\n\n#0\n#00\n\n\n#\n\n\n\n\n\n\n\n\n0200000000..\n\n\n0ape_get)\n       0",0.',)

input 2: ('def print_hello_world():',)
output 1: ('def print_hello_world():\n    print("Hello World\n\t\n    print("Hello World!\n\n\n\n\n\n\n\n\n\n#\n\n\n\n\n\n\n\n\n#\n        """\n    print_name,\n#0\n\n#\n\n\n\n\n\n\n\n\n#0\n#00\n\n\n#\n\n\n\n\n\n\n\n\n0200000000..\n\n\n0ape_get)\n       0",0.',)

with this fix:

Input/outputs:
input 1: ('def print_hello_world():',)
output 1: ('def print_hello_world():\n    print("Hello World")\n\ndef print_hello_world_twice():\n    print_hello_world()\n    print_hello_world()\n\ndef print_hello_world_thrice():\n    print_hello_world()\n    print_hello_world()\n    print_hello_world()\n\ndef print_hello_world_four_times():\n    print_hello_world()\n    print_hello_world()\n    print_hello_world()\n   ',)

input 2: ('def print_hello_world():',)
output 1: ('def print_hello_world():\n    print("Hello World")\n\ndef print_hello_world_twice():\n    print_hello_world()\n    print_hello_world()\n\ndef print_hello_world_thrice():\n    print_hello_world()\n    print_hello_world()\n    print_hello_world()\n\ndef print_hello_world_four_times():\n    print_hello_world()\n    print_hello_world()\n    print_hello_world()\n   ',)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@mgonchar mgonchar requested a review from ZhaiFeiyue as a code owner November 26, 2024 15:47
@vidyasiv
Copy link
Contributor

@mgonchar , thanks very much for fixing this. Please update test to run with True flag: https://github.com/huggingface/optimum-habana/blob/main/tests/test_text_generation_example.py#L30-L36 and update output if necessary

@mgonchar mgonchar force-pushed the main_fix_bucket_internal branch from 21130ea to d580212 Compare November 27, 2024 18:19
@mgonchar mgonchar requested a review from regisss as a code owner November 27, 2024 18:19
@mgonchar
Copy link
Contributor Author

mgonchar commented Nov 27, 2024

@mgonchar , thanks very much for fixing this. Please update test to run with True flag: https://github.com/huggingface/optimum-habana/blob/main/tests/test_text_generation_example.py#L30-L36 and update output if necessary

sure, I've changed the test and launched locally, it passed.

Here is the result (I've commented out all other models except starcoder)

============================================================================================================================================= test session starts =============================================================================================================================================
platform linux -- Python 3.10.12, pytest-8.3.3, pluggy-1.5.0
rootdir: /var/work/optimum-habana
configfile: setup.cfg
collected 9 items                                                                                                                                                                                                                                                                                             

test_text_generation_example.py .sssssss.                                                                                                                                                                                                                                                               [100%]

============================================================================================================================================== warnings summary ===============================================================================================================================================
tests/test_text_generation_example.py::test_text_generation_bf16_1x[token0-bigcode/starcoder-256-True-6846.575763562658-True]
  /usr/lib/python3.10/inspect.py:288: FutureWarning: `torch.distributed.reduce_op` is deprecated, please use `torch.distributed.ReduceOp` instead
    return isinstance(object, types.FunctionType)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================================================================= 2 passed, 7 skipped, 1 warning in 77.88s (0:01:17) ==============================================================================================================================

@mgonchar
Copy link
Contributor Author

output is fine, in my understanding output of bucket vs no-bucket case should be same if bucket_size is equal in both cases

key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = torch.cat((key, value), dim=-1) if use_cache else None
key = past_key.index_copy_(1, token_idx - 1, key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you verify this works with tgi-gaudi.. out of place op was used to fix a specific issue when tensor cache is disabled otherwise we saw error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sent you ticket link of empty tensor optional error with tgi-gaudi

Copy link
Contributor Author

@mgonchar mgonchar Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vidyasiv I tried to rollback changes from your commit #1181 and it works for me on latest 1.18 with command line

PT_HPU_DISABLE_TENSOR_CACHE=1 python run_generation.py --model_name_or_path bigcode/starcoder --batch_size 2 --use_hpu_graphs --use_kv_cache --max_new_tokens 100 --bf16

and output is the same as without PT_HPU_DISABLE_TENSOR_CACHE variable. It seems that original issues was fixed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for bucket it also works fine and gives the same output:

PT_HPU_DISABLE_TENSOR_CACHE=1 python run_generation.py --model_name_or_path bigcode/starcoder --batch_size 2 --use_hpu_graphs --use_kv_cache  --max_new_tokens 100 --bf16 --bucket_size=128 --bucket_internal

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vidyasiv What's the TGI config that was leading to an error?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@regisss issue in tgi from original ticket:

# server:
text-generation-launcher --model-id bigcode/starcoderbase-3b --sharded false --hostname 127.0.0.1 --max-input-length 2048  --max-batch-size 8 --dtype bfloat16

# In container: 
docker run -it --runtime=habana --name gaudi-tgi-scb-3b-e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HUGGING_FACE_HUB_TOKEN=$HF_TOKEN -e ENABLE_HPU_GRAPH=True -e BATCH_BUCKET_SIZE=8  -e PREFILL_BATCH_BUCKET_SIZE=4  -e PAD_SEQUENCE_TO_MULTIPLE_OF=128 --cap-add=sys_nice --net=host --entrypoint bash tgi_gaudi

HF equivalent back then was to set PT_HPU_DISABLE_TENSOR_CACHE=1 and --use_hpu_graphs

Copy link

The code quality check failed, please run make style.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

- update kv-cache state inplace at decode phase
- slice tensors with cache_idx to reduce excessive compute
@mgonchar mgonchar force-pushed the main_fix_bucket_internal branch from d580212 to f2e494f Compare December 1, 2024 19:51
@mgonchar
Copy link
Contributor Author

mgonchar commented Dec 1, 2024

rebased, style fixed

Copy link
Contributor

@vidyasiv vidyasiv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm based on Miroslav's testing

@regisss regisss merged commit 9a4c6de into huggingface:main Dec 9, 2024
4 checks passed
zzhang37 pushed a commit to zzhang37/optimum-habana that referenced this pull request Dec 9, 2024
@mgonchar mgonchar deleted the main_fix_bucket_internal branch December 9, 2024 23:39
imangohari1 pushed a commit to imangohari1/optimum-habana that referenced this pull request Dec 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants