From a1ab1cd9717a496bec8bf993a9432993aa1bf956 Mon Sep 17 00:00:00 2001 From: cyita Date: Thu, 26 Sep 2024 20:04:24 +0800 Subject: [PATCH 1/5] add npu generate --- .../src/ipex_llm/transformers/npu_generate.py | 236 ++++++++++++++++++ .../src/ipex_llm/transformers/npu_model.py | 3 + .../src/ipex_llm/transformers/speculative.py | 28 ++- 3 files changed, 256 insertions(+), 11 deletions(-) create mode 100644 python/llm/src/ipex_llm/transformers/npu_generate.py diff --git a/python/llm/src/ipex_llm/transformers/npu_generate.py b/python/llm/src/ipex_llm/transformers/npu_generate.py new file mode 100644 index 00000000000..da15f060022 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_generate.py @@ -0,0 +1,236 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation +# /candidate_generator.py and +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation +# /utils.py +# + +from typing import Any, Callable, Dict, List, Optional, Tuple +import os +import torch +import time +import numpy as np +import random +import logging +from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\ + _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify +from ipex_llm.utils.common import invalidInputError +from ipex_llm.transformers.utils import get_xpu_device_type + +logger = logging.getLogger("ipex_llm.npu") + +# patch GenerationMixin.generate +from transformers import GenerationMixin +original_generate = GenerationMixin.generate +query_group_size = 16 + + +@torch.no_grad() +def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, +): + return self.npu_generate(inputs=inputs, + generation_config=generation_config, + streamer=streamer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + **kwargs) + # if True: + # return self.lookup_generate(inputs=inputs, + # num_output_tokens=lookahead, + # generation_config=generation_config, + # streamer=streamer, + # logits_processor=logits_processor, + # stopping_criteria=stopping_criteria, + # prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + # **kwargs) + + + + # return original_generate(self, + # inputs=inputs, + # generation_config=generation_config, + # logits_processor=logits_processor, + # stopping_criteria=stopping_criteria, + # prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + # synced_gpus=synced_gpus, + # assistant_model=assistant_model, + # streamer=streamer, + # **kwargs) + +GenerationMixin.generate = generate + + +def clear_benchmarks(self): + self.first_token_time = None + self.last_token_time = [] + self.encoder_time = 0 + + +def _update_model_kwargs_for_generation(outputs, + model_kwargs: Dict[str, Any]): + model_kwargs["past_key_values"]= outputs["past_key_values"] + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + return model_kwargs + + +@torch.no_grad() +def npu_generate(self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + streamer: Optional["BaseStreamer"] = None, + **sampling_kwargs): + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!") + input_ids, generation_config, logits_processor, stopping_criteria, \ + model_kwargs = _prepare_generate_args(self, inputs, generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + streamer=streamer, + generate_attention_mask=True, + **sampling_kwargs) + + step = 0 + max_new_tokens = generation_config.max_new_tokens + + clear_benchmarks(self) + + input_len = input_ids.shape[1] + + eos_token_id_set = None + if generation_config.eos_token_id is not None: + if isinstance(generation_config.eos_token_id, list): + eos_token_id_set = set(generation_config.eos_token_id) + else: + eos_token_id_set = set([generation_config.eos_token_id]) + + while True: + if step >= max_new_tokens: + break + + tic = time.time() + + if step == 0: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + output = self(**model_inputs, + return_dict=True) + logits = output['logits'] + logits = logits[:, -1:] + logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :]) + if generation_config.do_sample: + output_ids, prob_list = deepmind_sample(logits, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + temperature=generation_config.temperature) + else: + output_ids = torch.argmax(logits, dim=-1) + input_ids = torch.cat((input_ids, output_ids), dim=-1) + + else: + # model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + model_inputs = { + "input_ids": input_ids[:, -1:], + "past_key_values": model_kwargs["past_key_values"], + # "position_ids": model_kwargs["position_ids"], + "use_cache": True, + "attention_mask": model_kwargs["attention_mask"], + } + + output = output = self(**model_inputs, + return_dict=True) + logits = output['logits'] + + logits[:, -1, :] = logits_processor(input_ids, + logits[:, -1, :]) + + if generation_config.do_sample: + output_ids, prob_list = deepmind_sample(logits, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + temperature=generation_config.temperature) + output_ids = output_ids.transpose(0, 1) + else: + output_ids = torch.argmax(logits, dim=-1) + + + + input_ids = torch.cat((input_ids, output_ids), dim=-1) + + step += 1 + + model_kwargs = _update_model_kwargs_for_generation( + output, model_kwargs + ) + + toc = time.time() + if self.first_token_time is None: + self.first_token_time = toc - tic + else: + self.last_token_time.append(toc - tic) + + # Stop on eos and remove content after eos + if eos_token_id_set is not None: + output_ids_list = output_ids[0].tolist() + first_eos_idx = -1 + for out_idx, out_id in enumerate(output_ids_list): + if out_id in eos_token_id_set: + first_eos_idx = out_idx + break + if first_eos_idx > -1: + if streamer is not None: + streamer.put(output_ids[:(first_eos_idx + 1)].cpu()) + step -= (len(output_ids_list) - first_eos_idx - 1) + break + if streamer is not None: + streamer.put(output_ids.cpu()) + + step = min(step, max_new_tokens) + # e2e_toc = time.time() + self.n_token_generated = step + # self.e2e_time_without_first = e2e_toc - e2e_tic + + + # if self.do_print: + print(f"=========First token cost {self.first_token_time:.4f} s=========") + if len(self.last_token_time) > 1: + self.first_cost = self.first_token_time + self.rest_cost_mean = np.mean(self.last_token_time) + # if self.do_print: + print(f"=========Rest tokens cost average {self.rest_cost_mean:.4f} s ({len(self.last_token_time)}" + f" tokens in all)=========") + + if streamer is not None: + streamer.end() + + return input_ids[:, : input_len + step] diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 56ca664cac9..519fd71b6a3 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -179,6 +179,9 @@ def from_pretrained(cls, *args, **kwargs): transpose_value_cache=transpose_value_cache, ) model.save_low_bit = types.MethodType(save_low_bit, model) + if model.config.model_type in ["qwen2", "llama"]: + from ipex_llm.transformers.npu_generate import npu_generate + model.npu_generate = types.MethodType(npu_generate, model) else: from ipex_llm.transformers.npu_models.convert import optimize_llm optimize_llm(model) diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 4600e99fefc..2c7e7dcac6f 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -526,7 +526,9 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa return past_key_values -def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sampling_kwargs): +def _prepare_generate_args(self, inputs, generation_config, streamer=None, logits_processor=None, + stopping_criteria=None, generate_attention_mask=False, + **sampling_kwargs): if generation_config is None: generation_config = self.generation_config @@ -551,8 +553,8 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam generation_config.pad_token_id = eos_token_id # 2. Set generation parameters if not already defined - logits_processor = LogitsProcessorList() - stopping_criteria = StoppingCriteriaList() + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() # 3. Define model inputs # inputs_tensor has to be defined @@ -574,15 +576,17 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam # else: # model_kwargs["use_cache"] = generation_config.use_cache - # accepts_attention_mask = "attention_mask" in set( - # inspect.signature(self.forward).parameters.keys()) - # requires_attention_mask = "encoder_outputs" not in model_kwargs + if generate_attention_mask: + import inspect + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs - # if model_kwargs.get("attention_mask", None) is None and \ - # requires_attention_mask and accepts_attention_mask: - # model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - # inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id - # ) + if model_kwargs.get("attention_mask", None) is None and \ + requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + ) # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: @@ -605,6 +609,8 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam # 5. Prepare `input_ids` which will be used for auto-regressive generation input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + # Skip max length + if streamer is not None: streamer.put(input_ids.cpu()) From d37a17131c35dc32b149734b7cfe6694f1020e61 Mon Sep 17 00:00:00 2001 From: cyita Date: Fri, 27 Sep 2024 09:39:39 +0800 Subject: [PATCH 2/5] num_beams fallback --- .../src/ipex_llm/transformers/npu_generate.py | 56 +++++++------------ 1 file changed, 21 insertions(+), 35 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_generate.py b/python/llm/src/ipex_llm/transformers/npu_generate.py index da15f060022..9d8b6be77cb 100644 --- a/python/llm/src/ipex_llm/transformers/npu_generate.py +++ b/python/llm/src/ipex_llm/transformers/npu_generate.py @@ -29,9 +29,8 @@ import logging from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\ - _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify + _crop_past_key_values, _prepare_generate_args from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.utils import get_xpu_device_type logger = logging.getLogger("ipex_llm.npu") @@ -54,35 +53,26 @@ def generate( streamer: Optional["BaseStreamer"] = None, **kwargs, ): - return self.npu_generate(inputs=inputs, - generation_config=generation_config, - streamer=streamer, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - **kwargs) - # if True: - # return self.lookup_generate(inputs=inputs, - # num_output_tokens=lookahead, - # generation_config=generation_config, - # streamer=streamer, - # logits_processor=logits_processor, - # stopping_criteria=stopping_criteria, - # prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - # **kwargs) - - - - # return original_generate(self, - # inputs=inputs, - # generation_config=generation_config, - # logits_processor=logits_processor, - # stopping_criteria=stopping_criteria, - # prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - # synced_gpus=synced_gpus, - # assistant_model=assistant_model, - # streamer=streamer, - # **kwargs) + if kwargs.get("num_beams", None) not in [None, 1]: + return original_generate(self, + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + **kwargs) + else: + return self.npu_generate(inputs=inputs, + generation_config=generation_config, + streamer=streamer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + **kwargs) + GenerationMixin.generate = generate @@ -112,7 +102,6 @@ def npu_generate(self, stopping_criteria: Optional[StoppingCriteriaList] = None, streamer: Optional["BaseStreamer"] = None, **sampling_kwargs): - print("!!!!!!!!!!!!!!!!!!!!!!!!!!!") input_ids, generation_config, logits_processor, stopping_criteria, \ model_kwargs = _prepare_generate_args(self, inputs, generation_config, logits_processor=logits_processor, @@ -158,7 +147,6 @@ def npu_generate(self, input_ids = torch.cat((input_ids, output_ids), dim=-1) else: - # model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = { "input_ids": input_ids[:, -1:], "past_key_values": model_kwargs["past_key_values"], @@ -183,8 +171,6 @@ def npu_generate(self, else: output_ids = torch.argmax(logits, dim=-1) - - input_ids = torch.cat((input_ids, output_ids), dim=-1) step += 1 From f4f54096a00aed02a5c15089fd1d5ce05b379547 Mon Sep 17 00:00:00 2001 From: cyita Date: Fri, 27 Sep 2024 14:33:41 +0800 Subject: [PATCH 3/5] fix --- .../llm/src/ipex_llm/transformers/npu_generate.py | 13 +++---------- python/llm/src/ipex_llm/transformers/speculative.py | 3 ++- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_generate.py b/python/llm/src/ipex_llm/transformers/npu_generate.py index 9d8b6be77cb..c58df38c57f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_generate.py +++ b/python/llm/src/ipex_llm/transformers/npu_generate.py @@ -15,8 +15,6 @@ # # Some parts of this file is adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/generation -# /candidate_generator.py and -# https://github.com/huggingface/transformers/blob/main/src/transformers/generation # /utils.py # @@ -85,7 +83,7 @@ def clear_benchmarks(self): def _update_model_kwargs_for_generation(outputs, model_kwargs: Dict[str, Any]): - model_kwargs["past_key_values"]= outputs["past_key_values"] + model_kwargs["past_key_values"]= outputs.past_key_values if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( @@ -145,14 +143,13 @@ def npu_generate(self, else: output_ids = torch.argmax(logits, dim=-1) input_ids = torch.cat((input_ids, output_ids), dim=-1) - else: model_inputs = { "input_ids": input_ids[:, -1:], "past_key_values": model_kwargs["past_key_values"], # "position_ids": model_kwargs["position_ids"], "use_cache": True, - "attention_mask": model_kwargs["attention_mask"], + "attention_mask": model_kwargs.get("attention_mask", None), } output = output = self(**model_inputs, @@ -202,19 +199,15 @@ def npu_generate(self, streamer.put(output_ids.cpu()) step = min(step, max_new_tokens) - # e2e_toc = time.time() self.n_token_generated = step - # self.e2e_time_without_first = e2e_toc - e2e_tic - # if self.do_print: print(f"=========First token cost {self.first_token_time:.4f} s=========") if len(self.last_token_time) > 1: self.first_cost = self.first_token_time self.rest_cost_mean = np.mean(self.last_token_time) - # if self.do_print: print(f"=========Rest tokens cost average {self.rest_cost_mean:.4f} s ({len(self.last_token_time)}" - f" tokens in all)=========") + f" tokens in all)=========") if streamer is not None: streamer.end() diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 2c7e7dcac6f..5e59b735ff2 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -554,7 +554,8 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, logit # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + stopping_criteria = stopping_criteria if stopping_criteria is not None \ + else StoppingCriteriaList() # 3. Define model inputs # inputs_tensor has to be defined From b0802cf04b437f7fc864b2cee179b3c9e283bb38 Mon Sep 17 00:00:00 2001 From: cyita Date: Fri, 27 Sep 2024 14:36:44 +0800 Subject: [PATCH 4/5] fix style --- .../src/ipex_llm/transformers/npu_generate.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_generate.py b/python/llm/src/ipex_llm/transformers/npu_generate.py index c58df38c57f..5ef3375de1f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_generate.py +++ b/python/llm/src/ipex_llm/transformers/npu_generate.py @@ -64,12 +64,12 @@ def generate( **kwargs) else: return self.npu_generate(inputs=inputs, - generation_config=generation_config, - streamer=streamer, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - **kwargs) + generation_config=generation_config, + streamer=streamer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + **kwargs) GenerationMixin.generate = generate @@ -83,7 +83,7 @@ def clear_benchmarks(self): def _update_model_kwargs_for_generation(outputs, model_kwargs: Dict[str, Any]): - model_kwargs["past_key_values"]= outputs.past_key_values + model_kwargs["past_key_values"] = outputs.past_key_values if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( @@ -175,7 +175,7 @@ def npu_generate(self, model_kwargs = _update_model_kwargs_for_generation( output, model_kwargs ) - + toc = time.time() if self.first_token_time is None: self.first_token_time = toc - tic @@ -201,13 +201,12 @@ def npu_generate(self, step = min(step, max_new_tokens) self.n_token_generated = step - print(f"=========First token cost {self.first_token_time:.4f} s=========") if len(self.last_token_time) > 1: self.first_cost = self.first_token_time self.rest_cost_mean = np.mean(self.last_token_time) - print(f"=========Rest tokens cost average {self.rest_cost_mean:.4f} s ({len(self.last_token_time)}" - f" tokens in all)=========") + print(f"=========Rest tokens cost average {self.rest_cost_mean:.4f} s " + f"({len(self.last_token_time)} tokens in all)=========") if streamer is not None: streamer.end() From 60378ea9d5c24c718a87494254b8555c8dce1e5c Mon Sep 17 00:00:00 2001 From: cyita Date: Fri, 27 Sep 2024 18:05:49 +0800 Subject: [PATCH 5/5] update decoder logits processor --- .../src/ipex_llm/transformers/npu_generate.py | 65 +++++++++++++------ 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_generate.py b/python/llm/src/ipex_llm/transformers/npu_generate.py index 5ef3375de1f..480175f78fd 100644 --- a/python/llm/src/ipex_llm/transformers/npu_generate.py +++ b/python/llm/src/ipex_llm/transformers/npu_generate.py @@ -29,6 +29,7 @@ from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\ _crop_past_key_values, _prepare_generate_args from ipex_llm.utils.common import invalidInputError +from transformers.modeling_outputs import CausalLMOutputWithPast logger = logging.getLogger("ipex_llm.npu") @@ -81,17 +82,6 @@ def clear_benchmarks(self): self.encoder_time = 0 -def _update_model_kwargs_for_generation(outputs, - model_kwargs: Dict[str, Any]): - model_kwargs["past_key_values"] = outputs.past_key_values - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - return model_kwargs - - @torch.no_grad() def npu_generate(self, inputs: Optional[torch.Tensor] = None, @@ -110,6 +100,14 @@ def npu_generate(self, step = 0 max_new_tokens = generation_config.max_new_tokens + attn_mask = model_kwargs.get("attention_mask", None) + + if self.config.model_type == "qwen2": + from transformers.generation.logits_process import MinNewTokensLengthLogitsProcessor + docoder_logits_processor = LogitsProcessorList([item for item in logits_processor \ + if not isinstance(item, MinNewTokensLengthLogitsProcessor)]) + else: + docoder_logits_processor = LogitsProcessorList() clear_benchmarks(self) @@ -126,14 +124,20 @@ def npu_generate(self, if step >= max_new_tokens: break - tic = time.time() + tic = time.perf_counter() if step == 0: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) output = self(**model_inputs, return_dict=True) - logits = output['logits'] - logits = logits[:, -1:] + # output = CausalLMOutputWithPast( + # loss=None, + # logits=torch.empty([1, 1, self.config.vocab_size]), + # past_key_values=None, + # hidden_states=None, + # attentions=None, + # ) + logits = output['logits'][:, -1:] logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :]) if generation_config.do_sample: output_ids, prob_list = deepmind_sample(logits, @@ -146,18 +150,28 @@ def npu_generate(self, else: model_inputs = { "input_ids": input_ids[:, -1:], - "past_key_values": model_kwargs["past_key_values"], + "past_key_values": past_key_values, # "position_ids": model_kwargs["position_ids"], "use_cache": True, - "attention_mask": model_kwargs.get("attention_mask", None), + "attention_mask": attn_mask, } output = output = self(**model_inputs, return_dict=True) + # output = CausalLMOutputWithPast( + # loss=None, + # logits=torch.empty([1, 1, self.config.vocab_size]), + # past_key_values=None, + # hidden_states=None, + # attentions=None, + # ) + t1 = time.perf_counter() + logits = output['logits'] - logits[:, -1, :] = logits_processor(input_ids, - logits[:, -1, :]) + logits[:, -1, :] = docoder_logits_processor(input_ids, + logits[:, -1, :]) + t2 = time.perf_counter() if generation_config.do_sample: output_ids, prob_list = deepmind_sample(logits, @@ -167,20 +181,29 @@ def npu_generate(self, output_ids = output_ids.transpose(0, 1) else: output_ids = torch.argmax(logits, dim=-1) + + t3 = time.perf_counter() input_ids = torch.cat((input_ids, output_ids), dim=-1) + t4 = time.perf_counter() step += 1 - model_kwargs = _update_model_kwargs_for_generation( - output, model_kwargs + past_key_values = output.past_key_values + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_ones((attn_mask.shape[0], 1))], dim=-1 ) - toc = time.time() + toc = time.perf_counter() + if self.first_token_time is None: self.first_token_time = toc - tic else: self.last_token_time.append(toc - tic) + print(f"Prepare input & dummy output: {(t1 - tic)*1000} ms, update attn mask: {(toc - t4)*1000} ms, " + f"argmax: {(t3 - t2)*1000} ms, cat input id: {(t4 - t3) * 1000} ms, logtis processor: {(t2 - t1)*1000} ms" + f" total generate: {(toc - t1)*1000} ms") # Stop on eos and remove content after eos if eos_token_id_set is not None: