Skip to content

Commit

Permalink
Add input padding during prefill for qwen2-7b (#12033)
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang authored Sep 6, 2024
1 parent f61b178 commit d2e1b9a
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,8 +934,21 @@ def forward(
" to max_prompt_len {self.max_prompt_len}"
),
)
self.prefill_input_queue.put((hidden_states, position_ids, attention_mask, past_key_value))
return self.prefill_result_queue.get()
pad_len = self.max_prompt_len - seq_len
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
position_ids = F.pad(position_ids, (0, pad_len), value=0)
attention_mask = F.pad(
attention_mask.to(torch.float16),
(0, pad_len, 0, pad_len),
value=torch.finfo(torch.float16).min,
)

args = (hidden_states, position_ids, attention_mask, past_key_value)
self.prefill_input_queue.put(args)
hidden_states, past_key_value = self.prefill_result_queue.get()
past_key_value.shrink(seq_len, self.transpose_value_cache)
hidden_states = hidden_states[:, :seq_len, :]
return hidden_states, past_key_value

def shutdown(self):
self.prefill_input_queue.put("stop")
Expand Down

0 comments on commit d2e1b9a

Please sign in to comment.