Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP8 model fallback KVCache to bfloat16 #1505

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

changwangss
Copy link
Contributor

@changwangss changwangss commented Nov 20, 2024

I plan to load fp8 model with the following config, Linear is fp8 and kvcache and others op are bf16.

FP8Config(allowlist={"types": ["Linear"], "names": []}, blocklist=blocklist =  {"types": [], "names": []})

when use run_generation.py do model.generate, the error raised.

    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1556, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1606, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/changwang/workspace/vllm/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 1278, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1556, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1606, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/changwang/workspace/vllm/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 962, in forward
    hidden_states, self_attn_weights, present_key_value = self.pre_attn(
  File "/home/changwang/workspace/vllm/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 1019, in pre_attn
    hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
  File "/home/changwang/workspace/vllm/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 682, in pre_attn_forward
    key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
  File "/home/changwang/workspace/vllm/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 426, in update
    prev.index_copy_(dim, idx - 1, cur)
RuntimeError: index_copy_(): self and source expected to have the same dtype, but got (self) Float8_e4m3fn and (source) BFloat16

@@ -628,10 +628,14 @@ def pre_attn_forward(
else:
if past_key_value is None:
past_key = torch.zeros(
key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
key_states.shape,
dtype=torch.bfloat16 if isinstance(self.k_cache, KVCache) else self.get_k_proj_weight_dtype(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the function?
The default value is:
self.k_proj.weight.dtype

Copy link
Contributor Author

@changwangss changwangss Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for recipes FP8Config(allowlist={"types": ["Linear"], "names": []}, blocklist=blocklist = {"types": [], "names": []}), self.k_proj.weight.dtype is torch.float8_e4m3fn, but the past_key dtype should be torch.bfloat16

)
past_value = torch.zeros(
key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
key_states.shape,
dtype=torch.bfloat16 if isinstance(self.v_cache, KVCache) else self.get_k_proj_weight_dtype(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants