diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index bd8ff8ba8d8c0..ae5a4a4d78c69 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -516,11 +516,17 @@ def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" try: - assert param.size() == loaded_weight.size(), ( - f"Attempted to load weight ({loaded_weight.size()}) " - f"into parameter ({param.size()})") - - param.data.copy_(loaded_weight) + if param.numel() == 1 and loaded_weight.numel() == 1: + # Sometimes scalar values aren't considered tensors with shapes + # so if both param and loaded_weight are a scalar, + # "broadcast" instead of copy + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) " + f"into parameter ({param.size()})") + + param.data.copy_(loaded_weight) except Exception: # NOTE: This exception is added for the purpose of setting breakpoint to # debug weight loading issues.