From f8add6ff944d0735f74ae7cadab03a7225624abc Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 6 Nov 2024 00:15:31 +0800 Subject: [PATCH] fix phi-3 tp --- vllm/model_executor/layers/linear.py | 34 ++++++++++++---------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1808dfc53072e..0492a7951a990 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -444,7 +444,10 @@ def weight_loader(self, param.data[loaded_shard_id].copy_(loaded_weight) param.shard_weight_type[loaded_shard_id] = loaded_weight.item() else: - param.weight_type = loaded_weight.item() + param.shard_weight_type = { + i: loaded_weight.item() + for i, _ in enumerate(self.output_sizes) + } return if is_gguf_weight: @@ -455,20 +458,15 @@ def weight_loader(self, shard_size = loaded_weight.size(output_dim) // tp_size start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) - if loaded_shard_id is not None: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) if len(param.data_container) == 2: self.qweight = param.materialize_nested() - else: - param.materialize(loaded_weight.shape, - dtype=loaded_weight.dtype) - param.data.copy_(loaded_weight) - return + return param_data = param.data output_dim = getattr(param, "output_dim", None) @@ -784,12 +782,15 @@ def weight_loader(self, is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) if is_gguf_weight_type: + idx_map = {"q": 0, "k": 1, "v": 2} if loaded_shard_id is not None: - idx_map = {"q": 0, "k": 1, "v": 2} param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) param.shard_weight_type[loaded_shard_id] = loaded_weight.item() else: - param.weight_type = loaded_weight.item() + param.shard_weight_type = { + k: loaded_weight.item() + for k in idx_map + } return if is_gguf_weight: @@ -800,20 +801,15 @@ def weight_loader(self, shard_size = loaded_weight.size(output_dim) // tp_size start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) - if loaded_shard_id is not None: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) if len(param.data_container) == 3: self.qweight = param.materialize_nested() - else: - param.materialize(loaded_weight.shape, - dtype=loaded_weight.dtype) - param.data.copy_(loaded_weight) - return + return param_data = param.data output_dim = getattr(param, "output_dim", None)