From 21313e09e3f9448817016290da20d0db1adf3664 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 15 Aug 2024 16:10:22 -0400 Subject: [PATCH] [Bugfix] Fix default weight loading for scalars (#7534) --- vllm/model_executor/model_loader/weight_utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index bd8ff8ba8d8c0..ae5a4a4d78c69 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -516,11 +516,17 @@ def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" try: - assert param.size() == loaded_weight.size(), ( - f"Attempted to load weight ({loaded_weight.size()}) " - f"into parameter ({param.size()})") - - param.data.copy_(loaded_weight) + if param.numel() == 1 and loaded_weight.numel() == 1: + # Sometimes scalar values aren't considered tensors with shapes + # so if both param and loaded_weight are a scalar, + # "broadcast" instead of copy + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) " + f"into parameter ({param.size()})") + + param.data.copy_(loaded_weight) except Exception: # NOTE: This exception is added for the purpose of setting breakpoint to # debug weight loading issues.