Skip to content

Commit

Permalink
fix phi-3 tp
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py committed Nov 5, 2024
1 parent f4c78cd commit f8add6f
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,10 @@ def weight_loader(self,
param.data[loaded_shard_id].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.weight_type = loaded_weight.item()
param.shard_weight_type = {
i: loaded_weight.item()
for i, _ in enumerate(self.output_sizes)
}
return

if is_gguf_weight:
Expand All @@ -455,20 +458,15 @@ def weight_loader(self,
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size

loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
else:
param.materialize(loaded_weight.shape,
dtype=loaded_weight.dtype)
param.data.copy_(loaded_weight)
return
return

param_data = param.data
output_dim = getattr(param, "output_dim", None)
Expand Down Expand Up @@ -784,12 +782,15 @@ def weight_loader(self,
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
idx_map = {"q": 0, "k": 1, "v": 2}
if loaded_shard_id is not None:
idx_map = {"q": 0, "k": 1, "v": 2}
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.weight_type = loaded_weight.item()
param.shard_weight_type = {
k: loaded_weight.item()
for k in idx_map
}
return

if is_gguf_weight:
Expand All @@ -800,20 +801,15 @@ def weight_loader(self,
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size

loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
else:
param.materialize(loaded_weight.shape,
dtype=loaded_weight.dtype)
param.data.copy_(loaded_weight)
return
return

param_data = param.data
output_dim = getattr(param, "output_dim", None)
Expand Down

0 comments on commit f8add6f

Please sign in to comment.