Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to run with BNB 4bit or 8bit quantization? #3

Open
fireicewolf opened this issue Oct 19, 2024 · 7 comments
Open

How to run with BNB 4bit or 8bit quantization? #3

fireicewolf opened this issue Oct 19, 2024 · 7 comments

Comments

@fireicewolf
Copy link

I tryed to modify your example code to run this model on lowvram card by BNB 4bit or 8bit quantization config.

While use bnb 4bit config like below:

qnt_config = BitsAndBytesConfig(load_in_4bit=True,
                                bnb_4bit_quant_type="nf4",
                                bnb_4bit_compute_dtype=torch.float16,
                                bnb_4bit_use_double_quant=True)

First time this issue occured while pixel_values = pixel_values.to(torch.bfloat16).unsqueeze(0)
RuntimeError: Input type (CUDABFloat16Type) and weight type (torch.cuda.HalfTensor) should be the same
Then I changed it to pixel_values = pixel_values.to(llm_dtype).unsqueeze(0)(llm_dtype is llava models weight load dtype)
RuntimeError: self and mat2 must have the same dtype, but got Half and Byte

these error should be caused by image input dtype.

Any idea to make it works?

@fpgaminer
Copy link
Owner

Keep pixel_values = pixel_values.to(torch.bfloat16).unsqueeze(0), but maybe try changing bnb_4bit_compute_dtype=torch.float16 to bnb_4bit_compute_dtype=torch.bfloat16?

@fireicewolf
Copy link
Author

Keep pixel_values = pixel_values.to(torch.bfloat16).unsqueeze(0), but maybe try changing bnb_4bit_compute_dtype=torch.float16 to bnb_4bit_compute_dtype=torch.bfloat16?

keep bnb_4bit_compute_dtype=torch.bfloat16 same to pixel_values.to(torch.bfloat16).unsqueeze(0)
will cause this error:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/gradio/queueing.py", line 622, in process_events
    response = await route_utils.call_process_api(
  File "/opt/conda/lib/python3.10/site-packages/gradio/route_utils.py", line 323, in call_process_api
    output = await app.get_blocks().process_api(
  File "/opt/conda/lib/python3.10/site-packages/gradio/blocks.py", line 2014, in process_api
    result = await self.call_function(
  File "/opt/conda/lib/python3.10/site-packages/gradio/blocks.py", line 1567, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/opt/conda/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2177, in run_sync_in_worker_thread
    return await future
  File "/opt/conda/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 859, in run
    result = context.run(func, *args)
  File "/opt/conda/lib/python3.10/site-packages/gradio/utils.py", line 846, in wrapper
    response = f(*args, **kwargs)
  File "/root/private_data/wd-joy-caption-cli/wd_llm_caption/gui.py", line 713, in caption_single_inference
    caption_text = get_caption_fn.my_llm.get_caption(
  File "/root/private_data/wd-joy-caption-cli/wd_llm_caption/utils/inference.py", line 607, in get_caption
    self.llm.generate(input_ids=input_ids, pixel_values=pixel_values,
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2047, in generate
    result = self._sample(
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3007, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llava/modeling_llava.py", line 453, in forward
    image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1189, in forward
    return self.vision_model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1100, in forward
    pooler_output = self.head(last_hidden_state) if self.use_head else None
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1127, in forward
    hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 5430, in multi_head_attention_forward
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
RuntimeError: self and mat2 must have the same dtype, but got BFloat16 and Byte

if use bnb 8bit config:

qnt_config = BitsAndBytesConfig(load_in_8bit=True,
                                llm_int8_enable_fp32_cpu_offload=True)

will cause this error:

/opt/conda/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:324: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/gradio/queueing.py", line 622, in process_events
    response = await route_utils.call_process_api(
  File "/opt/conda/lib/python3.10/site-packages/gradio/route_utils.py", line 323, in call_process_api
    output = await app.get_blocks().process_api(
  File "/opt/conda/lib/python3.10/site-packages/gradio/blocks.py", line 2014, in process_api
    result = await self.call_function(
  File "/opt/conda/lib/python3.10/site-packages/gradio/blocks.py", line 1567, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/opt/conda/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2177, in run_sync_in_worker_thread
    return await future
  File "/opt/conda/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 859, in run
    result = context.run(func, *args)
  File "/opt/conda/lib/python3.10/site-packages/gradio/utils.py", line 846, in wrapper
    response = f(*args, **kwargs)
  File "/root/private_data/wd-joy-caption-cli/wd_llm_caption/gui.py", line 713, in caption_single_inference
    caption_text = get_caption_fn.my_llm.get_caption(
  File "/root/private_data/wd-joy-caption-cli/wd_llm_caption/utils/inference.py", line 607, in get_caption
    self.llm.generate(input_ids=input_ids, pixel_values=pixel_values,
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2047, in generate
    result = self._sample(
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3007, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llava/modeling_llava.py", line 453, in forward
    image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1189, in forward
    return self.vision_model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1100, in forward
    pooler_output = self.head(last_hidden_state) if self.use_head else None
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1127, in forward
    hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 5430, in multi_head_attention_forward
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
RuntimeError: self and mat2 must have the same dtype, but got BFloat16 and Char

@fpgaminer
Copy link
Owner

Looks like this is a bug in transformers. I've submitted a bug report and pull request to get it fixed: huggingface/transformers#34294

We'll have to wait for that to get fixed and a new version of transformers released to have a clean fix here.

@fireicewolf
Copy link
Author

Thanks for your help, let's wait hf response.

@Tablaski
Copy link

Interested as well, could never run NF4 and original model is wayy to slow on my setup :-(

@Tablaski
Copy link

@fpgaminer some news ? I've read the reply from the transformer github and tried their solution but it didn't change anything

@effusiveperiscope
Copy link

bitsandbytes-foundation/bitsandbytes#963 seems to be related. I tried with the current transformers github as of today. It looks like in modeling_siglip.py self.attention.out_proj.weight has a uint8 type after 4-bit quantization, and the multihead attention calculation fails with the same error. I don't know enough about quantization to know whether this is correct behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants