diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 0f01f5f819ea4..569fc8dfb6a21 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -9,7 +9,7 @@ import torch from tests.quantization.utils import is_quant_method_supported -from tests.utils import fork_new_process_for_each_test +from tests.utils import compare_two_settings, fork_new_process_for_each_test models_4bit_to_test = [ ("facebook/opt-125m", "quantize opt model inflight"), @@ -82,6 +82,34 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, vllm_tp_size=2) +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason='Test requires at least 2 GPUs.') +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) +@fork_new_process_for_each_test +def test_load_pp_4bit_bnb_model(model_name, description) -> None: + common_args = [ + "--disable-log-stats", + "--disable-log-requests", + "--dtype", + "bfloat16", + "--enable-prefix-caching", + "--quantization", + "bitsandbytes", + "--load-format", + "bitsandbytes", + "--gpu-memory-utilization", + "0.7", + ] + pp_args = [ + *common_args, + "--pipeline-parallel-size", + "2", + ] + compare_two_settings(model_name, common_args, pp_args) + + def log_generated_texts(prompts, outputs, runner_name): logged_texts = [] for i, (_, generated_text) in enumerate(outputs): diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 8d3024534734b..715e6c11f86ce 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -991,7 +991,13 @@ def _load_weights(self, model_config: ModelConfig, param_dict = dict(model.named_parameters()) stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + # TODO: Change this lazy import to normal import + # after the checks are updated to run on a new version + from vllm.model_executor.models.utils import is_pp_missing_parameter for quant_param_name in quant_state_dict: + if is_pp_missing_parameter(quant_param_name, model): + continue + non_stacked_param_name = quant_param_name shard_index = 0