From 4cc189a6e013531e4def9948364ef77bdcc45fbc Mon Sep 17 00:00:00 2001 From: large Date: Sun, 7 Jul 2024 15:23:25 +0000 Subject: [PATCH 01/61] init --- llama.py | 13 ++ tts.py | 50 +++++++ vllm/model_executor/models/__init__.py | 3 +- vllm/model_executor/models/ttslm.py | 198 +++++++++++++++++++++++++ 4 files changed, 263 insertions(+), 1 deletion(-) create mode 100644 llama.py create mode 100644 tts.py create mode 100644 vllm/model_executor/models/ttslm.py diff --git a/llama.py b/llama.py new file mode 100644 index 0000000000000..44dd6b2930cf7 --- /dev/null +++ b/llama.py @@ -0,0 +1,13 @@ +from vllm import LLM, SamplingParams + +llm = LLM(model='/home/largeniu/triton/llama3/Meta-Llama-3-8B-Instruct') +prompts = [ + "Hello, my name is", + "The capital of France is", +] +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +outputs = llm.generate(prompts, sampling_params) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file diff --git a/tts.py b/tts.py new file mode 100644 index 0000000000000..713966ed4203b --- /dev/null +++ b/tts.py @@ -0,0 +1,50 @@ +from vllm import LLM, SamplingParams +import torch + +tts = torch.load('/home/largeniu/ttslm/GPT.pt') + +text_emb_count = tts['emb_text.weight'].shape[0] +audio_emb_count = tts['emb_code.0.weight'].shape[0] +model_dim = tts['emb_text.weight'].shape[1] + +# append audio embeddings to text embeddings +# all_0 = text_emb + audio_emb_0 +all_0 = torch.cat([tts['emb_text.weight'], tts['emb_code.0.weight']], dim=0) + +# all_1 = zero + audio_emb_1 +all_1 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.1.weight']], dim=0) + +# all_2 = zero + audio_emb_2 +all_2 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.2.weight']], dim=0) + +# all_3 = zero + audio_emb_3 +all_3 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.3.weight']], dim=0) + +# remove text emb and audio emb in the model +tts.pop('emb_text.weight') +tts.pop('emb_code.0.weight') +tts.pop('emb_code.1.weight') +tts.pop('emb_code.2.weight') +tts.pop('emb_code.3.weight') + +# add new embeddings to the model +tts['emb_all.0.weight'] = all_0 +tts['emb_all.1.weight'] = all_1 +tts['emb_all.2.weight'] = all_2 +tts['emb_all.3.weight'] = all_3 + +# save the model +torch.save(tts, '/home/largeniu/ttslm/GPT_merged_emb.pt') + +tokenizer = torch.load('/home/largeniu/g/ChatTTS/asset/tokenizer.pt') +llm = LLM(model='/home/largeniu/ttslm', skip_tokenizer_init=True) +llm.set_tokenizer(tokenizer) +prompts = [ + "Hello, my name is", +] +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +outputs = llm.generate(prompts, sampling_params) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index a4fe18d52d608..ef1d90dd41854 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -63,7 +63,8 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM") + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "ChatTtsLlm": ("ttslm", "ChatTtsLlm"), } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py new file mode 100644 index 0000000000000..85d009286314a --- /dev/null +++ b/vllm/model_executor/models/ttslm.py @@ -0,0 +1,198 @@ +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.utils.parametrizations import weight_norm +from transformers import LlamaConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm.sequence import IntermediateTensors, SamplerOutput + + +import lzma +import numpy as np +import pybase16384 as b14 + +class ChatTtsLlm(nn.Module): + def __init__(self, + config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + # static parameters, put them in config later + self.spk_emb_dim = 192 + self.spk_KL = 8 + self.num_audio_tokens = 626 + self.num_text_tokens = 21178 + self.num_vq = 4 + + self.gpt = LlamaModel(config) + self.model_dim = self.gpt.config.hidden_size + self.emb_all = nn.ModuleList([ + nn.Embedding(self.num_audio_tokens + self.num_text_tokens, self.model_dim) for _ in range(self.num_vq) + ]) + + self.head_text = weight_norm(nn.Linear(self.model_dim, self.num_text_tokens, bias=False), name='weight') + self.head_code = nn.ModuleList([ + weight_norm(nn.Linear(self.model_dim, self.num_audio_tokens, bias=False), name='weight') for _ in range(self.num_vq) + ]) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + try: + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + except KeyError: + pass + break + else: + try: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + except KeyError: + pass + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.gpt( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return model_output + + def generate( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor, + attention_mask: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + + temperature: torch.Tensor, + max_new_token=2048, + min_new_token=0, + ) -> torch.Tensor: + attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = [] + hiddens = [] + start_idx = input_ids.shape[1] + end_idx = torch.zeros(input_ids.shape[0], dtype=torch.long) + finish = torch.zeros(input_ids.shape[0], dtype=torch.bool) + + old_temperature = temperature + + temperature = ( + temperature.unsqueeze(0) + .expand(input_ids.shape[0], -1) + .contiguous() + .view(-1, 1) + ) + + attention_mask_cache = torch.ones( + ( + input_ids.shape[0], + input_ids.shape[1] + max_new_token, + ), + dtype=torch.bool, + ) + + if attention_mask is not None: + attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_( + attention_mask + ) + + def prefill(self, + input_ids: torch.Tensor, + text_mask: torch.Tensor, + spk_emb: str) -> torch.Tensor: + emb_text = self.emb_text(input_ids[text_mask].narrow(1, 0, 1).squeeze_(1)) + text_mask_inv = text_mask.logical_not() + masked_input_ids: torch.Tensor = input_ids[text_mask_inv] + emb_code = [ + self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) + ] + emb_code = torch.stack(emb_code, 2).sum(2) + emb = torch.zeros( + (input_ids.shape[:-1]) + (emb_text.shape[-1],), + dtype=emb_text.dtype, + ) + emb[text_mask] = emb_text + emb[text_mask_inv] = emb_code.to(emb.dtype) + if spk_emb: + self._apply_spk_emb(emb, spk_emb, input_ids) + + return emb + + @staticmethod + def _decode_spk_emb(spk_emb: str) -> np.ndarray: + return np.frombuffer( + lzma.decompress( + b14.decode_from_string(spk_emb), + format=lzma.FORMAT_RAW, + filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], + ), + dtype=np.float16, + ).copy() + + def _apply_spk_emb( + self, + emb: torch.Tensor, + spk_emb: str, + input_ids: torch.Tensor, + ): + n = ( + F.normalize( + torch.from_numpy( + self._decode_spk_emb(spk_emb), + ), + p=2.0, + dim=0, + eps=1e-12, + ) + .unsqueeze_(0) + .expand(emb.size(0), -1) + .unsqueeze_(1) + .expand(emb.shape) + ) + cond = ( + input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape) + ) + torch.where(cond, n, emb, out=emb) + del cond, n \ No newline at end of file From 0ec7803fbd1ba964ec49faf7b4598098a4225e60 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Tue, 9 Jul 2024 00:10:55 +0800 Subject: [PATCH 02/61] foward logits --- llama.py | 4 +- tts.py | 52 +++---- .../layers/prarllel_logits_processor.py | 140 ++++++++++++++++++ vllm/model_executor/models/ttslm.py | 43 +++++- 4 files changed, 207 insertions(+), 32 deletions(-) create mode 100644 vllm/model_executor/layers/prarllel_logits_processor.py diff --git a/llama.py b/llama.py index 44dd6b2930cf7..01c5c0c9eff2e 100644 --- a/llama.py +++ b/llama.py @@ -2,8 +2,8 @@ llm = LLM(model='/home/largeniu/triton/llama3/Meta-Llama-3-8B-Instruct') prompts = [ - "Hello, my name is", - "The capital of France is", + "Hi my name is", + # "The capital of France is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) outputs = llm.generate(prompts, sampling_params) diff --git a/tts.py b/tts.py index 713966ed4203b..c8e2dc925eb90 100644 --- a/tts.py +++ b/tts.py @@ -1,40 +1,40 @@ from vllm import LLM, SamplingParams import torch -tts = torch.load('/home/largeniu/ttslm/GPT.pt') +# tts = torch.load('/home/largeniu/ttslm/GPT.pt') -text_emb_count = tts['emb_text.weight'].shape[0] -audio_emb_count = tts['emb_code.0.weight'].shape[0] -model_dim = tts['emb_text.weight'].shape[1] +# text_emb_count = tts['emb_text.weight'].shape[0] +# audio_emb_count = tts['emb_code.0.weight'].shape[0] +# model_dim = tts['emb_text.weight'].shape[1] -# append audio embeddings to text embeddings -# all_0 = text_emb + audio_emb_0 -all_0 = torch.cat([tts['emb_text.weight'], tts['emb_code.0.weight']], dim=0) +# # append audio embeddings to text embeddings +# # all_0 = text_emb + audio_emb_0 +# all_0 = torch.cat([tts['emb_text.weight'], tts['emb_code.0.weight']], dim=0) -# all_1 = zero + audio_emb_1 -all_1 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.1.weight']], dim=0) +# # all_1 = zero + audio_emb_1 +# all_1 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.1.weight']], dim=0) -# all_2 = zero + audio_emb_2 -all_2 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.2.weight']], dim=0) +# # all_2 = zero + audio_emb_2 +# all_2 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.2.weight']], dim=0) -# all_3 = zero + audio_emb_3 -all_3 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.3.weight']], dim=0) +# # all_3 = zero + audio_emb_3 +# all_3 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.3.weight']], dim=0) -# remove text emb and audio emb in the model -tts.pop('emb_text.weight') -tts.pop('emb_code.0.weight') -tts.pop('emb_code.1.weight') -tts.pop('emb_code.2.weight') -tts.pop('emb_code.3.weight') +# # remove text emb and audio emb in the model +# tts.pop('emb_text.weight') +# tts.pop('emb_code.0.weight') +# tts.pop('emb_code.1.weight') +# tts.pop('emb_code.2.weight') +# tts.pop('emb_code.3.weight') -# add new embeddings to the model -tts['emb_all.0.weight'] = all_0 -tts['emb_all.1.weight'] = all_1 -tts['emb_all.2.weight'] = all_2 -tts['emb_all.3.weight'] = all_3 +# # add new embeddings to the model +# tts['emb_all.0.weight'] = all_0 +# tts['emb_all.1.weight'] = all_1 +# tts['emb_all.2.weight'] = all_2 +# tts['emb_all.3.weight'] = all_3 -# save the model -torch.save(tts, '/home/largeniu/ttslm/GPT_merged_emb.pt') +# # save the model +# torch.save(tts, '/home/largeniu/ttslm/GPT_merged_emb.pt') tokenizer = torch.load('/home/largeniu/g/ChatTTS/asset/tokenizer.pt') llm = LLM(model='/home/largeniu/ttslm', skip_tokenizer_init=True) diff --git a/vllm/model_executor/layers/prarllel_logits_processor.py b/vllm/model_executor/layers/prarllel_logits_processor.py new file mode 100644 index 0000000000000..f56f74eb3895f --- /dev/null +++ b/vllm/model_executor/layers/prarllel_logits_processor.py @@ -0,0 +1,140 @@ +"""A layer that compute logits from hidden_stats.""" +import inspect +from typing import Optional, List + +import torch +import torch.nn as nn + +from vllm.distributed import tensor_model_parallel_gather +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata + + +class ParallelLogitsProcessor(nn.Module): + """Process logits and apply logits processors from sampling metadata. + + This layer does the following: + 1. Gather logits from model hidden_states. + 2. Scale logits if needed. + 3. Apply logits processors (if any). + """ + + def __init__(self, + vocab_size: int, + num_logits_processors: int = 0, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super().__init__() + self.scale = scale + self.vocab_size = vocab_size + self.num_logits_processors = num_logits_processors + # Whether the input is logits (default is hidden states). + self.logits_as_input = logits_as_input + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + # Soft cap the logits. Used in Gemma 2. + self.soft_cap = soft_cap + + def forward( + self, + lm_heads: List[nn.Linear], + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.logits_as_input: + logits = hidden_states + else: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) + + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, lm_heads) + if logits is not None: + if self.soft_cap is not None: + logits = logits / self.soft_cap + logits = torch.tanh(logits) + logits = logits * self.soft_cap + + if self.scale != 1.0: + logits *= self.scale + + # Apply logits processors (if any). + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def _get_logits(self, hidden_states: torch.Tensor, + lm_heads: List[nn.Linear]) -> torch.Tensor: + # Get the logits for the next tokens. + logits_all = torch.zeros(self.num_logits_processors, hidden_states.size(0), self.vocab_size, device=hidden_states.device, dtype=hidden_states.dtype) + for i, lm_head in enumerate(lm_heads): + logits = lm_head(hidden_states) + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + logits_all[i] = logits + logits_all = logits_all.permute(1, 0, 2) + return logits_all + + def extra_repr(self) -> str: + s = f"vocab_size={self.vocab_size}" + s += f", forg_vocab_size={self.org_vocab_size}" + s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" + return s + + +def _prune_hidden_states( + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + return hidden_states.index_select(0, + sampling_metadata.selected_token_indices) + + +def _apply_logits_processors( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + found_logits_processors = False + logits_processed = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + logits_processors = sampling_params.logits_processors + if logits_processors: + found_logits_processors = True + + for seq_id, logits_row_idx in zip(seq_ids, + seq_group.sample_indices): + logits_row = logits[logits_row_idx] + past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids + prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids + + for logits_processor in logits_processors: + parameters = inspect.signature(logits_processor).parameters + if len(parameters) == 3: + logits_row = logits_processor(prompt_tokens_ids, + past_tokens_ids, + logits_row) + else: + logits_row = logits_processor(past_tokens_ids, + logits_row) + + logits[logits_row_idx] = logits_row + + logits_processed += len(seq_group.sample_indices) + len( + seq_group.prompt_logprob_indices) + + if found_logits_processors: + # verifies that no rows in logits were missed unexpectedly + assert logits_processed == logits.shape[0] + return logits diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 85d009286314a..90af817efdfd4 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -8,10 +8,12 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig +from vllm.model_executor.layers.prarllel_logits_processor import ParallelLogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.model_loader.weight_utils import default_weight_loader - +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -43,6 +45,8 @@ def __init__(self, self.head_code = nn.ModuleList([ weight_norm(nn.Linear(self.model_dim, self.num_audio_tokens, bias=False), name='weight') for _ in range(self.num_vq) ]) + self.logits_processor = ParallelLogitsProcessor(self.num_audio_tokens, self.num_vq) + self.sampler = Sampler() def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -76,7 +80,37 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): pass def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + code_emb = [ + self.emb_all[i](input_ids) + for i in range(self.num_vq) + ] + emb = torch.stack(code_emb, 1).sum(1) + return emb + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + # # compute logits for each vq + # # hidden: [token_count, model_dim] + # # logits: [token_count, num_audio_tokens, num_vq] + # logits = torch.zeros(hidden_states.size(0), self.num_audio_tokens, self.num_vq, dtype=hidden_states.dtype) + # for num_vq_iter in range(self.num_vq): + # x = self.head_code[num_vq_iter](hidden_states) + # logits[:, :, num_vq_iter] = x + + # # logits: [num_audio_tokens, num_vq] + # logits = logits.narrow(0, -1, 1).squeeze_(0) + # logits = logits.permute(1, 0) + # return logits + logits = self.logits_processor(self.head_code, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens def forward( self, @@ -87,13 +121,14 @@ def forward( intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.get_input_embeddings(input_ids) model_output = self.gpt( input_ids=input_ids, + inputs_embeds=hidden_states, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors ) return model_output From 43e921468bcb4021956bde80e5524cc9dc84b98c Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Wed, 10 Jul 2024 00:38:50 +0800 Subject: [PATCH 03/61] foward logits --- tts.py | 4 +- ...sor.py => multi_heads_logits_processor.py} | 2 +- .../layers/multi_heads_sampler.py | 54 +++++++++++++++++++ vllm/model_executor/models/ttslm.py | 12 ++--- vllm/sequence.py | 1 + vllm/worker/model_runner.py | 6 ++- 6 files changed, 69 insertions(+), 10 deletions(-) rename vllm/model_executor/layers/{prarllel_logits_processor.py => multi_heads_logits_processor.py} (99%) create mode 100644 vllm/model_executor/layers/multi_heads_sampler.py diff --git a/tts.py b/tts.py index c8e2dc925eb90..e565b4a15c109 100644 --- a/tts.py +++ b/tts.py @@ -37,12 +37,12 @@ # torch.save(tts, '/home/largeniu/ttslm/GPT_merged_emb.pt') tokenizer = torch.load('/home/largeniu/g/ChatTTS/asset/tokenizer.pt') -llm = LLM(model='/home/largeniu/ttslm', skip_tokenizer_init=True) +llm = LLM(model='/home/largeniu/ttslm', skip_tokenizer_init=True, dtype=torch.float32) llm.set_tokenizer(tokenizer) prompts = [ "Hello, my name is", ] -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams(temperature=1) outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt diff --git a/vllm/model_executor/layers/prarllel_logits_processor.py b/vllm/model_executor/layers/multi_heads_logits_processor.py similarity index 99% rename from vllm/model_executor/layers/prarllel_logits_processor.py rename to vllm/model_executor/layers/multi_heads_logits_processor.py index f56f74eb3895f..5b26929d393d7 100644 --- a/vllm/model_executor/layers/prarllel_logits_processor.py +++ b/vllm/model_executor/layers/multi_heads_logits_processor.py @@ -11,7 +11,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata -class ParallelLogitsProcessor(nn.Module): +class MultiHeadLogitsProcessor(nn.Module): """Process logits and apply logits processors from sampling metadata. This layer does the following: diff --git a/vllm/model_executor/layers/multi_heads_sampler.py b/vllm/model_executor/layers/multi_heads_sampler.py new file mode 100644 index 0000000000000..c697729334c91 --- /dev/null +++ b/vllm/model_executor/layers/multi_heads_sampler.py @@ -0,0 +1,54 @@ +"""A layer that samples the next tokens from the model's outputs.""" +import itertools +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.model_executor.layers.ops.sample import sample as sample_triton +from vllm.model_executor.sampling_metadata import (SamplingMetadata, + SamplingTensors, + SequenceGroupToSample) +from vllm.sampling_params import SamplingType +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SampleLogprobs, SamplerOutput, + SequenceOutput) +from vllm.model_executor.layers.sampler import Sampler, _apply_top_k_top_p, _sample, _get_logprobs, _build_sampler_output + +# (num_token_ids, num_parent_ids) per sequence group. +SampleResultType = List[Tuple[List[int], List[int]]] + + +class MultiheadsSampler(nn.Module): + def __init__(self, num_heads: int): + super().__init__() + + # Whether or not the SamplerOutput should have on-device tensors + # containing the sampled token ids and probabilities. This is used by + # speculative decoding. + self.num_heads = num_heads + self.include_gpu_probs_tensor = False + self.heads = nn.ModuleList([Sampler() for _ in range(num_heads)]) + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + # Sample from each head + head_logits = logits.permute(1, 0, 2) + output0 = self.heads[0](head_logits[0], sampling_metadata) + for i in range(self.num_heads - 1): + output = self.heads[i + 1](head_logits[i], sampling_metadata) + self.merge_sample_results(output0, output) + + return output0 + + def merge_sample_results( + self, + source: SamplerOutput, + target: SamplerOutput, + ): + for o_a, o_b in zip(source.outputs, target.outputs): + for s_a, s_b in zip(o_a.samples, o_b.samples): + s_a.output_tokens.append(s_b.output_token) \ No newline at end of file diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 90af817efdfd4..6cb382ec8ba0c 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -8,9 +8,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.model_executor.layers.prarllel_logits_processor import ParallelLogitsProcessor +from vllm.model_executor.layers.multi_heads_logits_processor import MultiHeadLogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.multi_heads_sampler import MultiheadsSampler from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -45,8 +45,8 @@ def __init__(self, self.head_code = nn.ModuleList([ weight_norm(nn.Linear(self.model_dim, self.num_audio_tokens, bias=False), name='weight') for _ in range(self.num_vq) ]) - self.logits_processor = ParallelLogitsProcessor(self.num_audio_tokens, self.num_vq) - self.sampler = Sampler() + self.logits_processor = MultiHeadLogitsProcessor(self.num_audio_tokens, self.num_vq) + self.sampler = MultiheadsSampler(self.num_vq) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -81,10 +81,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: code_emb = [ - self.emb_all[i](input_ids) + self.emb_all[i](input_ids[:,i]) for i in range(self.num_vq) ] - emb = torch.stack(code_emb, 1).sum(1) + emb = torch.stack(code_emb, 2).sum(2) return emb def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/sequence.py b/vllm/sequence.py index d200115aa0921..818d5c3e74076 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -701,6 +701,7 @@ def __init__( self.parent_seq_id = parent_seq_id self.output_token = output_token self.logprobs = logprobs + self.output_tokens = [output_token] def __repr__(self) -> str: return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d0c82d6bbedf3..e8ceeae36f07f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -173,7 +173,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, multimodal_config: Optional[MultiModalConfig] = None, - return_hidden_states: bool = False, + return_hidden_states: bool = True, ): self.model_config = model_config self.parallel_config = parallel_config @@ -676,6 +676,10 @@ def _prepare_model_input_tensors( input_tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, device=self.device) + # if self.model_config.hf_config have num_token_lens attribute + if hasattr(self.model_config.hf_config, "num_token_len"): + if self.model_config.hf_config.num_token_len > 1: + input_tokens_tensor = input_tokens_tensor.unsqueeze(dim=1).expand(len(input_tokens), self.model_config.hf_config.num_token_len) input_positions_tensor = torch.tensor(input_positions, dtype=torch.long, device=self.device) From 41e16557d354833ed368158daa51d1c7c8ea8393 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 10 Jul 2024 11:12:53 +0000 Subject: [PATCH 04/61] update --- llama.py => testllama.py | 0 tts.py | 6 ++--- vllm/engine/output_processor/single_step.py | 4 ++-- vllm/sequence.py | 15 ++++++++++-- .../tokenizer_group/tokenizer_group.py | 2 +- vllm/worker/model_runner.py | 23 ++++++++++++++----- 6 files changed, 35 insertions(+), 15 deletions(-) rename llama.py => testllama.py (100%) diff --git a/llama.py b/testllama.py similarity index 100% rename from llama.py rename to testllama.py diff --git a/tts.py b/tts.py index e565b4a15c109..e5b5f1454fb66 100644 --- a/tts.py +++ b/tts.py @@ -36,11 +36,9 @@ # # save the model # torch.save(tts, '/home/largeniu/ttslm/GPT_merged_emb.pt') -tokenizer = torch.load('/home/largeniu/g/ChatTTS/asset/tokenizer.pt') -llm = LLM(model='/home/largeniu/ttslm', skip_tokenizer_init=True, dtype=torch.float32) -llm.set_tokenizer(tokenizer) +llm = LLM(model='/home/zhn/ttslm') prompts = [ - "Hello, my name is", + "[Stts][empty_spk][speed_5]your text one[Ptts]", ] sampling_params = SamplingParams(temperature=1) outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index fa672e1feda92..28ad41f9a893d 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -102,14 +102,14 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, for child_sample in child_samples[:-1]: new_child_seq_id: int = next(self.seq_counter) child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_token, + child.append_token_id(child_sample.output_tokens, child_sample.logprobs) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, + parent.append_token_id(last_child_sample.output_tokens, last_child_sample.logprobs) child_seqs.append((parent, parent)) diff --git a/vllm/sequence.py b/vllm/sequence.py index 818d5c3e74076..62f181e63ad0e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -115,6 +115,7 @@ def __init__( self, prompt_token_ids: List[int], output_token_ids: Optional[List[int]] = None, + num_token_head: int = 1, ) -> None: self._prompt_token_ids: List[int] = list(prompt_token_ids) self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) @@ -315,9 +316,19 @@ def append_token_id( token_id: int, logprobs: Dict[int, Logprob], ) -> None: - assert token_id in logprobs + # assert token_id in logprobs self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) + if isinstance(token_id, List): + self.data.append_token_id(token_id, logprobs[token_id[0]].logprob) + else: + self.data.append_token_id(token_id, logprobs[token_id].logprob) + + def append_token_ids( + self, + token_ids: List[int], + logprobs: Dict[int, Logprob], + ) -> None: + pass def get_len(self) -> int: return self.data.get_len() diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 9614f01d2b955..555faa9798b41 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -51,7 +51,7 @@ def encode(self, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - ret = tokenizer.encode(prompt) + ret = tokenizer.encode(prompt, add_special_tokens=False) self._raise_if_input_too_long(ret, lora_request) return ret diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e8ceeae36f07f..f420af2055c9e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -673,13 +673,19 @@ def _prepare_model_input_tensors( dtype=query_start_loc.dtype, out=query_start_loc[1:]) - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) # if self.model_config.hf_config have num_token_lens attribute if hasattr(self.model_config.hf_config, "num_token_len"): - if self.model_config.hf_config.num_token_len > 1: - input_tokens_tensor = input_tokens_tensor.unsqueeze(dim=1).expand(len(input_tokens), self.model_config.hf_config.num_token_len) + multi_head_input_tokens = [] + # duplicate input_tokens n times + for i in range(self.model_config.hf_config.num_token_len): + multi_head_input_tokens.append(input_tokens) + input_tokens_tensor = torch.tensor(multi_head_input_tokens, + dtype=torch.long, + device=self.device).permute(1, 0) + else: + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) input_positions_tensor = torch.tensor(input_positions, dtype=torch.long, device=self.device) @@ -938,7 +944,12 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + + # if self.model_config.hf_config have num_token_lens attribute + if hasattr(self.model_config.hf_config, "num_token_len"): + input_tokens = torch.zeros(max_batch_size, self.model_config.hf_config.num_token_len, dtype=torch.long).cuda() + else: + input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) From 9c4be7bf974a8a932453e069d53e790705db9ff6 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Thu, 11 Jul 2024 00:09:46 +0800 Subject: [PATCH 05/61] genrate multihead output token --- tts.py | 4 +- vllm/engine/llm_engine.py | 4 + vllm/engine/output_processor/single_step.py | 20 ++++- vllm/engine/output_processor/stop_checker.py | 16 +++- vllm/model_executor/models/ttslm.py | 90 +++----------------- vllm/worker/model_runner.py | 25 ++---- 6 files changed, 56 insertions(+), 103 deletions(-) diff --git a/tts.py b/tts.py index e5b5f1454fb66..bb712ad1cc902 100644 --- a/tts.py +++ b/tts.py @@ -36,11 +36,11 @@ # # save the model # torch.save(tts, '/home/largeniu/ttslm/GPT_merged_emb.pt') -llm = LLM(model='/home/zhn/ttslm') +llm = LLM(model='/home/largeniu/ttslm', gpu_memory_utilization=0.5) prompts = [ "[Stts][empty_spk][speed_5]your text one[Ptts]", ] -sampling_params = SamplingParams(temperature=1) +sampling_params = SamplingParams(temperature=1, detokenize=False) outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index de7604ece7c31..b49054e459ed3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -548,6 +548,10 @@ def process_model_inputs( lora_request=lora_request) else: prompt_token_ids = inputs["prompt_token_ids"] + + if hasattr(self.model_config.hf_config, "num_output_head"): + # duplicate the prompt_token_ids for each head + prompt_token_ids = [[i] * self.model_config.hf_config.num_output_head for i in prompt_token_ids] llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 28ad41f9a893d..db7bfc2efbdc2 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -102,15 +102,27 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, for child_sample in child_samples[:-1]: new_child_seq_id: int = next(self.seq_counter) child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_tokens, - child_sample.logprobs) + + # if output_tokens more than one, it's has multi-head output + if len(child_sample.output_tokens) > 1: + child.append_token_id(child_sample.output_tokens, + child_sample.logprobs) + else: + child.append_token_id(child_sample.output_token, + child_sample.logprobs) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_tokens, - last_child_sample.logprobs) + + # if output_tokens more than one, it's has multi-head output + if len(child_sample.output_tokens) > 1: + parent.append_token_id(last_child_sample.output_tokens, + last_child_sample.logprobs) + else: + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs) child_seqs.append((parent, parent)) for seq, _ in child_seqs: diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 96f0d1142611b..0dbf29659d5fe 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Callable, List, Optional, Union from transformers import PreTrainedTokenizer @@ -27,6 +27,14 @@ def _get_max_model_len(self, lora_req: Optional[LoRARequest]): else: return self._max_model_len + def token_equal_or_in(self, + eos_token_id: int, + last_token_id: Union[int, List[int]]) -> bool: + if isinstance(last_token_id, list): + return eos_token_id in last_token_id + else: + return eos_token_id == last_token_id + def maybe_stop_sequence( self, seq: Sequence, @@ -47,7 +55,7 @@ def maybe_stop_sequence( # Check if the sequence has generated the EOS token. if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): + and self.token_equal_or_in(seq.eos_token_id, seq.get_last_token_id())): # Remove the last EOS token unless explicitly specified # This prevents unintended exposure of the EOS token if new_char_count and ( @@ -59,7 +67,9 @@ def maybe_stop_sequence( # Check if a stop token was encountered. # This assumes a single token produced per step. last_token_id = seq.get_last_token_id() - if last_token_id in sampling_params.stop_token_ids: + has_stop_token = any(self.token_equal_or_in(x, last_token_id) + for x in sampling_params.stop_token_ids) + if has_stop_token: if new_char_count and ( not sampling_params.include_stop_str_in_output): # Remove last token diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 6cb382ec8ba0c..f9421c0081d97 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -8,6 +8,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig +from vllm.inputs import INPUT_REGISTRY +from vllm.inputs.registry import InputContext from vllm.model_executor.layers.multi_heads_logits_processor import MultiHeadLogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.multi_heads_sampler import MultiheadsSampler @@ -21,6 +23,17 @@ import numpy as np import pybase16384 as b14 +def dummy_data_for_ttsllm(ctx: InputContext, seq_len: int): + + from vllm.sequence import SequenceData + + + dummy_seq_data = SequenceData([[0] * ctx.model_config.hf_config.num_output_head] * seq_len) + dummy_multi_modal_data = None + + return dummy_seq_data, dummy_multi_modal_data + +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ttsllm) class ChatTtsLlm(nn.Module): def __init__(self, config: LlamaConfig, @@ -89,18 +102,6 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - # # compute logits for each vq - # # hidden: [token_count, model_dim] - # # logits: [token_count, num_audio_tokens, num_vq] - # logits = torch.zeros(hidden_states.size(0), self.num_audio_tokens, self.num_vq, dtype=hidden_states.dtype) - # for num_vq_iter in range(self.num_vq): - # x = self.head_code[num_vq_iter](hidden_states) - # logits[:, :, num_vq_iter] = x - - # # logits: [num_audio_tokens, num_vq] - # logits = logits.narrow(0, -1, 1).squeeze_(0) - # logits = logits.permute(1, 0) - # return logits logits = self.logits_processor(self.head_code, hidden_states, sampling_metadata) return logits @@ -132,69 +133,6 @@ def forward( ) return model_output - def generate( - self, - input_ids: torch.Tensor, - input_embeds: torch.Tensor, - attention_mask: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - - temperature: torch.Tensor, - max_new_token=2048, - min_new_token=0, - ) -> torch.Tensor: - attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = [] - hiddens = [] - start_idx = input_ids.shape[1] - end_idx = torch.zeros(input_ids.shape[0], dtype=torch.long) - finish = torch.zeros(input_ids.shape[0], dtype=torch.bool) - - old_temperature = temperature - - temperature = ( - temperature.unsqueeze(0) - .expand(input_ids.shape[0], -1) - .contiguous() - .view(-1, 1) - ) - - attention_mask_cache = torch.ones( - ( - input_ids.shape[0], - input_ids.shape[1] + max_new_token, - ), - dtype=torch.bool, - ) - - if attention_mask is not None: - attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_( - attention_mask - ) - - def prefill(self, - input_ids: torch.Tensor, - text_mask: torch.Tensor, - spk_emb: str) -> torch.Tensor: - emb_text = self.emb_text(input_ids[text_mask].narrow(1, 0, 1).squeeze_(1)) - text_mask_inv = text_mask.logical_not() - masked_input_ids: torch.Tensor = input_ids[text_mask_inv] - emb_code = [ - self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) - ] - emb_code = torch.stack(emb_code, 2).sum(2) - emb = torch.zeros( - (input_ids.shape[:-1]) + (emb_text.shape[-1],), - dtype=emb_text.dtype, - ) - emb[text_mask] = emb_text - emb[text_mask_inv] = emb_code.to(emb.dtype) - if spk_emb: - self._apply_spk_emb(emb, spk_emb, input_ids) - - return emb - @staticmethod def _decode_spk_emb(spk_emb: str) -> np.ndarray: return np.frombuffer( @@ -230,4 +168,4 @@ def _apply_spk_emb( input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape) ) torch.where(cond, n, emb, out=emb) - del cond, n \ No newline at end of file + del cond, n diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f420af2055c9e..1805824bd58e8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -673,19 +673,10 @@ def _prepare_model_input_tensors( dtype=query_start_loc.dtype, out=query_start_loc[1:]) - # if self.model_config.hf_config have num_token_lens attribute - if hasattr(self.model_config.hf_config, "num_token_len"): - multi_head_input_tokens = [] - # duplicate input_tokens n times - for i in range(self.model_config.hf_config.num_token_len): - multi_head_input_tokens.append(input_tokens) - input_tokens_tensor = torch.tensor(multi_head_input_tokens, - dtype=torch.long, - device=self.device).permute(1, 0) - else: - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, dtype=torch.long, device=self.device) @@ -945,11 +936,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - # if self.model_config.hf_config have num_token_lens attribute - if hasattr(self.model_config.hf_config, "num_token_len"): - input_tokens = torch.zeros(max_batch_size, self.model_config.hf_config.num_token_len, dtype=torch.long).cuda() - else: - input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ + .dummy_data_for_profiling(self.model_config, max_batch_size) + input_tokens = torch.tensor(seq_data.prompt_token_ids, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) From d4886346862374ff298693965b83829bda00dc51 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 11 Jul 2024 08:51:53 +0000 Subject: [PATCH 06/61] update spk_emb --- tts.py | 11 ++-- vllm/engine/output_processor/single_step.py | 2 +- vllm/model_executor/models/ttslm.py | 57 ++++++++----------- vllm/multimodal/registry.py | 14 ++++- vllm/multimodal/speech.py | 63 +++++++++++++++++++++ 5 files changed, 108 insertions(+), 39 deletions(-) create mode 100644 vllm/multimodal/speech.py diff --git a/tts.py b/tts.py index bb712ad1cc902..31f65bb46b6d6 100644 --- a/tts.py +++ b/tts.py @@ -1,6 +1,6 @@ from vllm import LLM, SamplingParams import torch - +torch.random.manual_seed(999) # tts = torch.load('/home/largeniu/ttslm/GPT.pt') # text_emb_count = tts['emb_text.weight'].shape[0] @@ -36,11 +36,14 @@ # # save the model # torch.save(tts, '/home/largeniu/ttslm/GPT_merged_emb.pt') -llm = LLM(model='/home/largeniu/ttslm', gpu_memory_utilization=0.5) +llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5) prompts = [ - "[Stts][empty_spk][speed_5]your text one[Ptts]", + { + "prompt": "[Stts][empty_spk][speed_5]your text one[Ptts]", + "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, + } ] -sampling_params = SamplingParams(temperature=1, detokenize=False) +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048) outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index db7bfc2efbdc2..aaf21c782b313 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -117,7 +117,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, last_child_sample = child_samples[-1] # if output_tokens more than one, it's has multi-head output - if len(child_sample.output_tokens) > 1: + if len(last_child_sample.output_tokens) > 1: parent.append_token_id(last_child_sample.output_tokens, last_child_sample.logprobs) else: diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index f9421c0081d97..7e9d0fdb5fe84 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -16,6 +16,8 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.speech import SpeechPlugin from vllm.sequence import IntermediateTensors, SamplerOutput @@ -29,10 +31,11 @@ def dummy_data_for_ttsllm(ctx: InputContext, seq_len: int): dummy_seq_data = SequenceData([[0] * ctx.model_config.hf_config.num_output_head] * seq_len) - dummy_multi_modal_data = None + dummy_multi_modal_data = {"speech": SpeechPlugin.sample_random_speaker()} return dummy_seq_data, dummy_multi_modal_data +@MULTIMODAL_REGISTRY.register_speech_input_mapper() @INPUT_REGISTRY.register_dummy_data(dummy_data_for_ttsllm) class ChatTtsLlm(nn.Module): def __init__(self, @@ -47,6 +50,7 @@ def __init__(self, self.num_audio_tokens = 626 self.num_text_tokens = 21178 self.num_vq = 4 + self.spk_emb_token_id = 21143 self.gpt = LlamaModel(config) self.model_dim = self.gpt.config.hidden_size @@ -111,6 +115,15 @@ def sample( sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) + for output in next_tokens.outputs: + for sample in output.samples: + sample.output_token += self.num_text_tokens + for i in range(self.num_vq): + sample.output_tokens[i] += self.num_text_tokens + dic = {} + for k,v in sample.logprobs.items(): + dic[k + self.num_text_tokens] = v + sample.logprobs = dic return next_tokens def forward( @@ -121,8 +134,12 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.get_input_embeddings(input_ids) + spk_emb = kwargs.pop("speech", None) + if spk_emb is not None: + self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) model_output = self.gpt( input_ids=input_ids, inputs_embeds=hidden_states, @@ -133,39 +150,13 @@ def forward( ) return model_output - @staticmethod - def _decode_spk_emb(spk_emb: str) -> np.ndarray: - return np.frombuffer( - lzma.decompress( - b14.decode_from_string(spk_emb), - format=lzma.FORMAT_RAW, - filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], - ), - dtype=np.float16, - ).copy() - - def _apply_spk_emb( + def apply_spk_emb( self, emb: torch.Tensor, - spk_emb: str, + spk_emb: torch.Tensor, + attn_metadata: AttentionMetadata, input_ids: torch.Tensor, ): - n = ( - F.normalize( - torch.from_numpy( - self._decode_spk_emb(spk_emb), - ), - p=2.0, - dim=0, - eps=1e-12, - ) - .unsqueeze_(0) - .expand(emb.size(0), -1) - .unsqueeze_(1) - .expand(emb.shape) - ) - cond = ( - input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape) - ) - torch.where(cond, n, emb, out=emb) - del cond, n + assert emb.size(1) == spk_emb.size(1) + assert attn_metadata.seq_lens_tensor.size(0) == spk_emb.size(0) + pass diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index e0716bbf15715..4c29ca6433b77 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -5,6 +5,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger +from vllm.multimodal.speech import SpeechPlugin from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, MultiModalPlugin, MultiModalTokensCalc) @@ -21,7 +22,7 @@ class MultiModalRegistry: The registry handles both external and internal data input. """ - DEFAULT_PLUGINS = (ImagePlugin(), ) + DEFAULT_PLUGINS = (ImagePlugin(), SpeechPlugin()) def __init__( self, @@ -70,6 +71,17 @@ def register_image_input_mapper( See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ return self.register_input_mapper("image", mapper) + + def register_speech_input_mapper( + self, + mapper: Optional[MultiModalInputMapper] = None, + ): + """ + Register an input mapper for image data to a model class. + + See :meth:`MultiModalPlugin.register_input_mapper` for more details. + """ + return self.register_input_mapper("speech", mapper) def map_input(self, model_config: ModelConfig, data: MultiModalDataDict) -> MultiModalInputs: diff --git a/vllm/multimodal/speech.py b/vllm/multimodal/speech.py new file mode 100644 index 0000000000000..8390e2783b36e --- /dev/null +++ b/vllm/multimodal/speech.py @@ -0,0 +1,63 @@ +from functools import lru_cache +import lzma +from typing import List, Optional, Tuple, TypeVar + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import PreTrainedTokenizerBase + +from vllm.config import ModelConfig +from vllm.inputs.registry import InputContext +from vllm.logger import init_logger +from vllm.transformers_utils.image_processor import get_image_processor +from vllm.transformers_utils.tokenizer import get_tokenizer + +from .base import MultiModalInputs, MultiModalPlugin +import pybase16384 as b14 + +class SpeechPlugin(MultiModalPlugin): + + def get_data_key(self) -> str: + return "speech" + + def _decode_spk_emb(self, spk_emb: str) -> np.ndarray: + return np.frombuffer( + lzma.decompress( + b14.decode_from_string(spk_emb), + format=lzma.FORMAT_RAW, + filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], + ), + dtype=np.float16, + ).copy() + + def _default_input_mapper(self, ctx: InputContext, + data: object) -> MultiModalInputs: + model_config = ctx.model_config + if isinstance(data, str): + n =F.normalize( + torch.from_numpy(self._decode_spk_emb(data)), + p=2.0, + dim=0, + eps=1e-12, + ).unsqueeze_(0) + + return MultiModalInputs({"speech": n}) + elif isinstance(data, torch.Tensor): + raise NotImplementedError("Embeddings input is not supported yet") + + raise TypeError(f"Invalid image type: {type(data)}") + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + return 3000 + + @staticmethod + def sample_random_speaker() -> str: + return b14.encode_to_string( + lzma.compress( + np.random.randn(768).astype(np.float16).tobytes(), + format=lzma.FORMAT_RAW, + filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}] + ) + ) From 3cad09b13211df2b176dabe76fecbc87ab88dab1 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Thu, 11 Jul 2024 21:57:07 +0800 Subject: [PATCH 07/61] fix bug --- tts.py | 6 +++--- vllm/model_executor/layers/multi_heads_sampler.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tts.py b/tts.py index 31f65bb46b6d6..31c474eb0835b 100644 --- a/tts.py +++ b/tts.py @@ -36,14 +36,14 @@ # # save the model # torch.save(tts, '/home/largeniu/ttslm/GPT_merged_emb.pt') -llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5) +llm = LLM(model='/home/largeniu/ttslm', gpu_memory_utilization=0.5, enforce_eager=True) prompts = [ { - "prompt": "[Stts][empty_spk][speed_5]your text one[Ptts]", + "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, } ] -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048) +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt diff --git a/vllm/model_executor/layers/multi_heads_sampler.py b/vllm/model_executor/layers/multi_heads_sampler.py index c697729334c91..21b490f674bcb 100644 --- a/vllm/model_executor/layers/multi_heads_sampler.py +++ b/vllm/model_executor/layers/multi_heads_sampler.py @@ -39,7 +39,7 @@ def forward( head_logits = logits.permute(1, 0, 2) output0 = self.heads[0](head_logits[0], sampling_metadata) for i in range(self.num_heads - 1): - output = self.heads[i + 1](head_logits[i], sampling_metadata) + output = self.heads[i + 1](head_logits[i + 1], sampling_metadata) self.merge_sample_results(output0, output) return output0 From 1bf42a2b458776a09d1246b25c2e770af9baf820 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Fri, 12 Jul 2024 07:49:00 +0000 Subject: [PATCH 08/61] add output hiddens --- tts.py | 13 +++++++++---- vllm/engine/output_processor/single_step.py | 2 ++ vllm/model_executor/models/ttslm.py | 2 +- vllm/outputs.py | 4 ++++ vllm/sequence.py | 3 +++ vllm/worker/model_runner.py | 2 ++ 6 files changed, 21 insertions(+), 5 deletions(-) diff --git a/tts.py b/tts.py index 31c474eb0835b..6de8abdc46c2f 100644 --- a/tts.py +++ b/tts.py @@ -36,16 +36,21 @@ # # save the model # torch.save(tts, '/home/largeniu/ttslm/GPT_merged_emb.pt') -llm = LLM(model='/home/largeniu/ttslm', gpu_memory_utilization=0.5, enforce_eager=True) +llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, enforce_eager=True, dtype=torch.float32) prompts = [ { "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, + }, + { + "prompt": "[Stts][empty_spk][speed_5]Your text two[Ptts]", + "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, } ] sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) outputs = llm.generate(prompts, sampling_params) for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(output.prompt) + token_ids = output.outputs[0].token_ids + for token_id in token_ids: + print([x - 21178 for x in token_id]) \ No newline at end of file diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index aaf21c782b313..5084549561f39 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -124,6 +124,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, parent.append_token_id(last_child_sample.output_token, last_child_sample.logprobs) child_seqs.append((parent, parent)) + + parent.output_hiddens.append(outputs.hidden_state.clone()) for seq, _ in child_seqs: if seq_group.sampling_params.detokenize and self.detokenizer: diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 7e9d0fdb5fe84..944c98e304524 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -137,7 +137,7 @@ def forward( **kwargs: object ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.get_input_embeddings(input_ids) - spk_emb = kwargs.pop("speech", None) + spk_emb = kwargs.get("speech", None) if spk_emb is not None: self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) model_output = self.gpt( diff --git a/vllm/outputs.py b/vllm/outputs.py index 4cb7f06bdb8c7..dda5c6934b919 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from typing import List, Optional, Tuple, Union +import torch + from vllm.lora.request import LoRARequest from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceStatus) @@ -31,6 +33,7 @@ class CompletionOutput: token_ids: Tuple[int, ...] cumulative_logprob: float logprobs: Optional[SampleLogprobs] + hiddens: Optional[torch.Tensor] = None finish_reason: Optional[str] = None stop_reason: Union[int, str, None] = None lora_request: Optional[LoRARequest] = None @@ -129,6 +132,7 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": seq.get_output_token_ids(), seq.get_cumulative_logprob(), seq.output_logprobs if include_logprobs else None, + seq.output_hiddens, SequenceStatus.get_finished_reason(seq.status), seq.stop_reason) for seq in top_n_seqs ] diff --git a/vllm/sequence.py b/vllm/sequence.py index 62f181e63ad0e..c1acef5644b3d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -257,6 +257,7 @@ def __init__( self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] + self.output_hiddens: List[torch.Tensor] = [] self.output_text = "" self.status = SequenceStatus.WAITING @@ -747,10 +748,12 @@ def __init__( self, samples: List[SequenceOutput], prompt_logprobs: Optional[PromptLogprobs], + hidden_state: Optional[torch.Tensor] = None, ) -> None: self.samples = samples # Prompt logprob for each prompt query token. self.prompt_logprobs = prompt_logprobs + self.hidden_state = hidden_state def __repr__(self) -> str: return (f"CompletionSequenceGroupOutput(samples={self.samples}, " diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1805824bd58e8..050a5cf0e6894 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1282,6 +1282,8 @@ def execute_model( hidden_states = hidden_or_intermediate_states output.hidden_states = hidden_states + for i, o in enumerate(output): + o.hidden_state = hidden_states[i] return [output] From 2918e8b375ea56f086fdb76fa6c9fd9400a3c257 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Fri, 12 Jul 2024 10:09:43 +0000 Subject: [PATCH 09/61] add output hiddens --- testllama.py | 2 +- .../model_executor/layers/logits_processor.py | 13 +- .../layers/multi_heads_logits_processor.py | 140 ------------------ .../layers/multi_heads_sampler.py | 54 ------- vllm/model_executor/models/ttslm.py | 31 +++- 5 files changed, 34 insertions(+), 206 deletions(-) delete mode 100644 vllm/model_executor/layers/multi_heads_logits_processor.py delete mode 100644 vllm/model_executor/layers/multi_heads_sampler.py diff --git a/testllama.py b/testllama.py index 01c5c0c9eff2e..a4b29088ddd68 100644 --- a/testllama.py +++ b/testllama.py @@ -1,6 +1,6 @@ from vllm import LLM, SamplingParams -llm = LLM(model='/home/largeniu/triton/llama3/Meta-Llama-3-8B-Instruct') +llm = LLM(model='/home/zhn/g/Meta-Llama-3-8B-Instruct') prompts = [ "Hi my name is", # "The capital of France is", diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index f6fcf49ef464b..8bb5c1cc2bb93 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -1,6 +1,6 @@ """A layer that compute logits from hidden_stats.""" import inspect -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -70,12 +70,15 @@ def forward( return logits def _get_logits(self, hidden_states: torch.Tensor, - lm_head: VocabParallelEmbedding, + lm_head: Union[VocabParallelEmbedding, nn.Linear], embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: # Get the logits for the next tokens. - logits = lm_head.linear_method.apply(lm_head, - hidden_states, - bias=embedding_bias) + if isinstance(lm_head, nn.Linear): + logits = lm_head(hidden_states) + else: + logits = lm_head.linear_method.apply(lm_head, + hidden_states, + bias=embedding_bias) logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: diff --git a/vllm/model_executor/layers/multi_heads_logits_processor.py b/vllm/model_executor/layers/multi_heads_logits_processor.py deleted file mode 100644 index 5b26929d393d7..0000000000000 --- a/vllm/model_executor/layers/multi_heads_logits_processor.py +++ /dev/null @@ -1,140 +0,0 @@ -"""A layer that compute logits from hidden_stats.""" -import inspect -from typing import Optional, List - -import torch -import torch.nn as nn - -from vllm.distributed import tensor_model_parallel_gather -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata - - -class MultiHeadLogitsProcessor(nn.Module): - """Process logits and apply logits processors from sampling metadata. - - This layer does the following: - 1. Gather logits from model hidden_states. - 2. Scale logits if needed. - 3. Apply logits processors (if any). - """ - - def __init__(self, - vocab_size: int, - num_logits_processors: int = 0, - org_vocab_size: Optional[int] = None, - scale: float = 1.0, - logits_as_input: bool = False, - soft_cap: Optional[float] = None) -> None: - """ - Args: - scale: A scaling factor to apply to the logits. - """ - super().__init__() - self.scale = scale - self.vocab_size = vocab_size - self.num_logits_processors = num_logits_processors - # Whether the input is logits (default is hidden states). - self.logits_as_input = logits_as_input - # original vocabulary size (without LoRA). - self.org_vocab_size = org_vocab_size or vocab_size - # Soft cap the logits. Used in Gemma 2. - self.soft_cap = soft_cap - - def forward( - self, - lm_heads: List[nn.Linear], - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - embedding_bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if self.logits_as_input: - logits = hidden_states - else: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - - # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, lm_heads) - if logits is not None: - if self.soft_cap is not None: - logits = logits / self.soft_cap - logits = torch.tanh(logits) - logits = logits * self.soft_cap - - if self.scale != 1.0: - logits *= self.scale - - # Apply logits processors (if any). - logits = _apply_logits_processors(logits, sampling_metadata) - - return logits - - def _get_logits(self, hidden_states: torch.Tensor, - lm_heads: List[nn.Linear]) -> torch.Tensor: - # Get the logits for the next tokens. - logits_all = torch.zeros(self.num_logits_processors, hidden_states.size(0), self.vocab_size, device=hidden_states.device, dtype=hidden_states.dtype) - for i, lm_head in enumerate(lm_heads): - logits = lm_head(hidden_states) - logits = tensor_model_parallel_gather(logits) - # Remove paddings in vocab (if any). - if logits is not None: - logits = logits[:, :self.org_vocab_size] - logits_all[i] = logits - logits_all = logits_all.permute(1, 0, 2) - return logits_all - - def extra_repr(self) -> str: - s = f"vocab_size={self.vocab_size}" - s += f", forg_vocab_size={self.org_vocab_size}" - s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" - return s - - -def _prune_hidden_states( - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - return hidden_states.index_select(0, - sampling_metadata.selected_token_indices) - - -def _apply_logits_processors( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - found_logits_processors = False - logits_processed = 0 - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - - for seq_id, logits_row_idx in zip(seq_ids, - seq_group.sample_indices): - logits_row = logits[logits_row_idx] - past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids - prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids - - for logits_processor in logits_processors: - parameters = inspect.signature(logits_processor).parameters - if len(parameters) == 3: - logits_row = logits_processor(prompt_tokens_ids, - past_tokens_ids, - logits_row) - else: - logits_row = logits_processor(past_tokens_ids, - logits_row) - - logits[logits_row_idx] = logits_row - - logits_processed += len(seq_group.sample_indices) + len( - seq_group.prompt_logprob_indices) - - if found_logits_processors: - # verifies that no rows in logits were missed unexpectedly - assert logits_processed == logits.shape[0] - return logits diff --git a/vllm/model_executor/layers/multi_heads_sampler.py b/vllm/model_executor/layers/multi_heads_sampler.py deleted file mode 100644 index 21b490f674bcb..0000000000000 --- a/vllm/model_executor/layers/multi_heads_sampler.py +++ /dev/null @@ -1,54 +0,0 @@ -"""A layer that samples the next tokens from the model's outputs.""" -import itertools -from typing import Dict, List, Optional, Tuple - -import torch -import torch.nn as nn - -from vllm.model_executor.layers.ops.sample import sample as sample_triton -from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingTensors, - SequenceGroupToSample) -from vllm.sampling_params import SamplingType -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - PromptLogprobs, SampleLogprobs, SamplerOutput, - SequenceOutput) -from vllm.model_executor.layers.sampler import Sampler, _apply_top_k_top_p, _sample, _get_logprobs, _build_sampler_output - -# (num_token_ids, num_parent_ids) per sequence group. -SampleResultType = List[Tuple[List[int], List[int]]] - - -class MultiheadsSampler(nn.Module): - def __init__(self, num_heads: int): - super().__init__() - - # Whether or not the SamplerOutput should have on-device tensors - # containing the sampled token ids and probabilities. This is used by - # speculative decoding. - self.num_heads = num_heads - self.include_gpu_probs_tensor = False - self.heads = nn.ModuleList([Sampler() for _ in range(num_heads)]) - - def forward( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - # Sample from each head - head_logits = logits.permute(1, 0, 2) - output0 = self.heads[0](head_logits[0], sampling_metadata) - for i in range(self.num_heads - 1): - output = self.heads[i + 1](head_logits[i + 1], sampling_metadata) - self.merge_sample_results(output0, output) - - return output0 - - def merge_sample_results( - self, - source: SamplerOutput, - target: SamplerOutput, - ): - for o_a, o_b in zip(source.outputs, target.outputs): - for s_a, s_b in zip(o_a.samples, o_b.samples): - s_a.output_tokens.append(s_b.output_token) \ No newline at end of file diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 944c98e304524..42362ebf51a48 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -10,9 +10,10 @@ from vllm.config import CacheConfig, LoRAConfig from vllm.inputs import INPUT_REGISTRY from vllm.inputs.registry import InputContext -from vllm.model_executor.layers.multi_heads_logits_processor import MultiHeadLogitsProcessor +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.multi_heads_sampler import MultiheadsSampler +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -62,8 +63,8 @@ def __init__(self, self.head_code = nn.ModuleList([ weight_norm(nn.Linear(self.model_dim, self.num_audio_tokens, bias=False), name='weight') for _ in range(self.num_vq) ]) - self.logits_processor = MultiHeadLogitsProcessor(self.num_audio_tokens, self.num_vq) - self.sampler = MultiheadsSampler(self.num_vq) + self.logits_processor = LogitsProcessor(self.num_audio_tokens) + self.sampler = Sampler() def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -106,7 +107,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.head_code, hidden_states, sampling_metadata) + logits = [ + self.logits_processor(self.head_code[i], hidden_states, sampling_metadata) + for i in range(self.num_vq) + ] + logits = torch.stack(logits, 0).permute(1, 0, 2) return logits def sample( @@ -114,7 +119,12 @@ def sample( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) + head_logits = logits.permute(1, 0, 2) + next_tokens = self.sampler(head_logits[0], sampling_metadata) + for i in range(self.num_vq - 1): + output = self.sampler(head_logits[i + 1], sampling_metadata) + self.merge_sample_results(next_tokens, output) + for output in next_tokens.outputs: for sample in output.samples: sample.output_token += self.num_text_tokens @@ -160,3 +170,12 @@ def apply_spk_emb( assert emb.size(1) == spk_emb.size(1) assert attn_metadata.seq_lens_tensor.size(0) == spk_emb.size(0) pass + + def merge_sample_results( + self, + source: SamplerOutput, + target: SamplerOutput, + ): + for o_a, o_b in zip(source.outputs, target.outputs): + for s_a, s_b in zip(o_a.samples, o_b.samples): + s_a.output_tokens.append(s_b.output_token) \ No newline at end of file From d1f961ffb6088ce93c79aaf5de9b6ad10db72a53 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Sat, 13 Jul 2024 23:59:58 +0800 Subject: [PATCH 10/61] add async --- tts.py | 2 +- tts_async.py | 39 +++++++++++++++++++++++++++++++++ vllm/engine/async_llm_engine.py | 4 ++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 tts_async.py diff --git a/tts.py b/tts.py index 6de8abdc46c2f..741678f11c66c 100644 --- a/tts.py +++ b/tts.py @@ -36,7 +36,7 @@ # # save the model # torch.save(tts, '/home/largeniu/ttslm/GPT_merged_emb.pt') -llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, enforce_eager=True, dtype=torch.float32) +llm = LLM(model='/home/largeniu/ttslm', gpu_memory_utilization=0.5, enforce_eager=True, dtype=torch.float32) prompts = [ { "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", diff --git a/tts_async.py b/tts_async.py new file mode 100644 index 0000000000000..5708e63709827 --- /dev/null +++ b/tts_async.py @@ -0,0 +1,39 @@ +import asyncio +import time + +import torch +from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams + +prompts = [ + { + "prompt": "[Stts][empty_spk][speed_5]Your text one Your text one Your text one Your text one Your text one[Ptts]", + "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, + }, + { + "prompt": "[Stts][empty_spk][speed_5]Your text two[Ptts]", + "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, + } +] + +engine_args = AsyncEngineArgs(model='/home/largeniu/ttslm', gpu_memory_utilization=0.5, enforce_eager=True, dtype=torch.float32) +model = AsyncLLMEngine.from_engine_args(engine_args) +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) + +async def generate_streaming(prompt, id): + results_generator = model.generate(prompt, sampling_params, request_id=id) + count=0 + async for request_output in results_generator: + token_ids = request_output.outputs[0].token_ids + print(f'{id} {[x - 21178 for x in token_ids[-1]]}') + count+=1 + + print(count) + +async def generate(): + tasks = [] + for i in range(5): + t = generate_streaming(prompts[i%2], i) + tasks.append(t) + await asyncio.gather(*tasks) + +asyncio.run(generate()) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 33e40c7b3624a..dcc568c0d7905 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -278,6 +278,10 @@ async def process_model_inputs_async( lora_request=lora_request) else: prompt_token_ids = inputs["prompt_token_ids"] + + if hasattr(self.model_config.hf_config, "num_output_head"): + # duplicate the prompt_token_ids for each head + prompt_token_ids = [[i] * self.model_config.hf_config.num_output_head for i in prompt_token_ids] llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), From 9a86ee8c66d06390e908e238bc681abd0e4e048d Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Sun, 14 Jul 2024 19:32:49 +0800 Subject: [PATCH 11/61] add spk emb --- tts.py | 4 ++-- vllm/model_executor/models/ttslm.py | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tts.py b/tts.py index 741678f11c66c..e66b44d043457 100644 --- a/tts.py +++ b/tts.py @@ -39,11 +39,11 @@ llm = LLM(model='/home/largeniu/ttslm', gpu_memory_utilization=0.5, enforce_eager=True, dtype=torch.float32) prompts = [ { - "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", + "prompt": "[Stts][spk_emb][speed_5]Your text one[Ptts]", "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, }, { - "prompt": "[Stts][empty_spk][speed_5]Your text two[Ptts]", + "prompt": "[Stts][spk_emb][speed_5]Your text two[Ptts]", "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, } ] diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 42362ebf51a48..250c963bf4c0c 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -169,7 +169,13 @@ def apply_spk_emb( ): assert emb.size(1) == spk_emb.size(1) assert attn_metadata.seq_lens_tensor.size(0) == spk_emb.size(0) - pass + # convert spk_emb to the same dtype as emb + spk_emb = spk_emb.to(emb.dtype) + # find the index of the speaker token + indices = (input_ids[:,0] == self.spk_emb_token_id).nonzero(as_tuple=True) + if indices[0].size(0) == 0: + return + emb.index_put_(indices, spk_emb) def merge_sample_results( self, From 693c2d930d72c89503f167c9c97e6515507778d9 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Mon, 15 Jul 2024 22:17:31 +0800 Subject: [PATCH 12/61] add benchmark --- benchmarks/benchmark_tts.py | 372 ++++++++++++++++++++++++++++++++++++ vllm/worker/model_runner.py | 6 +- 2 files changed, 377 insertions(+), 1 deletion(-) create mode 100644 benchmarks/benchmark_tts.py diff --git a/benchmarks/benchmark_tts.py b/benchmarks/benchmark_tts.py new file mode 100644 index 0000000000000..1956231610a1b --- /dev/null +++ b/benchmarks/benchmark_tts.py @@ -0,0 +1,372 @@ +"""Benchmark offline inference throughput.""" +import argparse +import json +import random +import time +from typing import List, Optional, Tuple + +import torch +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.utils import FlexibleArgumentParser + +def run_vllm( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + use_beam_search: bool, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + distributed_executor_backend: Optional[str], + gpu_memory_utilization: float = 0.9, + download_dir: Optional[str] = None, + load_format: str = EngineArgs.load_format, +) -> Tuple[float, int]: + from vllm import LLM, SamplingParams + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + load_format=load_format, + ) + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + for prompt, _, output_len in requests: + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=1, + temperature=1, + detokenize=False, + stop_token_ids=[21803], + max_tokens=2048, + top_k=1 + )) + + start = time.perf_counter() + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + total_output_tokens = 0 + for output in outputs: + total_output_tokens += len(output.outputs[0].token_ids) + end = time.perf_counter() + return end - start, total_output_tokens + + +def run_hf( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + use_beam_search: bool, + max_batch_size: int, + trust_remote_code: bool, +) -> float: + assert not use_beam_search + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: List[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt, prompt_len, output_len = requests[i] + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + _, next_prompt_len, next_output_len = requests[i + 1] + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=not use_beam_search, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, + use_cache=True, + max_new_tokens=max_output_len, + ) + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + requests = open(args.dataset).read().splitlines() + requests = [(f'[Stts][spk_emb][speed_5]{request}[Ptts]', 0, 0) for request in requests] + requests = requests[:args.num_prompts] + + input_ids = tokenizer([x[0] for x in requests], return_tensors="pt", padding=True).input_ids + + if args.backend == "vllm": + elapsed_time, total_num_tokens = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.distributed_executor_backend, + args.gpu_memory_utilization, args.download_dir, args.load_format) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.use_beam_search, args.hf_max_batch_size, + args.trust_remote_code) + else: + raise ValueError(f"Unknown backend: {args.backend}") + + print(f"Total input {input_ids.numel()}, total output {total_num_tokens}") + print(f"Elapsed time: {elapsed_time:.2f}s") + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=[*QUANTIZATION_METHODS, None], + default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument( + '--max-model-len', + type=int, + default=None, + help='Maximum length of a sequence (including prompt and output). ' + 'If None, will be derived from the model.') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument("--enforce-eager", + action="store_true", + help="enforce eager execution") + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') + parser.add_argument( + "--device", + type=str, + default="auto", + choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], + help='device type for vLLM execution, supporting CUDA, OpenVINO and ' + 'CPU.') + parser.add_argument( + "--enable-prefix-caching", + action='store_true', + help="enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--download-dir', + type=str, + default=None, + help='directory to download and load the weights, ' + 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=None, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', + 'bitsandbytes' + ], + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave. See the Tensorize vLLM Model script in the Examples' + 'section for more information.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n') + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + if args.backend == "vllm": + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + elif args.backend == "hf": + if args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + elif args.backend == "mii": + if args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.use_beam_search: + raise ValueError("Beam search is not supported for MII backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + if args.tokenizer != args.model: + raise ValueError("Tokenizer must be the same as the model for MII " + "backend.") + main(args) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 050a5cf0e6894..b0e38562aca05 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -612,7 +612,11 @@ def _prepare_model_input_tensors( graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size for _ in range(graph_batch_size - batch_size): - input_tokens.append(0) + if hasattr(self.model_config.hf_config, "num_output_head"): + # duplicate the prompt_token_ids for each head + input_tokens.append([0] * self.model_config.hf_config.num_output_head) + else: + input_tokens.append(0) input_positions.append(0) slot_mapping.append(_PAD_SLOT_ID) seq_lens.append(1) From 90802c50bfe49289c77bb78ae2ee3c02672d7401 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Mon, 15 Jul 2024 23:51:26 +0800 Subject: [PATCH 13/61] add benchmark --- benchmarks/backend_request_func.py | 3 +- benchmarks/benchmark_tts.py | 246 ++++++++++++++++++++++++++++- 2 files changed, 243 insertions(+), 6 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index fbab547d094fe..a783a3613ac1e 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -28,7 +28,8 @@ class RequestFuncInput: @dataclass class RequestFuncOutput: - generated_text: str = "" + generated_text: str = "", + output_tokens: Union[List[int], List[List[int]]] = [], success: bool = False latency: float = 0.0 ttft: float = 0.0 # Time to first token diff --git a/benchmarks/benchmark_tts.py b/benchmarks/benchmark_tts.py index 1956231610a1b..9a33126791835 100644 --- a/benchmarks/benchmark_tts.py +++ b/benchmarks/benchmark_tts.py @@ -1,19 +1,243 @@ """Benchmark offline inference throughput.""" import argparse +import asyncio +from asyncio import tasks import json import random import time -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, AsyncGenerator +import warnings +import numpy as np import torch from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) +from benchmarks.backend_request_func import RequestFuncInput, RequestFuncOutput +from benchmarks.benchmark_serving import BenchmarkMetrics from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams from vllm.utils import FlexibleArgumentParser +def calculate_metrics( + input_requests: List[Tuple[str, int, int]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, +) -> Tuple[BenchmarkMetrics, List[int]]: + actual_output_lens: List[int] = [] + total_input = 0 + completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + # We use the tokenizer to count the number of output tokens for all + # serving backends instead of looking at len(outputs[i].itl) since + # multiple output tokens may be bundled together + # Note : this may inflate the output token count slightly + output_len = len(outputs[i].output_tokens) + actual_output_lens.append(output_len) + total_input += input_requests[i][1] + if output_len > 1: + tpots.append( + (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + completed += 1 + else: + actual_output_lens.append(0) + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + std_ttft_ms=np.std(ttfts or 0) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + ) + + return metrics, actual_output_lens + +async def get_request( + input_requests: List[Tuple[str, int, int]], + request_rate: float, +) -> AsyncGenerator[Tuple[str, int, int], None]: + input_requests = iter(input_requests) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + +async def generate_streaming(llm: AsyncLLMEngine, request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None)-> RequestFuncOutput: + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + ttft = 0.0 + st = time.perf_counter() + sampling_params = SamplingParams(n=1, temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) + results_generator = llm.generate(request_func_input.prompt, sampling_params, request_id=id) + async for request_output in results_generator: + token_ids = request_output.outputs[0].token_ids + # print(f'{id} {[x - 21178 for x in token_ids[-1]]}') + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + output.output_tokens = token_ids + + if pbar: + pbar.update(1) + return output + +async def run_vllm_async( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + use_beam_search: bool, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + distributed_executor_backend: Optional[str], + request_rate=16, + gpu_memory_utilization: float = 0.9, + download_dir: Optional[str] = None, + load_format: str = EngineArgs.load_format, +) -> Tuple[float, int]: + + + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + load_format=load_format, + ) + llm = AsyncLLMEngine.from_engine_args(engine_args) + pbar = tqdm(total=len(requests)) + benchmark_start_time = time.perf_counter() + + async for request in get_request(requests, request_rate): + prompt, prompt_len, output_len = request + request_func_input = RequestFuncInput( + model=model, + prompt=prompt, + prompt_len=prompt_len, + output_len=output_len, + use_beam_search=use_beam_search, + ) + tasks.append( + asyncio.create_task( + generate_streaming(llm, request_func_input=request_func_input, + pbar=pbar))) + + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, actual_output_lens = calculate_metrics( + input_requests=requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + ) + + print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", + benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", + metrics.request_throughput)) + print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):", + metrics.input_throughput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-')) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", + metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)', + n=50, + c='-')) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", + metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-')) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("=" * 50) + def run_vllm( requests: List[Tuple[str, int, int]], model: str, @@ -144,7 +368,6 @@ def run_hf( end = time.perf_counter() return end - start - def main(args: argparse.Namespace): print(args) random.seed(args.seed) @@ -153,10 +376,21 @@ def main(args: argparse.Namespace): tokenizer = AutoTokenizer.from_pretrained( args.tokenizer, trust_remote_code=args.trust_remote_code) requests = open(args.dataset).read().splitlines() - requests = [(f'[Stts][spk_emb][speed_5]{request}[Ptts]', 0, 0) for request in requests] + requests = [(f'[Stts][spk_emb][speed_5]{request}[Ptts]', len(tokenizer(request).input_ids), 2048) for request in requests] requests = requests[:args.num_prompts] - input_ids = tokenizer([x[0] for x in requests], return_tensors="pt", padding=True).input_ids + total_input_tokens = sum(count for _, count, _ in requests) + + if args.streaming: + asyncio.run(run_vllm_async(requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.distributed_executor_backend, + args.gpu_memory_utilization, args.download_dir, args.load_format)) + return if args.backend == "vllm": elapsed_time, total_num_tokens = run_vllm( @@ -176,7 +410,7 @@ def main(args: argparse.Namespace): else: raise ValueError(f"Unknown backend: {args.backend}") - print(f"Total input {input_ids.numel()}, total output {total_num_tokens}") + print(f"Total input {total_input_tokens}, total output {total_num_tokens}") print(f"Elapsed time: {elapsed_time:.2f}s") print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} tokens/s") @@ -200,6 +434,7 @@ def main(args: argparse.Namespace): type=str, choices=["vllm", "hf", "mii"], default="vllm") + parser.add_argument("--streaming", action="store_true") parser.add_argument("--dataset", type=str, default=None, @@ -369,4 +604,5 @@ def main(args: argparse.Namespace): if args.tokenizer != args.model: raise ValueError("Tokenizer must be the same as the model for MII " "backend.") + main(args) From a63218592b54b1f8199aec1d745b6a593cd6f0d2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 16 Jul 2024 09:28:34 +0000 Subject: [PATCH 14/61] remove weight norm --- benchmarks/benchmark_tts.py | 37 +++++++++++++++++------------ tts.py | 19 +++++++++++---- vllm/model_executor/models/ttslm.py | 6 ++--- 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/benchmarks/benchmark_tts.py b/benchmarks/benchmark_tts.py index 9a33126791835..965ad174bb18e 100644 --- a/benchmarks/benchmark_tts.py +++ b/benchmarks/benchmark_tts.py @@ -14,8 +14,8 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) -from benchmarks.backend_request_func import RequestFuncInput, RequestFuncOutput -from benchmarks.benchmark_serving import BenchmarkMetrics +from backend_request_func import RequestFuncInput, RequestFuncOutput +from benchmark_serving import BenchmarkMetrics from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams @@ -84,8 +84,8 @@ async def get_request( input_requests: List[Tuple[str, int, int]], request_rate: float, ) -> AsyncGenerator[Tuple[str, int, int], None]: - input_requests = iter(input_requests) - for request in input_requests: + requests = iter(input_requests) + for request in requests: yield request if request_rate == float("inf"): @@ -97,13 +97,14 @@ async def get_request( # The next request will be sent after the interval. await asyncio.sleep(interval) -async def generate_streaming(llm: AsyncLLMEngine, request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None)-> RequestFuncOutput: +async def generate_streaming(llm: AsyncLLMEngine, request_func_input: RequestFuncInput, request_id:str, pbar: Optional[tqdm] = None)-> RequestFuncOutput: output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len ttft = 0.0 st = time.perf_counter() + most_recent_timestamp = st sampling_params = SamplingParams(n=1, temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) - results_generator = llm.generate(request_func_input.prompt, sampling_params, request_id=id) + results_generator = llm.generate(request_func_input.prompt, sampling_params, request_id=request_id) async for request_output in results_generator: token_ids = request_output.outputs[0].token_ids # print(f'{id} {[x - 21178 for x in token_ids[-1]]}') @@ -151,9 +152,9 @@ async def run_vllm_async( gpu_memory_utilization: float = 0.9, download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, -) -> Tuple[float, int]: +): - + engine_count = 1 engine_args = AsyncEngineArgs( model=model, tokenizer=tokenizer, @@ -163,7 +164,7 @@ async def run_vllm_async( trust_remote_code=trust_remote_code, dtype=dtype, max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, + gpu_memory_utilization=gpu_memory_utilization/engine_count, enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, quantization_param_path=quantization_param_path, @@ -178,20 +179,22 @@ async def run_vllm_async( llm = AsyncLLMEngine.from_engine_args(engine_args) pbar = tqdm(total=len(requests)) benchmark_start_time = time.perf_counter() - + tasks: List[asyncio.Task] = [] + request_id = 0 async for request in get_request(requests, request_rate): prompt, prompt_len, output_len = request request_func_input = RequestFuncInput( + api_url="", model=model, prompt=prompt, prompt_len=prompt_len, output_len=output_len, use_beam_search=use_beam_search, ) + request_id += 1 tasks.append( asyncio.create_task( - generate_streaming(llm, request_func_input=request_func_input, - pbar=pbar))) + generate_streaming(llm, request_func_input=request_func_input, request_id=str(request_id), pbar=pbar))) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) @@ -375,8 +378,8 @@ def main(args: argparse.Namespace): # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( args.tokenizer, trust_remote_code=args.trust_remote_code) - requests = open(args.dataset).read().splitlines() - requests = [(f'[Stts][spk_emb][speed_5]{request}[Ptts]', len(tokenizer(request).input_ids), 2048) for request in requests] + lines = open(args.dataset).read().splitlines() + requests = [(f'[Stts][spk_emb][speed_5]{line}[Ptts]', len(tokenizer(line).input_ids), 2048) for line in lines] requests = requests[:args.num_prompts] total_input_tokens = sum(count for _, count, _ in requests) @@ -388,7 +391,7 @@ def main(args: argparse.Namespace): args.enforce_eager, args.kv_cache_dtype, args.quantization_param_path, args.device, args.enable_prefix_caching, args.enable_chunked_prefill, - args.max_num_batched_tokens, args.distributed_executor_backend, + args.max_num_batched_tokens, args.distributed_executor_backend, args.request_rate, args.gpu_memory_utilization, args.download_dir, args.load_format)) return @@ -435,6 +438,10 @@ def main(args: argparse.Namespace): choices=["vllm", "hf", "mii"], default="vllm") parser.add_argument("--streaming", action="store_true") + parser.add_argument("--request-rate", + type=int, + default=None, + help="request rate per second") parser.add_argument("--dataset", type=str, default=None, diff --git a/tts.py b/tts.py index e66b44d043457..00074faaa9630 100644 --- a/tts.py +++ b/tts.py @@ -1,7 +1,7 @@ from vllm import LLM, SamplingParams import torch torch.random.manual_seed(999) -# tts = torch.load('/home/largeniu/ttslm/GPT.pt') +# tts = torch.load('/home/zhn/g/ChatTTS/asset/GPT.pt') # text_emb_count = tts['emb_text.weight'].shape[0] # audio_emb_count = tts['emb_code.0.weight'].shape[0] @@ -33,17 +33,26 @@ # tts['emb_all.2.weight'] = all_2 # tts['emb_all.3.weight'] = all_3 +# for i in range(4): +# original0 = tts[f'head_code.{i}.parametrizations.weight.original0'] +# original1 = tts[f'head_code.{i}.parametrizations.weight.original1'] +# # get the normalized weights based on the original 0 and 1 +# weight_norm0 = torch._weight_norm(original1, original0, dim=0) +# tts.pop(f'head_code.{i}.parametrizations.weight.original0') +# tts.pop(f'head_code.{i}.parametrizations.weight.original1') +# tts[f'lm_head.{i}.weight'] = weight_norm0 + # # save the model -# torch.save(tts, '/home/largeniu/ttslm/GPT_merged_emb.pt') +# torch.save(tts, '/home/zhn/ttslm/GPT_merged_emb_nonorm.pt') -llm = LLM(model='/home/largeniu/ttslm', gpu_memory_utilization=0.5, enforce_eager=True, dtype=torch.float32) +llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, enforce_eager=True, dtype=torch.float32) prompts = [ { - "prompt": "[Stts][spk_emb][speed_5]Your text one[Ptts]", + "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, }, { - "prompt": "[Stts][spk_emb][speed_5]Your text two[Ptts]", + "prompt": "[Stts][empty_spk][speed_5]Your text two[Ptts]", "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, } ] diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 250c963bf4c0c..3880e49f5e821 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -60,8 +60,8 @@ def __init__(self, ]) self.head_text = weight_norm(nn.Linear(self.model_dim, self.num_text_tokens, bias=False), name='weight') - self.head_code = nn.ModuleList([ - weight_norm(nn.Linear(self.model_dim, self.num_audio_tokens, bias=False), name='weight') for _ in range(self.num_vq) + self.lm_head = nn.ModuleList([ + nn.Linear(self.model_dim, self.num_audio_tokens, bias=False) for _ in range(self.num_vq) ]) self.logits_processor = LogitsProcessor(self.num_audio_tokens) self.sampler = Sampler() @@ -108,7 +108,7 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = [ - self.logits_processor(self.head_code[i], hidden_states, sampling_metadata) + self.logits_processor(self.lm_head[i], hidden_states, sampling_metadata) for i in range(self.num_vq) ] logits = torch.stack(logits, 0).permute(1, 0, 2) From 3e127db2231f439efb265d1fbc5e0ab187964cd3 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 16 Jul 2024 09:48:42 +0000 Subject: [PATCH 15/61] VocabParallelEmbedding --- vllm/model_executor/models/ttslm.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 3880e49f5e821..7b8308ffd0f8b 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead +from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -46,8 +46,6 @@ def __init__(self, super().__init__() # static parameters, put them in config later - self.spk_emb_dim = 192 - self.spk_KL = 8 self.num_audio_tokens = 626 self.num_text_tokens = 21178 self.num_vq = 4 @@ -56,10 +54,9 @@ def __init__(self, self.gpt = LlamaModel(config) self.model_dim = self.gpt.config.hidden_size self.emb_all = nn.ModuleList([ - nn.Embedding(self.num_audio_tokens + self.num_text_tokens, self.model_dim) for _ in range(self.num_vq) + VocabParallelEmbedding(self.num_audio_tokens + self.num_text_tokens, self.model_dim) for _ in range(self.num_vq) ]) - self.head_text = weight_norm(nn.Linear(self.model_dim, self.num_text_tokens, bias=False), name='weight') self.lm_head = nn.ModuleList([ nn.Linear(self.model_dim, self.num_audio_tokens, bias=False) for _ in range(self.num_vq) ]) From b8c94af8ee8377784d0cd7dc0a730bdbc44f0d4c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 29 Jul 2024 13:56:29 +0000 Subject: [PATCH 16/61] use base64 --- tts.py | 8 ++--- vllm/model_executor/models/ttslm.py | 24 +++++++-------- vllm/multimodal/registry.py | 2 +- vllm/multimodal/speech.py | 48 +++++++++++++++++++---------- 4 files changed, 49 insertions(+), 33 deletions(-) diff --git a/tts.py b/tts.py index 00074faaa9630..3529fc5510034 100644 --- a/tts.py +++ b/tts.py @@ -1,7 +1,7 @@ from vllm import LLM, SamplingParams import torch torch.random.manual_seed(999) -# tts = torch.load('/home/zhn/g/ChatTTS/asset/GPT.pt') +# tts = torch.load('/home/zhn/ttslm/GPT_merged_emb_nonorm.pt') # text_emb_count = tts['emb_text.weight'].shape[0] # audio_emb_count = tts['emb_code.0.weight'].shape[0] @@ -45,15 +45,15 @@ # # save the model # torch.save(tts, '/home/zhn/ttslm/GPT_merged_emb_nonorm.pt') -llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, enforce_eager=True, dtype=torch.float32) +llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, enforce_eager=True) prompts = [ { "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", - "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, + "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, }, { "prompt": "[Stts][empty_spk][speed_5]Your text two[Ptts]", - "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, + "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, } ] sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 7b8308ffd0f8b..fe342ff9ddeb4 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -8,7 +8,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.inputs import INPUT_REGISTRY +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs.registry import InputContext from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -18,7 +18,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.speech import SpeechPlugin +from vllm.multimodal.speech import SpeechPlugin,FishSpeechPlugin from vllm.sequence import IntermediateTensors, SamplerOutput @@ -32,7 +32,7 @@ def dummy_data_for_ttsllm(ctx: InputContext, seq_len: int): dummy_seq_data = SequenceData([[0] * ctx.model_config.hf_config.num_output_head] * seq_len) - dummy_multi_modal_data = {"speech": SpeechPlugin.sample_random_speaker()} + dummy_multi_modal_data = {"audio": SpeechPlugin.sample_random_speaker()} return dummy_seq_data, dummy_multi_modal_data @@ -46,19 +46,19 @@ def __init__(self, super().__init__() # static parameters, put them in config later - self.num_audio_tokens = 626 - self.num_text_tokens = 21178 - self.num_vq = 4 + self.num_audio_tokens = config.num_audio_tokens + self.num_text_tokens = config.num_text_tokens + self.num_output_head = config.num_output_head self.spk_emb_token_id = 21143 self.gpt = LlamaModel(config) self.model_dim = self.gpt.config.hidden_size self.emb_all = nn.ModuleList([ - VocabParallelEmbedding(self.num_audio_tokens + self.num_text_tokens, self.model_dim) for _ in range(self.num_vq) + VocabParallelEmbedding(self.num_audio_tokens + self.num_text_tokens, self.model_dim) for _ in range(self.num_output_head) ]) self.lm_head = nn.ModuleList([ - nn.Linear(self.model_dim, self.num_audio_tokens, bias=False) for _ in range(self.num_vq) + nn.Linear(self.model_dim, self.num_audio_tokens, bias=False) for _ in range(self.num_output_head) ]) self.logits_processor = LogitsProcessor(self.num_audio_tokens) self.sampler = Sampler() @@ -97,7 +97,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: code_emb = [ self.emb_all[i](input_ids[:,i]) - for i in range(self.num_vq) + for i in range(self.num_output_head) ] emb = torch.stack(code_emb, 2).sum(2) return emb @@ -106,7 +106,7 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = [ self.logits_processor(self.lm_head[i], hidden_states, sampling_metadata) - for i in range(self.num_vq) + for i in range(self.num_output_head) ] logits = torch.stack(logits, 0).permute(1, 0, 2) return logits @@ -118,14 +118,14 @@ def sample( ) -> Optional[SamplerOutput]: head_logits = logits.permute(1, 0, 2) next_tokens = self.sampler(head_logits[0], sampling_metadata) - for i in range(self.num_vq - 1): + for i in range(self.num_output_head - 1): output = self.sampler(head_logits[i + 1], sampling_metadata) self.merge_sample_results(next_tokens, output) for output in next_tokens.outputs: for sample in output.samples: sample.output_token += self.num_text_tokens - for i in range(self.num_vq): + for i in range(self.num_output_head): sample.output_tokens[i] += self.num_text_tokens dic = {} for k,v in sample.logprobs.items(): diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 63df0df551330..50badbf9e4d13 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -85,7 +85,7 @@ def register_speech_input_mapper( See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ - return self.register_input_mapper("speech", mapper) + return self.register_input_mapper("audio", mapper) def map_input(self, model_config: ModelConfig, data: MultiModalDataDict) -> MultiModalInputs: diff --git a/vllm/multimodal/speech.py b/vllm/multimodal/speech.py index 8390e2783b36e..9c995e3015f44 100644 --- a/vllm/multimodal/speech.py +++ b/vllm/multimodal/speech.py @@ -16,21 +16,41 @@ from .base import MultiModalInputs, MultiModalPlugin import pybase16384 as b14 +import base64 +import pickle + +class FishSpeechPlugin(MultiModalPlugin): + + def get_data_key(self) -> str: + return "audio1" + + def _default_input_mapper(self, ctx: InputContext, + data: object) -> MultiModalInputs: + if isinstance(data, str): + base64_decoded = base64.b64decode(data) + deserialized_data = pickle.loads(base64_decoded) + tensor = torch.from_numpy(deserialized_data) + return MultiModalInputs({"audio": tensor}) + elif isinstance(data, torch.Tensor): + raise NotImplementedError("Embeddings input is not supported yet") + + raise TypeError(f"Invalid image type: {type(data)}") + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + return 16 + + @staticmethod + def get_default_audio(): + return 'a' class SpeechPlugin(MultiModalPlugin): def get_data_key(self) -> str: - return "speech" + return "audio" def _decode_spk_emb(self, spk_emb: str) -> np.ndarray: - return np.frombuffer( - lzma.decompress( - b14.decode_from_string(spk_emb), - format=lzma.FORMAT_RAW, - filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], - ), - dtype=np.float16, - ).copy() + n = base64.b64decode(spk_emb) + return np.frombuffer(n, dtype=np.float16).copy() def _default_input_mapper(self, ctx: InputContext, data: object) -> MultiModalInputs: @@ -54,10 +74,6 @@ def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: @staticmethod def sample_random_speaker() -> str: - return b14.encode_to_string( - lzma.compress( - np.random.randn(768).astype(np.float16).tobytes(), - format=lzma.FORMAT_RAW, - filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}] - ) - ) + n = np.random.randn(768).astype(np.float16) + s = base64.b64encode(n).decode("utf-8") + return s From bd0fc4fb4728055f1743cbcbc4239a4c421b8477 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 Aug 2024 06:45:10 +0000 Subject: [PATCH 17/61] merge main --- vllm/model_executor/models/ttslm.py | 10 +++++++--- vllm/multimodal/registry.py | 12 +++++++++++- vllm/sequence.py | 22 +++++++--------------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index fe342ff9ddeb4..230e587cb67c4 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union import torch from torch import nn @@ -18,7 +18,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.speech import SpeechPlugin,FishSpeechPlugin +from vllm.multimodal.speech import SpeechPlugin from vllm.sequence import IntermediateTensors, SamplerOutput @@ -26,7 +26,7 @@ import numpy as np import pybase16384 as b14 -def dummy_data_for_ttsllm(ctx: InputContext, seq_len: int): +def dummy_data_for_ttsllm(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): from vllm.sequence import SequenceData @@ -36,8 +36,12 @@ def dummy_data_for_ttsllm(ctx: InputContext, seq_len: int): return dummy_seq_data, dummy_multi_modal_data +def get_max_speech_tokens(ctx: InputContext): + return 16 + @MULTIMODAL_REGISTRY.register_speech_input_mapper() @INPUT_REGISTRY.register_dummy_data(dummy_data_for_ttsllm) +@MULTIMODAL_REGISTRY.register_max_speech_tokens(get_max_speech_tokens) class ChatTtsLlm(nn.Module): def __init__(self, config: LlamaConfig, diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index c6015e225695e..1c9f0fc63e290 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -101,7 +101,7 @@ def register_speech_input_mapper( mapper: Optional[MultiModalInputMapper] = None, ): """ - Register an input mapper for image data to a model class. + Register an input mapper for speech data to a model class. See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ @@ -173,6 +173,16 @@ def register_max_image_tokens( image, that are passed to the language model for a model class. """ return self.register_max_multimodal_tokens("image", max_mm_tokens) + + def register_max_speech_tokens( + self, + max_mm_tokens: Optional[MultiModalTokensCalc] = None, + ): + """ + Register the maximum number of speech tokens, corresponding to a single + speech, that are passed to the language model for a model class. + """ + return self.register_max_multimodal_tokens("audio", max_mm_tokens) def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: """ diff --git a/vllm/sequence.py b/vllm/sequence.py index efc6ec7ece591..945879a2f8ab7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -132,10 +132,10 @@ def __init__( output_token_ids: Optional[List[int]] = None, num_token_head: int = 1, ) -> None: - self._prompt_token_ids = array('l', prompt_token_ids) + self._prompt_token_ids: List[int] = list(prompt_token_ids) self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) - self._output_token_ids = array( - 'l', output_token_ids if output_token_ids is not None else []) + self._output_token_ids: List[int] = ( + list(output_token_ids) if output_token_ids is not None else []) self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). @@ -145,8 +145,8 @@ def __init__( self._update_cached_all_tokens() def _update_cached_all_tokens(self): - self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + - self._output_token_ids) + self._cached_all_token_ids: List[int] = (self._prompt_token_ids + + self._output_token_ids) @property def prompt_token_ids(self) -> Tuple[int, ...]: @@ -154,27 +154,19 @@ def prompt_token_ids(self) -> Tuple[int, ...]: @prompt_token_ids.setter def prompt_token_ids(self, new_prompt_token_ids) -> None: - self._prompt_token_ids = array('l', new_prompt_token_ids) + self._prompt_token_ids = list(new_prompt_token_ids) self._prompt_token_ids_tuple = tuple(new_prompt_token_ids) self._update_cached_all_tokens() - @property - def prompt_token_ids_array(self) -> array: - return self._prompt_token_ids - @property def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @output_token_ids.setter def output_token_ids(self, new_output_token_ids) -> None: - self._output_token_ids = array('l', new_output_token_ids) + self._output_token_ids = list(new_output_token_ids) self._update_cached_all_tokens() - @property - def output_token_ids_array(self) -> array: - return self._output_token_ids - def append_token_id(self, token_id: int, logprob: float) -> None: self._output_token_ids.append(token_id) self._cached_all_token_ids.append(token_id) From a0acc4433178f49f8207475bcd46165182257925 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 Aug 2024 07:26:18 +0000 Subject: [PATCH 18/61] fix merge --- vllm/engine/llm_engine.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cd0ebf9bdb95d..55eba899b1bf0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -742,6 +742,10 @@ def _extract_prompt_components( request_id=request_id, lora_request=lora_request, ) + + if hasattr(self.model_config.hf_config, "num_output_head"): + # duplicate the prompt_token_ids for each head + prompt_token_ids = [[i] * self.model_config.hf_config.num_output_head for i in prompt_token_ids] multi_modal_data = inputs.get("multi_modal_data") else: @@ -947,11 +951,6 @@ def process_model_inputs( request_id=request_id, ) else: - prompt_token_ids = inputs["prompt_token_ids"] - - if hasattr(self.model_config.hf_config, "num_output_head"): - # duplicate the prompt_token_ids for each head - prompt_token_ids = [[i] * self.model_config.hf_config.num_output_head for i in prompt_token_ids] if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") From cea25d217b0db0009af08f3686de95776f9f0c18 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 21 Aug 2024 03:07:58 +0000 Subject: [PATCH 19/61] fix merge main --- tts.py | 2 +- tts_async.py | 8 ++++---- vllm/engine/async_llm_engine.py | 9 ++++----- vllm/engine/llm_engine.py | 7 +++---- vllm/engine/output_processor/single_step.py | 8 +++++++- vllm/model_executor/models/ttslm.py | 1 - vllm/multimodal/speech.py | 1 - vllm/worker/model_runner.py | 10 +++++++--- 8 files changed, 26 insertions(+), 20 deletions(-) diff --git a/tts.py b/tts.py index 3529fc5510034..291a3841a5bde 100644 --- a/tts.py +++ b/tts.py @@ -45,7 +45,7 @@ # # save the model # torch.save(tts, '/home/zhn/ttslm/GPT_merged_emb_nonorm.pt') -llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, enforce_eager=True) +llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5) prompts = [ { "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", diff --git a/tts_async.py b/tts_async.py index 5708e63709827..f1698bc16fb37 100644 --- a/tts_async.py +++ b/tts_async.py @@ -6,16 +6,16 @@ prompts = [ { - "prompt": "[Stts][empty_spk][speed_5]Your text one Your text one Your text one Your text one Your text one[Ptts]", - "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, + "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", + "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, }, { "prompt": "[Stts][empty_spk][speed_5]Your text two[Ptts]", - "multi_modal_data": {"speech": '蘁淰敩欀夃罘嶕姱树犌砵譢殅祋荬囱肩氼囪婚埂杮慧荡螆納茆硧壂姒嚉発籜偔柏滙卼庇擰斴檋檚枉熥枮嶒伭崍圚紑蟽菸剤襨灸螥菨讟愖瑻埨笿琻捌嬐娣舝紟亓膈壷瞁烊侦謐縂磬皡氛蜠椚册房悞绱女簉撚膌炌俨果膫肈堈惀啥撴瑦塵抚螐呄熾滬艘櫵甃卥訷恲厜袎匊峆沇爈欏蝑妗看夦摳臭诬戃圭歙瓭趁覚爞庄曙眆喜殉崂箲譠磋谒綍曆褋呺二狝蠟蚗煗曱痚誏攍课恞巧貧膐仨奞癶蠲崨緔荇扩瑹戾蝎淾烁恹泣掐璷蠟橍珺痛杈啤熶腎撨袘獮焊噭矇禆綿蔈裥罌嶗吒墯楺繥惶豳徸縍娦琖梍兠瀙旋咒貂狭藧浖忮兘搎蔁經儏硨杰棂繴掅巖諹晋啿苕畗貰蝚褰肜澆烌谞椻咞噊唦脶厑腺療才跷屉作匽娆弼恪宱璕灼艦劈瀥奼帅员結留笽椶祘畴葙愗愷犘圏沢嚀祵诽槐亄廅淂苫俩寮寫榄妑滪榡佝联绽啁琗絰臒柏潃葧莢熡澗擵蚏疨耩椒嶼萊之瓌蜈桷胷纽蠺平贒厐厥誐森杊橥櫐氯昫囮睝燎廴朆苐瓅崓璾亥划癊螄忊奬趞堾獪尴旂挮蟂樲濔嚼歌柝嬗襾戊拼蘗朤弉穛碇橢翸懵珖芔惤瓈妫啇嘾咏墛儜紶晞綜薒罙膹竧疝汽揌旁胅簹媯秀獅實珉目棩枛羊嘗琌褠磓畠壞稃蔘壖蓌垓搲致恏禔偘厓耛寍喿啟暚皻義灞牁柏玁喤喞褙必暤熞渥繙弝尸乒母蛝癯筂毰菂耵萤帿赶穲唩讌囉吓盶揉碁莻埐諎禛藯捃窀独畃跪咄擈艜挰岷葿矸啄珙聢瀣怳賎礉炶埬枬曬热侹焘柳巫桏痭藍粑傃橛眣槣脊埙孠浏喲儀卂蠯磟竈腏赥倎淲殦叺峋笎臽緋窼淽叛剋柆勷賻淮憸廽秏歔簾荔嶅赗褨愳贽腈修櫗廬勞嶑僾謰帿螷恉晝揨攮訹剄攝倪纜捷浹廎囑僄荂瑏啝摴笡趕臮蕙趘梴崿盽嵶癀堣謇檝螆朧浽譿耣薥怿槕児歬椷嬌宄豐冊翇肴芹剁帶嗲姡孵炋杶垪丑槅寺澚矼祆拏矑賮诂朡毇夂穚婷簵烀箚呠玨唙奶苧蒎螹舜绚蚨箒盲覻祐枋崣萇裻刾堺氨儮箒蕮嬫嗧譋嗏奠豲案礠嫙倈檧噻豊洺敋砿刘怸媽圌覲緾晃伫藸氬觠晽帖吳樦廏娍书惩漊粲謎工縺豰呄澁囱猩臞秦啭且疅褃娷腉蟱忂死虑臝捝咁嬔斸睌嶮燽肂姵珢挔娐羞悸竩壯榊怤跚膁惟烁坶樨喴曰夷断蔹垬梛嘳苯灰痩簗薽帓聤漪罷刴纕琒綴叱劒絖壈恭跋渃析稐哄劫峑琭胨瀒訦媅許硶砯誏芙螓剂膕涣蔝瓠償芦絸破啗諨皆塥摉糷琍诂羊粑埾獎厺塞弧剙眸屢嶵薐伥疖裤筯憨掍伖袺圣僕蔁绹倱襜垘犽抖窐刊偠瓻珬杪劯溤疿莶洽荷杉簡怪巚舆蓞咙杬叉姵聉画离嬑聘誷瀝箠悒珺謌綛揬妿蓱僐嶢蜎甅一㴁'}, + "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, } ] -engine_args = AsyncEngineArgs(model='/home/largeniu/ttslm', gpu_memory_utilization=0.5, enforce_eager=True, dtype=torch.float32) +engine_args = AsyncEngineArgs(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, dtype=torch.float32) model = AsyncLLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8538fdd746f4f..a96ca769babc5 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -347,6 +347,10 @@ async def _extract_prompt_components_async( else: assert_never(inputs) + if hasattr(self.model_config.hf_config, "num_output_head"): + # duplicate the prompt_token_ids for each head + prompt_token_ids = [[i] * self.model_config.hf_config.num_output_head for i in prompt_token_ids] + return prompt, prompt_token_ids, multi_modal_data async def _process_encoder_decoder_prompt_async( @@ -420,11 +424,6 @@ async def process_model_inputs_async( request_id=request_id, ) else: - prompt_token_ids = inputs["prompt_token_ids"] - - if hasattr(self.model_config.hf_config, "num_output_head"): - # duplicate the prompt_token_ids for each head - prompt_token_ids = [[i] * self.model_config.hf_config.num_output_head for i in prompt_token_ids] if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 55eba899b1bf0..fe94aa41fda84 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -742,15 +742,14 @@ def _extract_prompt_components( request_id=request_id, lora_request=lora_request, ) - - if hasattr(self.model_config.hf_config, "num_output_head"): - # duplicate the prompt_token_ids for each head - prompt_token_ids = [[i] * self.model_config.hf_config.num_output_head for i in prompt_token_ids] multi_modal_data = inputs.get("multi_modal_data") else: assert_never(inputs) + if hasattr(self.model_config.hf_config, "num_output_head"): + # duplicate the prompt_token_ids for each head + prompt_token_ids = [[i] * self.model_config.hf_config.num_output_head for i in prompt_token_ids] return prompt, prompt_token_ids, multi_modal_data def _apply_prompt_adapter( diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4c392cca4db3e..7f13ad8ff8627 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -87,7 +87,13 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, sample = outputs.samples[0] # only have one sequence seq = seq_group.seqs[0] - seq.append_token_id(sample.output_token, sample.logprobs) + # if output_tokens more than one, it's has multi-head output + if len(sample.output_tokens) > 1: + seq.append_token_id(sample.output_tokens, + sample.logprobs) + else: + seq.append_token_id(sample.output_token, + sample.logprobs) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 230e587cb67c4..b36cce21119b9 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -24,7 +24,6 @@ import lzma import numpy as np -import pybase16384 as b14 def dummy_data_for_ttsllm(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): diff --git a/vllm/multimodal/speech.py b/vllm/multimodal/speech.py index 9c995e3015f44..af4e04a507292 100644 --- a/vllm/multimodal/speech.py +++ b/vllm/multimodal/speech.py @@ -15,7 +15,6 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from .base import MultiModalInputs, MultiModalPlugin -import pybase16384 as b14 import base64 import pickle diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 625aad28ff944..15a8e4e75193f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -480,7 +480,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len - if isinstance(tokens, list): + if isinstance(tokens, list) and isinstance(tokens[0], list) == True: inter_data.input_tokens[seq_idx].extend(tokens) else: inter_data.input_tokens[seq_idx].append(tokens) @@ -699,7 +699,11 @@ def build(self) -> ModelInputForGPU: # Tokens and positions. if cuda_graph_pad_size: - input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) + if isinstance(input_tokens[0], list): + num_head = len(input_tokens[0]) + input_tokens.extend(itertools.repeat([0] * num_head, cuda_graph_pad_size)) + else: + input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) assert self.runner.device is not None input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, @@ -1208,7 +1212,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ - .dummy_data_for_profiling(self.model_config, max_batch_size) + .dummy_data_for_profiling(self.model_config, max_batch_size, self.mm_registry) input_tokens = torch.tensor(seq_data.prompt_token_ids, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() From 43252d9b5c4514a97e88a4c5644be681b4356884 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 21 Aug 2024 07:40:33 +0000 Subject: [PATCH 20/61] fix merge --- tts.py | 9 ++- tts_async.py | 10 ++- tts_fish.py | 77 +++++++++++++++++++ .../tokenizer_group/tokenizer_group.py | 2 +- 4 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 tts_fish.py diff --git a/tts.py b/tts.py index 291a3841a5bde..39bc0a92ecc8d 100644 --- a/tts.py +++ b/tts.py @@ -45,17 +45,22 @@ # # save the model # torch.save(tts, '/home/zhn/ttslm/GPT_merged_emb_nonorm.pt') -llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5) +llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, dtype=torch.float32, enforce_eager=True) prompts = [ { "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, }, { - "prompt": "[Stts][empty_spk][speed_5]Your text two[Ptts]", + "prompt": "[Stts][empty_spk][speed_5]Anther long string[Ptts]", "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, } ] + +for i in range(12): + prompts.append(prompts[0]) + prompts.append(prompts[1]) + sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) outputs = llm.generate(prompts, sampling_params) for output in outputs: diff --git a/tts_async.py b/tts_async.py index f1698bc16fb37..00abc5b44886e 100644 --- a/tts_async.py +++ b/tts_async.py @@ -15,23 +15,27 @@ } ] -engine_args = AsyncEngineArgs(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, dtype=torch.float32) +engine_args = AsyncEngineArgs(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, dtype=torch.float32, enforce_eager=True) model = AsyncLLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) async def generate_streaming(prompt, id): results_generator = model.generate(prompt, sampling_params, request_id=id) count=0 + tokens = [] async for request_output in results_generator: token_ids = request_output.outputs[0].token_ids print(f'{id} {[x - 21178 for x in token_ids[-1]]}') + tokens.append([x - 21178 for x in token_ids[-1]]) count+=1 - print(count) + print(prompt['prompt']) + for token in tokens: + print(token) async def generate(): tasks = [] - for i in range(5): + for i in range(10): t = generate_streaming(prompts[i%2], i) tasks.append(t) await asyncio.gather(*tasks) diff --git a/tts_fish.py b/tts_fish.py new file mode 100644 index 0000000000000..e2e482e517979 --- /dev/null +++ b/tts_fish.py @@ -0,0 +1,77 @@ +import torch +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams +llm = LLM(model='/home/zhn/xtts', gpu_memory_utilization=0.5, enforce_eager=True, add_bos_token=False, dtype=torch.float32) +prompts = [ + { + "prompt": "[zh-cn]ni3hao3", + } +] +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) +outputs = llm.generate(prompts, sampling_params) +for output in outputs: + print(output.prompt) + token_ids = output.outputs[0].token_ids + for token_id in token_ids: + print([x - 21178 for x in token_id]) +# tokenizer = AutoTokenizer.from_pretrained('/home/zhn/xtts', add_bos_token=False) +# id = tokenizer.encode('[zh-cn]ni3hao3') + +# torch.random.manual_seed(999) +# gpt = torch.load('/home/zhn/xtts/llama.pt') +# tts = torch.load('/home/zhn/xtts/checkpoint-902000.pt') + +# llama = tts['model']['llama'] +# layer_count = 24 +# for i in range(layer_count): +# name_qkv_0 = f'layers.{i}.attention.wqkv.weight' +# name_qkv_1 = f'gpt.layers.{i}.self_attn.qkv_proj.weight' +# llama[name_qkv_1] = llama.pop(name_qkv_0) + +# name_o_0 = f'layers.{i}.attention.wo.weight' +# name_o_1 = f'gpt.layers.{i}.self_attn.o_proj.weight' +# llama[name_o_1] = llama.pop(name_o_0) + +# name_gate_0 = f'layers.{i}.feed_forward.w1.weight' +# name_gate_1 = f'gpt.layers.{i}.mlp.gate_proj.weight' +# llama[name_gate_1] = llama.pop(name_gate_0) + +# name_up_0 = f'layers.{i}.feed_forward.w3.weight' +# name_up_1 = f'gpt.layers.{i}.mlp.up_proj.weight' +# llama[name_up_1] = llama.pop(name_up_0) + +# name_down_0 = f'layers.{i}.feed_forward.w2.weight' +# name_down_1 = f'gpt.layers.{i}.mlp.down_proj.weight' +# llama[name_down_1] = llama.pop(name_down_0) + +# name_ffn_norm_0 = f'layers.{i}.ffn_norm.weight' +# name_ffn_norm_1 = f'gpt.layers.{i}.input_layernorm.weight' +# llama[name_ffn_norm_1] = llama.pop(name_ffn_norm_0) + +# name_attn_norm_0 = f'layers.{i}.attention_norm.weight' +# name_attn_norm_1 = f'gpt.layers.{i}.post_attention_layernorm.weight' +# llama[name_attn_norm_1] = llama.pop(name_attn_norm_0) + +# name_norm_0 = f'norm.weight' +# name_norm_1 = f'gpt.norm.weight' +# llama[name_norm_1] = llama.pop(name_norm_0) + +# text_emb = llama['text_embeddings.weight'] +# code_emb_0 = llama['code_embeddings.weight'][0:1026, :] +# code_emb_1 = llama['code_embeddings.weight'][1026:2052, :] +# all_0 = torch.cat([text_emb, code_emb_0], dim=0) +# all_1 = torch.cat([torch.zeros_like(text_emb), code_emb_1], dim=0) +# llama['emb_all.0.weight'] = all_0 +# llama['emb_all.1.weight'] = all_1 +# llama.pop('text_embeddings.weight') +# llama.pop('code_embeddings.weight') + +# output = llama['output.weight'] +# lm_head = output[7002:, :] +# lm_head_0 = lm_head[0:1026, :] +# lm_head_1 = lm_head[1026:2052, :] +# llama['lm_head.0.weight'] = lm_head_0 +# llama['lm_head.1.weight'] = lm_head_1 +# llama.pop('output.weight') + +# torch.save(llama, '/home/zhn/xtts/llama.pt') diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index ad86f33a00e05..3e3776fcbbd49 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -66,7 +66,7 @@ async def encode_async( request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - ret = tokenizer.encode(prompt) + ret = tokenizer.encode(prompt, add_special_tokens=False) self._raise_if_input_too_long(ret, lora_request) return ret From dbf5f67b948776bfa1ada4ec2ffb04ff02daebb0 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 30 Aug 2024 12:38:13 +0000 Subject: [PATCH 21/61] fix broken llama --- tts_fish.py | 77 ------------------------------------- vllm/worker/model_runner.py | 4 +- 2 files changed, 2 insertions(+), 79 deletions(-) delete mode 100644 tts_fish.py diff --git a/tts_fish.py b/tts_fish.py deleted file mode 100644 index e2e482e517979..0000000000000 --- a/tts_fish.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch -from transformers import AutoTokenizer -from vllm import LLM, SamplingParams -llm = LLM(model='/home/zhn/xtts', gpu_memory_utilization=0.5, enforce_eager=True, add_bos_token=False, dtype=torch.float32) -prompts = [ - { - "prompt": "[zh-cn]ni3hao3", - } -] -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) -outputs = llm.generate(prompts, sampling_params) -for output in outputs: - print(output.prompt) - token_ids = output.outputs[0].token_ids - for token_id in token_ids: - print([x - 21178 for x in token_id]) -# tokenizer = AutoTokenizer.from_pretrained('/home/zhn/xtts', add_bos_token=False) -# id = tokenizer.encode('[zh-cn]ni3hao3') - -# torch.random.manual_seed(999) -# gpt = torch.load('/home/zhn/xtts/llama.pt') -# tts = torch.load('/home/zhn/xtts/checkpoint-902000.pt') - -# llama = tts['model']['llama'] -# layer_count = 24 -# for i in range(layer_count): -# name_qkv_0 = f'layers.{i}.attention.wqkv.weight' -# name_qkv_1 = f'gpt.layers.{i}.self_attn.qkv_proj.weight' -# llama[name_qkv_1] = llama.pop(name_qkv_0) - -# name_o_0 = f'layers.{i}.attention.wo.weight' -# name_o_1 = f'gpt.layers.{i}.self_attn.o_proj.weight' -# llama[name_o_1] = llama.pop(name_o_0) - -# name_gate_0 = f'layers.{i}.feed_forward.w1.weight' -# name_gate_1 = f'gpt.layers.{i}.mlp.gate_proj.weight' -# llama[name_gate_1] = llama.pop(name_gate_0) - -# name_up_0 = f'layers.{i}.feed_forward.w3.weight' -# name_up_1 = f'gpt.layers.{i}.mlp.up_proj.weight' -# llama[name_up_1] = llama.pop(name_up_0) - -# name_down_0 = f'layers.{i}.feed_forward.w2.weight' -# name_down_1 = f'gpt.layers.{i}.mlp.down_proj.weight' -# llama[name_down_1] = llama.pop(name_down_0) - -# name_ffn_norm_0 = f'layers.{i}.ffn_norm.weight' -# name_ffn_norm_1 = f'gpt.layers.{i}.input_layernorm.weight' -# llama[name_ffn_norm_1] = llama.pop(name_ffn_norm_0) - -# name_attn_norm_0 = f'layers.{i}.attention_norm.weight' -# name_attn_norm_1 = f'gpt.layers.{i}.post_attention_layernorm.weight' -# llama[name_attn_norm_1] = llama.pop(name_attn_norm_0) - -# name_norm_0 = f'norm.weight' -# name_norm_1 = f'gpt.norm.weight' -# llama[name_norm_1] = llama.pop(name_norm_0) - -# text_emb = llama['text_embeddings.weight'] -# code_emb_0 = llama['code_embeddings.weight'][0:1026, :] -# code_emb_1 = llama['code_embeddings.weight'][1026:2052, :] -# all_0 = torch.cat([text_emb, code_emb_0], dim=0) -# all_1 = torch.cat([torch.zeros_like(text_emb), code_emb_1], dim=0) -# llama['emb_all.0.weight'] = all_0 -# llama['emb_all.1.weight'] = all_1 -# llama.pop('text_embeddings.weight') -# llama.pop('code_embeddings.weight') - -# output = llama['output.weight'] -# lm_head = output[7002:, :] -# lm_head_0 = lm_head[0:1026, :] -# lm_head_1 = lm_head[1026:2052, :] -# llama['lm_head.0.weight'] = lm_head_0 -# llama['lm_head.1.weight'] = lm_head_1 -# llama.pop('output.weight') - -# torch.save(llama, '/home/zhn/xtts/llama.pt') diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 15a8e4e75193f..a499975869df1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -480,7 +480,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len - if isinstance(tokens, list) and isinstance(tokens[0], list) == True: + if isinstance(tokens, list) or hasattr(self.runner.model_config.hf_config, "num_output_head"): inter_data.input_tokens[seq_idx].extend(tokens) else: inter_data.input_tokens[seq_idx].append(tokens) @@ -699,7 +699,7 @@ def build(self) -> ModelInputForGPU: # Tokens and positions. if cuda_graph_pad_size: - if isinstance(input_tokens[0], list): + if hasattr(self.runner.model_config.hf_config, "num_output_head"): num_head = len(input_tokens[0]) input_tokens.extend(itertools.repeat([0] * num_head, cuda_graph_pad_size)) else: From 7e98539bec08042a75d508c1ca70ab3a73904198 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 30 Aug 2024 13:41:37 +0000 Subject: [PATCH 22/61] fix broken llama --- vllm/worker/model_runner.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 15a8e4e75193f..96bfda0cbee5f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -480,10 +480,16 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len - if isinstance(tokens, list) and isinstance(tokens[0], list) == True: - inter_data.input_tokens[seq_idx].extend(tokens) + if hasattr(self.runner.model_config.hf_config, "num_output_head"): + if inter_data.is_prompt: + inter_data.input_tokens[seq_idx].extend(tokens) + else: + inter_data.input_tokens[seq_idx].append(tokens) else: - inter_data.input_tokens[seq_idx].append(tokens) + if isinstance(tokens, list): + inter_data.input_tokens[seq_idx].extend(tokens) + else: + inter_data.input_tokens[seq_idx].append(tokens) if (seq_len - context_len) == 1: inter_data.input_positions[seq_idx].append(seq_len - 1) @@ -699,8 +705,8 @@ def build(self) -> ModelInputForGPU: # Tokens and positions. if cuda_graph_pad_size: - if isinstance(input_tokens[0], list): - num_head = len(input_tokens[0]) + if hasattr(self.runner.model_config.hf_config, "num_output_head"): + num_head = self.runner.model_config.hf_config.num_output_head input_tokens.extend(itertools.repeat([0] * num_head, cuda_graph_pad_size)) else: input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) From 350a758339234332a1a628576923339b71b6b039 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Sat, 31 Aug 2024 12:18:04 +0000 Subject: [PATCH 23/61] seperate emb --- vllm/engine/async_llm_engine.py | 4 --- vllm/engine/llm_engine.py | 3 -- vllm/model_executor/models/ttslm.py | 45 ++++++++++++++--------------- vllm/worker/model_runner.py | 1 + 4 files changed, 23 insertions(+), 30 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a96ca769babc5..a28b20fcbbcd8 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -347,10 +347,6 @@ async def _extract_prompt_components_async( else: assert_never(inputs) - if hasattr(self.model_config.hf_config, "num_output_head"): - # duplicate the prompt_token_ids for each head - prompt_token_ids = [[i] * self.model_config.hf_config.num_output_head for i in prompt_token_ids] - return prompt, prompt_token_ids, multi_modal_data async def _process_encoder_decoder_prompt_async( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fe94aa41fda84..979555eb6a05d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -747,9 +747,6 @@ def _extract_prompt_components( else: assert_never(inputs) - if hasattr(self.model_config.hf_config, "num_output_head"): - # duplicate the prompt_token_ids for each head - prompt_token_ids = [[i] * self.model_config.hf_config.num_output_head for i in prompt_token_ids] return prompt, prompt_token_ids, multi_modal_data def _apply_prompt_adapter( diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index b36cce21119b9..f330e533a25f2 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -30,7 +30,7 @@ def dummy_data_for_ttsllm(ctx: InputContext, seq_len: int, mm_counts: Mapping[st from vllm.sequence import SequenceData - dummy_seq_data = SequenceData([[0] * ctx.model_config.hf_config.num_output_head] * seq_len) + dummy_seq_data = SequenceData([0] * seq_len) dummy_multi_modal_data = {"audio": SpeechPlugin.sample_random_speaker()} return dummy_seq_data, dummy_multi_modal_data @@ -56,8 +56,9 @@ def __init__(self, self.gpt = LlamaModel(config) self.model_dim = self.gpt.config.hidden_size - self.emb_all = nn.ModuleList([ - VocabParallelEmbedding(self.num_audio_tokens + self.num_text_tokens, self.model_dim) for _ in range(self.num_output_head) + self.emb_text = VocabParallelEmbedding(self.num_text_tokens, self.model_dim) + self.emb_code = nn.ModuleList([ + VocabParallelEmbedding(self.num_audio_tokens, self.model_dim) for _ in range(self.num_output_head) ]) self.lm_head = nn.ModuleList([ @@ -97,12 +98,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): except KeyError: pass - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - code_emb = [ - self.emb_all[i](input_ids[:,i]) - for i in range(self.num_output_head) - ] - emb = torch.stack(code_emb, 2).sum(2) + def get_input_embeddings(self, input_ids: torch.Tensor, is_prompt: bool) -> torch.Tensor: + if is_prompt: + emb = self.emb_text(input_ids) + else: + code_emb = [ + self.emb_code[i](input_ids[:,i]) + for i in range(self.num_output_head) + ] + emb = torch.stack(code_emb, 2).sum(2) return emb def compute_logits(self, hidden_states: torch.Tensor, @@ -125,15 +129,6 @@ def sample( output = self.sampler(head_logits[i + 1], sampling_metadata) self.merge_sample_results(next_tokens, output) - for output in next_tokens.outputs: - for sample in output.samples: - sample.output_token += self.num_text_tokens - for i in range(self.num_output_head): - sample.output_tokens[i] += self.num_text_tokens - dic = {} - for k,v in sample.logprobs.items(): - dic[k + self.num_text_tokens] = v - sample.logprobs = dic return next_tokens def forward( @@ -144,12 +139,16 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + is_prompt: bool = True, **kwargs: object ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.get_input_embeddings(input_ids) - spk_emb = kwargs.get("speech", None) - if spk_emb is not None: - self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids, is_prompt) + spk_emb = kwargs.get("speech", None) + if spk_emb is not None: + self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) model_output = self.gpt( input_ids=input_ids, inputs_embeds=hidden_states, @@ -172,7 +171,7 @@ def apply_spk_emb( # convert spk_emb to the same dtype as emb spk_emb = spk_emb.to(emb.dtype) # find the index of the speaker token - indices = (input_ids[:,0] == self.spk_emb_token_id).nonzero(as_tuple=True) + indices = (input_ids == self.spk_emb_token_id).nonzero(as_tuple=True) if indices[0].size(0) == 0: return emb.index_put_(indices, spk_emb) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 96bfda0cbee5f..601a711d1f55c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1557,6 +1557,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, + is_prompt=model_input.is_prompt, **MultiModalInputs.as_kwargs(multi_modal_kwargs, device=self.device), **seqlen_agnostic_kwargs) From 4b263924a2ba48fc4ee873c251f1506352a525c7 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Sat, 31 Aug 2024 14:39:40 +0000 Subject: [PATCH 24/61] fix catpure run --- vllm/model_executor/models/ttslm.py | 2 +- vllm/worker/model_runner.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index f330e533a25f2..6893ace0e18a2 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -139,7 +139,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - is_prompt: bool = True, + is_prompt: bool = False, **kwargs: object ) -> Union[torch.Tensor, IntermediateTensors]: if inputs_embeds is not None: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 601a711d1f55c..ee2434d179474 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -815,7 +815,7 @@ def __init__( is_driver_worker: bool = False, prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, - return_hidden_states: bool = True, + return_hidden_states: bool = False, observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, @@ -1216,10 +1216,10 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - - seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ - .dummy_data_for_profiling(self.model_config, max_batch_size, self.mm_registry) - input_tokens = torch.tensor(seq_data.prompt_token_ids, dtype=torch.long).cuda() + + input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + if hasattr(self.model_config.hf_config, "num_output_head"): + input_tokens = torch.zeros(max_batch_size, self.model_config.hf_config.num_output_head, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) @@ -1606,8 +1606,8 @@ def execute_model( hidden_states = hidden_or_intermediate_states output.hidden_states = hidden_states - for i, o in enumerate(output): - o.hidden_state = hidden_states[i] + # for i, o in enumerate(output): + # o.hidden_state = hidden_states[i] return [output] From df4ba99faf975b0a17f4f6b59166b1b51d9e0b07 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Tue, 3 Sep 2024 13:46:10 +0000 Subject: [PATCH 25/61] support output hidden --- benchmarks/benchmark_tts.py | 6 ++-- testllama.py | 39 ++++++++++++++++----- tts.py | 8 ++--- vllm/engine/output_processor/single_step.py | 3 ++ vllm/model_executor/models/llama.py | 1 + vllm/outputs.py | 4 ++- vllm/sequence.py | 18 ++++++++++ vllm/worker/model_runner.py | 6 ++-- 8 files changed, 67 insertions(+), 18 deletions(-) diff --git a/benchmarks/benchmark_tts.py b/benchmarks/benchmark_tts.py index 965ad174bb18e..6d26d7647c039 100644 --- a/benchmarks/benchmark_tts.py +++ b/benchmarks/benchmark_tts.py @@ -302,13 +302,15 @@ def run_vllm( max_tokens=2048, top_k=1 )) - + print(prompts) + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + print("warmup done") start = time.perf_counter() outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() total_output_tokens = 0 for output in outputs: total_output_tokens += len(output.outputs[0].token_ids) - end = time.perf_counter() return end - start, total_output_tokens diff --git a/testllama.py b/testllama.py index a4b29088ddd68..200cf7779a2a9 100644 --- a/testllama.py +++ b/testllama.py @@ -1,13 +1,36 @@ +import torch +import time from vllm import LLM, SamplingParams -llm = LLM(model='/home/zhn/g/Meta-Llama-3-8B-Instruct') +torch.random.manual_seed(999) + +llm = LLM(model='/home/zhn/g/Meta-Llama-3-8B-Instruct', gpu_memory_utilization=0.5) prompts = [ "Hi my name is", - # "The capital of France is", + "Tell me a joke", ] -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -outputs = llm.generate(prompts, sampling_params) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + +texts = [] +start = time.time() +for i in range(10): + sampling_params = SamplingParams(temperature=0, top_k=1, max_tokens=200, top_p=1, repetition_penalty=0.9) + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + texts.append(generated_text) +end = time.time() +print(f"Time taken: {end - start:.2f}s") +# for text in texts: +# print(text) + +# for i in range(5): +# prompts.append(prompts[0]) +# prompts.append(prompts[1]) + +# sampling_params = SamplingParams(temperature=1, top_k=1, max_tokens=100) +# outputs = llm.generate(prompts, sampling_params) +# for output in outputs: +# prompt = output.prompt +# generated_text = output.outputs[0].text +# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file diff --git a/tts.py b/tts.py index 39bc0a92ecc8d..ff2a9a3f77636 100644 --- a/tts.py +++ b/tts.py @@ -45,7 +45,7 @@ # # save the model # torch.save(tts, '/home/zhn/ttslm/GPT_merged_emb_nonorm.pt') -llm = LLM(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, dtype=torch.float32, enforce_eager=True) +llm = LLM(model='/home/zhn/ttslm_dev', gpu_memory_utilization=0.5, dtype=torch.float16, enforce_eager=True) prompts = [ { "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", @@ -57,14 +57,14 @@ } ] -for i in range(12): +for i in range(0): prompts.append(prompts[0]) prompts.append(prompts[1]) -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[625], max_tokens=2048, top_k=1) outputs = llm.generate(prompts, sampling_params) for output in outputs: print(output.prompt) token_ids = output.outputs[0].token_ids for token_id in token_ids: - print([x - 21178 for x in token_id]) \ No newline at end of file + print([x - 0 for x in token_id]) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 7f13ad8ff8627..7956f76b5b25c 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -94,6 +94,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, else: seq.append_token_id(sample.output_token, sample.logprobs) + # store the hidden state if it have one + if outputs.hidden_state is not None: + seq.output_hiddens.append(outputs.hidden_state.clone()) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0c67a9b8e198b..91783164c3315 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -425,6 +425,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + is_prompt: bool = False, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) diff --git a/vllm/outputs.py b/vllm/outputs.py index db33276632630..6a01946a5de31 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -39,6 +39,7 @@ class CompletionOutput: finish_reason: Optional[str] = None stop_reason: Union[int, str, None] = None lora_request: Optional[LoRARequest] = None + hidden_states: Optional[List[torch.Tensor]] = None def finished(self) -> bool: return self.finish_reason is not None @@ -148,7 +149,8 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": seq.get_cumulative_logprob() if include_logprobs else None, seq.output_logprobs if include_logprobs else None, SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason) for seq in top_n_seqs + seq.stop_reason, + hidden_states=seq.output_hiddens) for seq in top_n_seqs ] # Every sequence in the sequence group should have the same prompt. diff --git a/vllm/sequence.py b/vllm/sequence.py index 945879a2f8ab7..ededd8b5a8493 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -158,6 +158,10 @@ def prompt_token_ids(self, new_prompt_token_ids) -> None: self._prompt_token_ids_tuple = tuple(new_prompt_token_ids) self._update_cached_all_tokens() + @property + def prompt_token_ids_array(self) -> array: + return self._prompt_token_ids + @property def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @@ -167,6 +171,10 @@ def output_token_ids(self, new_output_token_ids) -> None: self._output_token_ids = list(new_output_token_ids) self._update_cached_all_tokens() + @property + def output_token_ids_array(self) -> array: + return self._output_token_ids + def append_token_id(self, token_id: int, logprob: float) -> None: self._output_token_ids.append(token_id) self._cached_all_token_ids.append(token_id) @@ -900,6 +908,16 @@ def __eq__(self, other: object) -> bool: log_probs_equal = other.logprobs == self.logprobs return equal and log_probs_equal +class MultiHeadSequenceOutput(SequenceOutput): + def __init__( + self, + parent_seq_id: int, + output_tokens: List[int], + logprobs: Dict[int, Logprob], + ) -> None: + super().__init__(parent_seq_id, output_tokens[0], logprobs) + self.output_tokens = output_tokens + class SequenceGroupOutput(ABC): """The base class for model outputs associated with a sequence group.""" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ee2434d179474..4008167796cff 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -815,7 +815,7 @@ def __init__( is_driver_worker: bool = False, prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, - return_hidden_states: bool = False, + return_hidden_states: bool = True, observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, @@ -1606,8 +1606,8 @@ def execute_model( hidden_states = hidden_or_intermediate_states output.hidden_states = hidden_states - # for i, o in enumerate(output): - # o.hidden_state = hidden_states[i] + for i, o in enumerate(output): + o.hidden_state = hidden_states[i] return [output] From ceea94dc0c8a5c70b412efa2c93957b15b684da8 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Wed, 4 Sep 2024 06:21:48 +0000 Subject: [PATCH 26/61] Enable repetition_penalties maybe work?? --- tts_async.py | 10 +++--- vllm/model_executor/layers/sampler.py | 39 +++++++++++++++++------- vllm/model_executor/models/ttslm.py | 6 ++-- vllm/model_executor/sampling_metadata.py | 10 ++++-- vllm/sampling_params.py | 2 ++ 5 files changed, 46 insertions(+), 21 deletions(-) diff --git a/tts_async.py b/tts_async.py index 00abc5b44886e..7052b92cac8ea 100644 --- a/tts_async.py +++ b/tts_async.py @@ -15,9 +15,9 @@ } ] -engine_args = AsyncEngineArgs(model='/home/zhn/ttslm', gpu_memory_utilization=0.5, dtype=torch.float32, enforce_eager=True) +engine_args = AsyncEngineArgs(model='/home/zhn/ttslm_dev', gpu_memory_utilization=0.5, dtype=torch.float16) model = AsyncLLMEngine.from_engine_args(engine_args) -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[625], max_tokens=2048, top_k=1) async def generate_streaming(prompt, id): results_generator = model.generate(prompt, sampling_params, request_id=id) @@ -25,8 +25,8 @@ async def generate_streaming(prompt, id): tokens = [] async for request_output in results_generator: token_ids = request_output.outputs[0].token_ids - print(f'{id} {[x - 21178 for x in token_ids[-1]]}') - tokens.append([x - 21178 for x in token_ids[-1]]) + print(f'{id} {[x - 0 for x in token_ids[-1]]}') + tokens.append([x - 0 for x in token_ids[-1]]) count+=1 print(prompt['prompt']) @@ -35,7 +35,7 @@ async def generate_streaming(prompt, id): async def generate(): tasks = [] - for i in range(10): + for i in range(1): t = generate_streaming(prompts[i%2], i) tasks.append(t) await asyncio.gather(*tasks) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 41abdf211e7e7..8d8408edea525 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -44,7 +44,7 @@ class Sampler(nn.Module): in logits for each token in the input prompt. """ - def __init__(self): + def __init__(self, idx: int = -1): super().__init__() # Whether or not the SamplerOutput should have on-device tensors @@ -52,6 +52,7 @@ def __init__(self): # speculative decoding. self.include_gpu_probs_tensor = False self.should_modify_greedy_probs_inplace = False + self.head_idx = idx def _init_sampling_tensors( self, @@ -71,13 +72,14 @@ def _init_sampling_tensors( # Initialize new sampling tensors (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) + do_min_p, is_prompt) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype, head_idx=self.head_idx) self._sampling_tensors = sampling_tensors self._do_penalties = do_penalties self._do_top_p_top_k = do_top_p_top_k self._do_min_p = do_min_p + self._is_prompt = is_prompt def forward( self, @@ -107,16 +109,24 @@ def forward( do_penalties = self._do_penalties do_top_p_top_k = self._do_top_p_top_k do_min_p = self._do_min_p + is_prompt = self._is_prompt logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Apply presence and frequency penalties. if do_penalties: - logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) + if self.head_idx >= 0 and is_prompt: + # when multihead output and prompt phase, we do not apply penalties + # because the prompt tokens are not same as the output tokens + pass + else: + skip_prompt_repetition = self.head_idx >= 0 and not is_prompt + logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties, + skip_prompt_repetition=skip_prompt_repetition) # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. @@ -250,13 +260,20 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor) -> torch.Tensor: + repetition_penalties: torch.Tensor, + skip_prompt_repetition: bool = False) -> torch.Tensor: num_seqs, vocab_size = logits.shape - _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, - num_seqs) output_bin_counts, output_mask = _get_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) + if skip_prompt_repetition: + # when multihead output, we do not apply penalties for prompt tokens + # because the prompt tokens are not same as the output tokens + prompt_mask = torch.zeros((num_seqs, vocab_size), dtype=torch.bool, device=prompt_tokens_tensor.device) + else: + _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, + num_seqs) + repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) repetition_penalties[~(prompt_mask | output_mask)] = 1.0 logits = torch.where(logits > 0, logits / repetition_penalties, diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 6893ace0e18a2..065d0eb708684 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -65,7 +65,7 @@ def __init__(self, nn.Linear(self.model_dim, self.num_audio_tokens, bias=False) for _ in range(self.num_output_head) ]) self.logits_processor = LogitsProcessor(self.num_audio_tokens) - self.sampler = Sampler() + self.samplers = [Sampler(head_idx) for head_idx in range(self.num_output_head)] def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -124,9 +124,9 @@ def sample( sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: head_logits = logits.permute(1, 0, 2) - next_tokens = self.sampler(head_logits[0], sampling_metadata) + next_tokens = self.samplers[0](head_logits[0], sampling_metadata) for i in range(self.num_output_head - 1): - output = self.sampler(head_logits[i + 1], sampling_metadata) + output = self.samplers[i](head_logits[i + 1], sampling_metadata) self.merge_sample_results(next_tokens, output) return next_tokens diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 94b4b14416821..8acd6c6dadb68 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -391,6 +391,7 @@ def from_sampling_metadata( device: torch.device, dtype: torch.dtype, *, + head_idx: int = -1, extra_seeds_to_generate: int = 0, extra_entropy: Optional[Tuple[int, ...]] = None ) -> Tuple["SamplingTensors", bool, bool, bool]: @@ -413,6 +414,7 @@ def from_sampling_metadata( do_penalties = False do_top_p_top_k = False do_min_p = False + is_prompt = False if _USE_TRITON_SAMPLER: prompt_best_of: List[int] = [] @@ -501,6 +503,7 @@ def from_sampling_metadata( if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids + repetition_window = seq_group.sampling_params.repetition_window if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) @@ -512,14 +515,17 @@ def from_sampling_metadata( for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] prompt_tokens.append(seq_data.prompt_token_ids_array) - output_tokens.append(seq_data.output_token_ids_array) + if head_idx >= 0 and seq_group.is_prompt == False: + output_tokens.append([i[head_idx] for i in seq_data.output_token_ids_array[-repetition_window:]]) + else: + output_tokens.append(seq_data.output_token_ids_array) sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, frequency_penalties, repetition_penalties, sampling_seeds, sample_indices, prompt_tokens, output_tokens, vocab_size, extra_seeds_to_generate, device, dtype) - return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) + return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p, is_prompt) @classmethod def from_lists(cls, temperatures: List[float], top_ps: List[float], diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 04250c682cd23..5d4da2ccec323 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -119,6 +119,7 @@ def __init__( presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, + repetition_window: int = 16, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -146,6 +147,7 @@ def __init__( self.presence_penalty = presence_penalty self.frequency_penalty = frequency_penalty self.repetition_penalty = repetition_penalty + self.repetition_window = repetition_window if 0 < temperature < _MAX_TEMP: logger.warning( "temperature %s is less than %s, which may cause numerical " From 74a2edd2280dacf48dee403920b5482f71c46683 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Thu, 5 Sep 2024 03:08:42 +0000 Subject: [PATCH 27/61] implement multihead sampler --- benchmarks/benchmark_tts.py | 4 +- .../layers/multi_head_sampler.py | 234 ++++++++++++++++++ vllm/model_executor/models/ttslm.py | 20 +- 3 files changed, 249 insertions(+), 9 deletions(-) create mode 100644 vllm/model_executor/layers/multi_head_sampler.py diff --git a/benchmarks/benchmark_tts.py b/benchmarks/benchmark_tts.py index 6d26d7647c039..c65711fb8a91b 100644 --- a/benchmarks/benchmark_tts.py +++ b/benchmarks/benchmark_tts.py @@ -103,7 +103,7 @@ async def generate_streaming(llm: AsyncLLMEngine, request_func_input: RequestFun ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st - sampling_params = SamplingParams(n=1, temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1) + sampling_params = SamplingParams(n=1, temperature=1, detokenize=False, stop_token_ids=[625], max_tokens=2048, top_k=1) results_generator = llm.generate(request_func_input.prompt, sampling_params, request_id=request_id) async for request_output in results_generator: token_ids = request_output.outputs[0].token_ids @@ -298,7 +298,7 @@ def run_vllm( n=1, temperature=1, detokenize=False, - stop_token_ids=[21803], + stop_token_ids=[625], max_tokens=2048, top_k=1 )) diff --git a/vllm/model_executor/layers/multi_head_sampler.py b/vllm/model_executor/layers/multi_head_sampler.py new file mode 100644 index 0000000000000..0f896a897f3bd --- /dev/null +++ b/vllm/model_executor/layers/multi_head_sampler.py @@ -0,0 +1,234 @@ +"""A layer that samples the next tokens from the model's outputs.""" +from array import array +import itertools +from math import inf +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.model_executor.layers.sampler import _apply_top_k_top_p, _get_bin_counts_and_mask +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.model_executor.layers.ops.sample import sample as sample_triton + +from vllm.model_executor.sampling_metadata import (SamplingMetadata, + SamplingTensors, + SequenceGroupToSample) +from vllm.sampling_params import SamplingType +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SampleLogprobs, SamplerOutput, + SequenceOutput) +from vllm.utils import (PyObjectCache, async_tensor_h2d, + is_pin_memory_available, make_tensor_with_pad, + maybe_expand_dim) + +from einops import rearrange + +_SAMPLING_EPS = 1e-5 + +class MultiheadSampler(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + batch_size, num_heads, vocab_size = logits.size() + logits = logits.reshape(batch_size * num_heads, vocab_size) + + self._init_sampling_tensors(num_heads, vocab_size, logits, sampling_metadata) + sampling_tensors = self._sampling_tensors + do_penalties = self._do_penalties + do_top_p_top_k = self._do_top_p_top_k + is_prompt = self._is_prompt + + if not is_prompt and do_penalties: + logits = self._apply_penalties(logits, sampling_tensors.output_tokens, sampling_tensors.repetition_penalties) + + # Use float32 to apply temperature scaling. + # Use in-place division to avoid creating a new tensor. + logits = logits.to(torch.float) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) + + if do_top_p_top_k: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + + id_next = torch.multinomial(probs, 1).to(logits.device) + id_next = id_next.reshape(-1, num_heads).tolist() + return self.build_sampler_output(id_next, sampling_metadata) + + + def _init_sampling_tensors(self, + num_heads: int, + vocab_size: int, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata): + self._sampling_tensors = None + sampling_tensors, do_penalties, do_top_p_top_k, is_prompt = self.from_sampling_metadata( + num_heads, vocab_size, logits, sampling_metadata + ) + + self._sampling_tensors = sampling_tensors + self._do_penalties = do_penalties + self._do_top_p_top_k = do_top_p_top_k + self._is_prompt = is_prompt + + def from_sampling_metadata(self, + num_heads: int, + vocab_size: int, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Tuple["SamplingTensors", bool, bool, bool]: + dtype = logits.dtype + device = logits.device + + output_tokens: List[array] = [] + top_ks: List[int] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + repetition_penalties: List[float] = [] + do_penalties = False + do_top_p_top_k = False + is_prompt = False + + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + temperature = sampling_params.temperature + r = sampling_params.repetition_penalty + top_p = sampling_params.top_p + + # k should not be greater than the vocab size. + top_k = min(sampling_params.top_k, vocab_size) + top_k = vocab_size if top_k == -1 else top_k + if temperature < _SAMPLING_EPS: + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + temperature = 1.0 + if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS + or top_k != vocab_size): + do_top_p_top_k = True + if not do_penalties and (abs(r - 1.0) >= _SAMPLING_EPS): + do_penalties = True + + is_prompt = seq_group.is_prompt + if seq_group.do_sample: + sample_lens = len(seq_group.sample_indices) + assert sample_lens == len(seq_ids) + temperatures += [temperature] * len(seq_ids) * num_heads + top_ps += [top_p] * len(seq_ids) * num_heads + top_ks += [top_k] * len(seq_ids) * num_heads + repetition_penalties += [r] * len(seq_ids) * num_heads + + if do_penalties: + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + repetition_window = seq_group.sampling_params.repetition_window + if seq_group.do_sample: + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + token_ids_in_window = seq_data.output_token_ids_array[-repetition_window:] + if token_ids_in_window: + for head_id in range(num_heads): + output_tokens.append([row[head_id] for row in token_ids_in_window]) + + pin_memory = is_pin_memory_available() + if do_penalties: + output_t = make_tensor_with_pad( + output_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) + else: + empty_tensor = torch.empty(0, device=device, dtype=torch.long) + output_t = empty_tensor + + temperatures_t = torch.tensor( + temperatures, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ps_t = torch.tensor( + top_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + repetition_penalties_t = torch.tensor( + repetition_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ks_t = torch.tensor( + top_ks, + device="cpu", + dtype=torch.int, + pin_memory=pin_memory, + ) + + sampling_tensors = SamplingTensors( + output_tokens=output_t.to(device=device, non_blocking=True), + temperatures=temperatures_t.to(device=device, non_blocking=True), + top_ps=top_ps_t.to(device=device, non_blocking=True), + repetition_penalties=repetition_penalties_t.to(device=device, non_blocking=True), + top_ks=top_ks_t.to(device=device, non_blocking=True), + + min_ps=None, + presence_penalties=None, + frequency_penalties=None, + sampling_seeds=None, + sample_indices=None, + extra_seeds=None, + prompt_tokens=None + ) + + return (sampling_tensors, do_penalties, do_top_p_top_k, is_prompt) + + def _apply_penalties(self, logits: torch.Tensor, + output_tokens_tensor: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + num_seqs, vocab_size = logits.shape + output_bin_counts, output_mask = _get_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + + repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) + repetition_penalties[~(output_mask)] = 1.0 + logits = torch.where(logits > 0, logits / repetition_penalties, + logits * repetition_penalties) + + return logits + + def build_sampler_output(self, + sample_results: List[List[int]], + sampling_metadata: SamplingMetadata) -> SamplerOutput: + sampler_output: List[CompletionSequenceGroupOutput] = [] + for seq_group, sample_result in zip(sampling_metadata.seq_groups, sample_results): + seq_ids = seq_group.seq_ids + parent_id = 0 # no beam search for now + seq_outputs: List[SequenceOutput] = [] + log_prob = { sample_result[0]: Logprob(logprob=inf, rank=None, decoded_token=None) } + seq_output = SequenceOutput(seq_ids[parent_id], sample_result[0], log_prob) + seq_output.output_tokens = sample_result + seq_outputs.append(seq_output) + sampler_output.append(CompletionSequenceGroupOutput(seq_outputs, prompt_logprobs=None)) + + sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, None) + return SamplerOutput( + outputs=sampler_output, + sampled_token_probs=sampled_token_probs, + sampled_token_ids=sampled_token_ids, + logprobs=logprobs_tensor, + ) \ No newline at end of file diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 065d0eb708684..086c00f1a8b85 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -11,6 +11,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs.registry import InputContext from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.multi_head_sampler import MultiheadSampler from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding @@ -21,6 +22,8 @@ from vllm.multimodal.speech import SpeechPlugin from vllm.sequence import IntermediateTensors, SamplerOutput +from einops import rearrange +from transformers.generation import TopKLogitsWarper, TopPLogitsWarper import lzma import numpy as np @@ -65,7 +68,8 @@ def __init__(self, nn.Linear(self.model_dim, self.num_audio_tokens, bias=False) for _ in range(self.num_output_head) ]) self.logits_processor = LogitsProcessor(self.num_audio_tokens) - self.samplers = [Sampler(head_idx) for head_idx in range(self.num_output_head)] + self.sampler = MultiheadSampler() + # self.samplers = [Sampler(head_idx) for head_idx in range(self.num_output_head)] def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -123,12 +127,13 @@ def sample( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - head_logits = logits.permute(1, 0, 2) - next_tokens = self.samplers[0](head_logits[0], sampling_metadata) - for i in range(self.num_output_head - 1): - output = self.samplers[i](head_logits[i + 1], sampling_metadata) - self.merge_sample_results(next_tokens, output) + # head_logits = logits.permute(1, 0, 2) + # next_tokens = self.samplers[0](head_logits[0], sampling_metadata) + # for i in range(self.num_output_head - 1): + # output = self.samplers[i](head_logits[i + 1], sampling_metadata) + # self.merge_sample_results(next_tokens, output) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def forward( @@ -183,4 +188,5 @@ def merge_sample_results( ): for o_a, o_b in zip(source.outputs, target.outputs): for s_a, s_b in zip(o_a.samples, o_b.samples): - s_a.output_tokens.append(s_b.output_token) \ No newline at end of file + s_a.output_tokens.append(s_b.output_token) + \ No newline at end of file From 76195a56b9bd45b1d895f336a701ece9046ca06f Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Fri, 6 Sep 2024 13:06:29 +0000 Subject: [PATCH 28/61] for xptts --- tts_fish.py | 94 +++++++++++++++++++ .../model_executor/layers/rotary_embedding.py | 52 +++++++--- vllm/model_executor/models/llama.py | 38 ++++---- vllm/model_executor/models/ttslm.py | 21 +++-- 4 files changed, 162 insertions(+), 43 deletions(-) create mode 100644 tts_fish.py diff --git a/tts_fish.py b/tts_fish.py new file mode 100644 index 0000000000000..460aecf595dea --- /dev/null +++ b/tts_fish.py @@ -0,0 +1,94 @@ +from vllm import LLM, SamplingParams +import torch +torch.random.manual_seed(999) +# tts1 = torch.load('/home/zhn/ttslm_dev/GPT_merged_emb_nonorm.pt') +# tts2 = torch.load('/home/zhn/fishtts/checkpoint-1400000.bak') + +# llama = tts2['model']['llama'] + +# llama.pop('freqs_cis') +# llama.pop('causal_mask') + +# llama['emb_text.weight'] = llama['text_embeddings.weight'] +# llama.pop('text_embeddings.weight') + +# llama['emb_code.0.weight'] = llama['code_embeddings.weight'][0:1026] +# llama['emb_code.1.weight'] = llama['code_embeddings.weight'][1026:] +# llama.pop('code_embeddings.weight') + +# layer = 24 +# dim = 1536 +# for i in range(layer): +# qkv_name = f'layers.{i}.attention.wqkv.weight' +# q = llama[qkv_name][0:dim] +# k = llama[qkv_name][dim:2*dim] +# v = llama[qkv_name][2*dim:] +# llama[f'gpt.layers.{i}.self_attn.q_proj.weight'] = q +# llama[f'gpt.layers.{i}.self_attn.k_proj.weight'] = k +# llama[f'gpt.layers.{i}.self_attn.v_proj.weight'] = v +# llama.pop(qkv_name) + +# wo_name = f'layers.{i}.attention.wo.weight' +# wo = llama[wo_name] +# llama[f'gpt.layers.{i}.self_attn.o_proj.weight'] = wo +# llama.pop(wo_name) + +# gate_proj_name = f'layers.{i}.feed_forward.w1.weight' +# w_gate = llama[gate_proj_name] +# llama[f'gpt.layers.{i}.mlp.gate_proj.weight'] = w_gate +# llama.pop(gate_proj_name) + +# gate_up_proj_name = f'layers.{i}.feed_forward.w3.weight' +# w_gate_up = llama[gate_up_proj_name] +# llama[f'gpt.layers.{i}.mlp.up_proj.weight'] = w_gate_up +# llama.pop(gate_up_proj_name) + +# gate_down_proj_name = f'layers.{i}.feed_forward.w2.weight' +# w_gate_down = llama[gate_down_proj_name] +# llama[f'gpt.layers.{i}.mlp.down_proj.weight'] = w_gate_down +# llama.pop(gate_down_proj_name) + +# attn_norm_name = f'layers.{i}.attention_norm.weight' +# w_attn_norm = llama[attn_norm_name] +# llama[f'gpt.layers.{i}.input_layernorm.weight'] = w_attn_norm +# llama.pop(attn_norm_name) + +# ffn_norm_name = f'layers.{i}.ffn_norm.weight' +# w_ffn_norm = llama[ffn_norm_name] +# llama[f'gpt.layers.{i}.post_attention_layernorm.weight'] = w_ffn_norm +# llama.pop(ffn_norm_name) + + +# norm_name = 'norm.weight' +# w_norm = llama[norm_name] +# llama['gpt.norm.weight'] = w_norm +# llama.pop(norm_name) + +# output_name = 'output.weight' +# w_output = llama[output_name] +# llama['lm_head.0.weight'] = w_output[7002:7002+1026] +# llama['lm_head.1.weight'] = w_output[7002+1026:7002+1026*2] +# llama.pop(output_name) + +# torch.save(llama, '/home/zhn/fishtts/llama.pt') + +llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, enforce_eager=True) +prompts = [ + { + "prompt_token_ids": [7001, 5023, 16, 62, 4550, 4557, 4790, 4963, 7, 4676, 4697, 17, + 4549, 2719, 4546, 7, 435, 20, 4499, 37, 1164, 4561, 4637, 828, + 566, 4496, 7, 120, 14, 4695, 32, 4765, 4594, 4648, 4513, 4692, + 37, 1164, 4555, 100, 4544, 4680, 7, 38, 4706, 36, 566, 4498, + 4717, 30, 1164, 4596, 7, 4597, 4858, 475, 20, 4496, 37, 1164, + 4499, 7, 132, 4604, 17, 4610, 17, 4650, 4603, 14, 4596, 4938, + 4513, 0, 0] + } +] + +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1) +outputs = llm.generate(prompts, sampling_params) +for output in outputs: + print(output.prompt) + token_ids = output.outputs[0].token_ids + for token_id in token_ids: + print([x - 0 for x in token_id]) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 95888e7976ad3..8249216479cbb 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -56,6 +56,31 @@ def _apply_rotary_emb( -1).transpose(1, 2) return x_out +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> torch.Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + #xshaped = x.reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + #return x_out2 + return x_out2.type_as(x) class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" @@ -86,6 +111,15 @@ def __init__( cos, sin = cache.chunk(2, dim=-1) freqs_cis = cos + 1j * sin self.register_buffer("freqs_cis", freqs_cis, persistent=False) + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + self.max_position_embeddings, + self.rotary_dim, + self.base, + ), + ) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" @@ -187,14 +221,14 @@ def forward_native2( query = query.view(batch_size, seq_len, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb(query_rot, freqs_cis) + query_rot = apply_rotary_emb(query_rot, freqs_cis) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(batch_size, seq_len, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb(key_rot, freqs_cis) + key_rot = apply_rotary_emb(key_rot, freqs_cis) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -207,19 +241,7 @@ def forward_cuda( ) -> Tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(positions.device, - dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, - self.is_neox_style, self.rotary_dim, - offsets) - else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) - return query, key + return self.forward_native2(positions, query, key, offsets) def forward_xpu( self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 91783164c3315..72728e3cd8f5f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -25,6 +25,7 @@ import torch from torch import nn +import torch.nn.functional as F from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata @@ -82,12 +83,16 @@ def __init__( raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() + + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x + y1 = F.silu(self.gate_proj(x)) + y2 = self.up_proj(x) + y = y1 * y2 + return self.down_proj(y) class LlamaAttention(nn.Module): @@ -241,26 +246,17 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( + x = hidden_states + n = self.input_layernorm(hidden_states) + h = self.self_attn( positions=positions, - hidden_states=hidden_states, + hidden_states=n, kv_cache=kv_cache, attn_metadata=attn_metadata, ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - + h = x + h + out = h + self.mlp(self.post_attention_layernorm(h)) + return out, None class LlamaModel(nn.Module): @@ -340,7 +336,7 @@ def forward( "residual": residual }) - hidden_states, _ = self.norm(hidden_states, residual) + hidden_states = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 086c00f1a8b85..ab73addb81f23 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -77,8 +77,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), + # (".gate_up_proj", ".gate_proj", 0), + # (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: @@ -105,11 +105,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def get_input_embeddings(self, input_ids: torch.Tensor, is_prompt: bool) -> torch.Tensor: if is_prompt: emb = self.emb_text(input_ids) - else: + audio_start = torch.tensor([1024, 1022], device=input_ids.device) code_emb = [ - self.emb_code[i](input_ids[:,i]) + self.emb_code[i](audio_start[i]) for i in range(self.num_output_head) ] + code_emb = torch.stack(code_emb, 1).sum(1) + emb[-1] = code_emb + else: + code_emb = [ + self.emb_code[0](input_ids[:,0]), + self.emb_code[1](input_ids[:,1] - 2) + ] emb = torch.stack(code_emb, 2).sum(2) return emb @@ -151,9 +158,9 @@ def forward( hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids, is_prompt) - spk_emb = kwargs.get("speech", None) - if spk_emb is not None: - self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) + # spk_emb = kwargs.get("speech", None) + # if spk_emb is not None: + # self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) model_output = self.gpt( input_ids=input_ids, inputs_embeds=hidden_states, From 91fea92576f170cb09587001bb8fb65d0c9820b2 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Fri, 6 Sep 2024 15:28:27 +0000 Subject: [PATCH 29/61] fix llama --- fish.out | 106 ++++++++++++++++++ tts.out | 96 ++++++++++++++++ tts_fish.py | 4 +- .../model_executor/layers/rotary_embedding.py | 87 ++++++++++++-- vllm/model_executor/models/llama.py | 61 ++++++---- vllm/model_executor/models/ttslm.py | 6 +- 6 files changed, 319 insertions(+), 41 deletions(-) create mode 100644 fish.out create mode 100644 tts.out diff --git a/fish.out b/fish.out new file mode 100644 index 0000000000000..114b670438fc8 --- /dev/null +++ b/fish.out @@ -0,0 +1,106 @@ +[836, 54] +[378, 273] +[144, 426] +[637, 405] +[265, 658] +[375, 1018] +[411, 947] +[1004, 1016] +[254, 983] +[611, 60] +[761, 640] +[67, 427] +[726, 244] +[947, 375] +[502, 916] +[881, 500] +[303, 829] +[190, 787] +[245, 737] +[572, 629] +[997, 1004] +[54, 820] +[195, 967] +[783, 744] +[530, 683] +[52, 473] +[798, 476] +[761, 719] +[128, 427] +[601, 429] +[211, 321] +[885, 1005] +[364, 426] +[910, 882] +[1011, 99] +[974, 408] +[990, 610] +[997, 229] +[620, 646] +[373, 648] +[625, 362] +[572, 322] +[959, 927] +[22, 40] +[83, 379] +[132, 381] +[461, 194] +[467, 743] +[150, 89] +[553, 425] +[885, 43] +[1011, 91] +[293, 598] +[437, 395] +[978, 334] +[318, 845] +[836, 22] +[373, 377] +[806, 648] +[833, 355] +[632, 389] +[22, 357] +[324, 195] +[83, 415] +[467, 524] +[105, 650] +[132, 528] +[982, 151] +[643, 1020] +[656, 436] +[293, 117] +[148, 1023] +[256, 728] +[836, 91] +[373, 218] +[437, 371] +[1011, 400] +[343, 183] +[978, 17] +[422, 6] +[885, 162] +[319, 845] +[83, 382] +[404, 311] +[443, 485] +[105, 471] +[806, 146] +[22, 568] +[678, 765] +[467, 519] +[833, 662] +[836, 968] +[293, 547] +[373, 223] +[656, 240] +[688, 520] +[529, 474] +[437, 369] +[150, 117] +[885, 742] +[620, 828] +[1011, 508] +[318, 731] +[256, 220] +[407, 989] +[1025, 1025] diff --git a/tts.out b/tts.out new file mode 100644 index 0000000000000..68af2fd1eb420 --- /dev/null +++ b/tts.out @@ -0,0 +1,96 @@ +[Stts][empty_spk][speed_5]Your text one[Ptts] +[416, 441, 166, 216] +[416, 422, 166, 191] +[441, 546, 166, 156] +[416, 597, 165, 41] +[441, 297, 166, 331] +[441, 541, 166, 31] +[421, 472, 165, 192] +[446, 287, 167, 456] +[416, 596, 166, 163] +[421, 447, 166, 317] +[446, 287, 167, 456] +[416, 596, 166, 132] +[421, 462, 166, 185] +[416, 346, 166, 432] +[416, 546, 165, 166] +[446, 592, 341, 471] +[468, 597, 466, 467] +[466, 541, 466, 466] +[468, 467, 463, 466] +[461, 456, 407, 401] +[336, 207, 456, 531] +[343, 243, 427, 203] +[468, 473, 306, 456] +[571, 597, 341, 466] +[546, 543, 166, 165] +[411, 391, 166, 166] +[306, 207, 187, 168] +[281, 406, 312, 468] +[163, 467, 336, 216] +[313, 168, 433, 578] +[183, 31, 427, 581] +[180, 41, 408, 576] +[338, 218, 403, 531] +[316, 216, 336, 215] +[416, 416, 162, 166] +[406, 131, 156, 161] +[406, 391, 166, 31] +[406, 390, 162, 41] +[406, 380, 191, 216] +[318, 168, 341, 416] +[218, 222, 541, 591] +[215, 215, 436, 592] +[216, 81, 432, 418] +[366, 491, 456, 456] +[471, 465, 456, 456] +[343, 625, 432, 625] +[Stts][empty_spk][speed_5]Anther long string[Ptts] +[416, 291, 166, 216] +[416, 422, 166, 166] +[441, 546, 166, 131] +[421, 422, 165, 166] +[441, 422, 166, 281] +[416, 591, 166, 10] +[421, 347, 165, 411] +[441, 421, 166, 283] +[421, 561, 166, 136] +[421, 297, 166, 411] +[416, 596, 166, 258] +[421, 541, 166, 140] +[421, 591, 291, 168] +[421, 420, 291, 205] +[291, 221, 216, 156] +[338, 207, 408, 583] +[338, 467, 457, 453] +[341, 216, 460, 490] +[416, 467, 337, 208] +[281, 207, 337, 466] +[333, 207, 456, 456] +[343, 467, 407, 451] +[216, 241, 306, 156] +[216, 242, 331, 458] +[341, 242, 312, 218] +[341, 216, 341, 491] +[216, 367, 441, 223] +[216, 216, 342, 416] +[86, 330, 341, 416] +[216, 216, 338, 208] +[91, 340, 336, 456] +[216, 365, 312, 466] +[346, 472, 336, 481] +[321, 172, 336, 492] +[163, 158, 332, 206] +[406, 157, 186, 158] +[406, 416, 188, 218] +[406, 381, 167, 216] +[411, 516, 162, 43] +[406, 391, 308, 158] +[281, 167, 331, 465] +[343, 219, 431, 401] +[218, 223, 431, 151] +[218, 207, 331, 406] +[341, 235, 308, 408] +[216, 165, 207, 226] +[343, 418, 337, 416] +[625, 625, 625, 625] \ No newline at end of file diff --git a/tts_fish.py b/tts_fish.py index 460aecf595dea..8f37d71275359 100644 --- a/tts_fish.py +++ b/tts_fish.py @@ -72,7 +72,7 @@ # torch.save(llama, '/home/zhn/fishtts/llama.pt') -llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, enforce_eager=True) +llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32) prompts = [ { "prompt_token_ids": [7001, 5023, 16, 62, 4550, 4557, 4790, 4963, 7, 4676, 4697, 17, @@ -85,7 +85,7 @@ } ] -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1) +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) outputs = llm.generate(prompts, sampling_params) for output in outputs: print(output.prompt) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 8249216479cbb..45616af86b257 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -111,15 +111,6 @@ def __init__( cos, sin = cache.chunk(2, dim=-1) freqs_cis = cos + 1j * sin self.register_buffer("freqs_cis", freqs_cis, persistent=False) - - self.register_buffer( - "freqs_cis", - precompute_freqs_cis( - self.max_position_embeddings, - self.rotary_dim, - self.base, - ), - ) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" @@ -221,14 +212,14 @@ def forward_native2( query = query.view(batch_size, seq_len, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = apply_rotary_emb(query_rot, freqs_cis) + query_rot = _apply_rotary_emb(query_rot, freqs_cis) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(batch_size, seq_len, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_emb(key_rot, freqs_cis) + key_rot = _apply_rotary_emb(key_rot, freqs_cis) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -241,7 +232,19 @@ def forward_cuda( ) -> Tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops - return self.forward_native2(positions, query, key, offsets) + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key def forward_xpu( self, @@ -283,6 +286,66 @@ def extra_repr(self) -> str: s += f", base={self.base}, is_neox_style={self.is_neox_style}" return s +class XPRotaryEmbedding(nn.Module): + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype = None, + ) -> None: + super().__init__() + if dtype is None: + dtype = torch.get_default_dtype() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + self.max_position_embeddings, + self.rotary_dim, + self.base, + ), + ) + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if positions.dim() == 1: + batch_size = 1 + seq_len = positions.shape[0] + else: + batch_size, seq_len = positions.shape + if offsets is not None: + positions = positions + offsets + freqs_cis = self.freqs_cis.index_select(0, positions.flatten()) + freqs_cis = freqs_cis.view(batch_size, 1, seq_len, -1) + + query_shape = query.shape + query = query.view(batch_size, seq_len, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_emb(query_rot, freqs_cis) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(batch_size, seq_len, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_emb(key_rot, freqs_cis) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key class LinearScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with linear scaling. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 72728e3cd8f5f..0a80d30ad6743 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -42,7 +42,7 @@ QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) -from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding import XPRotaryEmbedding, get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -84,15 +84,15 @@ def __init__( "Only silu is supported for now.") self.act_fn = SiluAndMul() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + # self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + # self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + # self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, x): - y1 = F.silu(self.gate_proj(x)) - y2 = self.up_proj(x) - y = y1 * y2 - return self.down_proj(y) + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x class LlamaAttention(nn.Module): @@ -158,14 +158,19 @@ def __init__( if quant_config is not None and quant_config.get_name() == "gguf": is_neox_style = False - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=is_neox_style, - ) + if config.use_xp_rope: + self.rotary_emb = XPRotaryEmbedding( + self.head_dim, self.head_dim, max_position_embeddings, rope_theta, + is_neox_style) + else: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) self.attn = Attention(self.num_heads, self.head_dim, self.scaling, @@ -246,17 +251,25 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: - x = hidden_states - n = self.input_layernorm(hidden_states) - h = self.self_attn( + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( positions=positions, - hidden_states=n, + hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) - h = x + h - out = h + self.mlp(self.post_attention_layernorm(h)) - return out, None + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual class LlamaModel(nn.Module): @@ -336,7 +349,7 @@ def forward( "residual": residual }) - hidden_states = self.norm(hidden_states, residual) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index ab73addb81f23..70f1ef05fb7f5 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -77,8 +77,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), - # (".gate_up_proj", ".gate_proj", 0), - # (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: @@ -115,7 +115,7 @@ def get_input_embeddings(self, input_ids: torch.Tensor, is_prompt: bool) -> torc else: code_emb = [ self.emb_code[0](input_ids[:,0]), - self.emb_code[1](input_ids[:,1] - 2) + self.emb_code[1](input_ids[:,1]) ] emb = torch.stack(code_emb, 2).sum(2) return emb From 5c9d4a78d66c88a7d01d50111cfd10b8991e25f2 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Mon, 9 Sep 2024 03:39:28 +0000 Subject: [PATCH 30/61] fp32 done --- benchmarks/benchmark_tts.py | 29 ++- fish.out | 338 +++++++++++++++++++--------- tts.py | 6 +- tts_fish.py | 1 + vllm/model_executor/models/llama.py | 17 +- vllm/model_executor/models/ttslm.py | 6 +- vllm/worker/model_runner.py | 2 +- 7 files changed, 263 insertions(+), 136 deletions(-) diff --git a/benchmarks/benchmark_tts.py b/benchmarks/benchmark_tts.py index c65711fb8a91b..55bad76d7e19e 100644 --- a/benchmarks/benchmark_tts.py +++ b/benchmarks/benchmark_tts.py @@ -273,7 +273,7 @@ def run_vllm( tensor_parallel_size=tensor_parallel_size, seed=seed, trust_remote_code=trust_remote_code, - dtype=dtype, + dtype=torch.float32, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, enforce_eager=enforce_eager, @@ -289,22 +289,19 @@ def run_vllm( ) # Add the requests to the engine. - prompts: List[str] = [] - sampling_params: List[SamplingParams] = [] - for prompt, _, output_len in requests: - prompts.append(prompt) - sampling_params.append( - SamplingParams( - n=1, - temperature=1, - detokenize=False, - stop_token_ids=[625], - max_tokens=2048, - top_k=1 - )) + prompts = [ + { + "prompt_token_ids": [7001, 5023, 16, 62, 4550, 4557, 4790, 4963, 7, 4676, 4697, 17, + 4549, 2719, 4546, 7, 435, 20, 4499, 37, 1164, 4561, 4637, 828, + 566, 4496, 7, 120, 14, 4695, 32, 4765, 4594, 4648, 4513, 4692, + 37, 1164, 4555, 100, 4544, 4680, 7, 38, 4706, 36, 566, 4498, + 4717, 30, 1164, 4596, 7, 4597, 4858, 475, 20, 4496, 37, 1164, + 4499, 7, 132, 4604, 17, 4610, 17, 4650, 4603, 14, 4596, 4938, + 4513, 0, 0] + } + ] + sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) print(prompts) - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - print("warmup done") start = time.perf_counter() outputs = llm.generate(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() diff --git a/fish.out b/fish.out index 114b670438fc8..c3103d175a719 100644 --- a/fish.out +++ b/fish.out @@ -1,106 +1,236 @@ [836, 54] -[378, 273] -[144, 426] -[637, 405] -[265, 658] -[375, 1018] -[411, 947] -[1004, 1016] -[254, 983] -[611, 60] -[761, 640] -[67, 427] -[726, 244] -[947, 375] -[502, 916] -[881, 500] -[303, 829] -[190, 787] -[245, 737] -[572, 629] -[997, 1004] -[54, 820] -[195, 967] -[783, 744] -[530, 683] -[52, 473] -[798, 476] -[761, 719] -[128, 427] -[601, 429] -[211, 321] -[885, 1005] -[364, 426] -[910, 882] -[1011, 99] -[974, 408] -[990, 610] -[997, 229] -[620, 646] -[373, 648] -[625, 362] -[572, 322] -[959, 927] -[22, 40] -[83, 379] -[132, 381] -[461, 194] -[467, 743] -[150, 89] -[553, 425] -[885, 43] -[1011, 91] -[293, 598] -[437, 395] -[978, 334] -[318, 845] -[836, 22] -[373, 377] -[806, 648] -[833, 355] -[632, 389] -[22, 357] -[324, 195] -[83, 415] -[467, 524] -[105, 650] -[132, 528] -[982, 151] -[643, 1020] -[656, 436] -[293, 117] -[148, 1023] -[256, 728] -[836, 91] -[373, 218] -[437, 371] -[1011, 400] -[343, 183] -[978, 17] -[422, 6] -[885, 162] -[319, 845] -[83, 382] -[404, 311] -[443, 485] -[105, 471] -[806, 146] -[22, 568] -[678, 765] -[467, 519] -[833, 662] -[836, 968] -[293, 547] -[373, 223] -[656, 240] -[688, 520] -[529, 474] -[437, 369] -[150, 117] -[885, 742] -[620, 828] -[1011, 508] -[318, 731] -[256, 220] -[407, 989] +[521, 273] +[633, 81] +[771, 88] +[431, 330] +[1016, 665] +[573, 848] +[903, 513] +[582, 10] +[575, 460] +[783, 157] +[756, 293] +[52, 213] +[179, 807] +[869, 249] +[760, 916] +[729, 934] +[936, 386] +[986, 883] +[62, 486] +[830, 252] +[1020, 771] +[650, 425] +[632, 374] +[802, 470] +[140, 229] +[150, 605] +[964, 648] +[643, 529] +[73, 633] +[788, 880] +[179, 491] +[269, 766] +[485, 836] +[586, 923] +[983, 593] +[726, 336] +[935, 639] +[1010, 378] +[913, 869] +[883, 810] +[563, 925] +[411, 288] +[598, 90] +[850, 694] +[220, 78] +[514, 1013] +[443, 229] +[632, 255] +[802, 470] +[150, 504] +[964, 374] +[828, 656] +[410, 963] +[650, 408] +[83, 500] +[185, 146] +[523, 968] +[73, 202] +[254, 326] +[193, 47] +[179, 665] +[788, 892] +[485, 960] +[375, 534] +[7, 84] +[579, 653] +[953, 540] +[220, 588] +[810, 117] +[708, 906] +[246, 545] +[534, 377] +[464, 593] +[434, 88] +[817, 595] +[919, 459] +[135, 765] +[632, 255] +[802, 470] +[150, 648] +[140, 605] +[964, 435] +[982, 793] +[454, 815] +[638, 371] +[698, 117] +[68, 1003] +[106, 867] +[991, 597] +[440, 588] +[766, 526] +[449, 586] +[254, 53] +[521, 490] +[82, 386] +[256, 507] +[611, 81] +[767, 934] +[377, 281] +[885, 780] +[620, 885] +[513, 932] +[343, 223] +[643, 435] +[632, 470] +[802, 504] +[140, 648] +[150, 229] +[964, 374] +[828, 255] +[790, 737] +[650, 963] +[443, 382] +[581, 989] +[410, 408] +[456, 507] +[763, 346] +[974, 617] +[102, 461] +[902, 292] +[290, 353] +[196, 117] +[575, 787] +[880, 317] +[415, 82] +[622, 230] +[607, 286] +[503, 664] +[106, 624] +[445, 31] +[478, 876] +[782, 430] +[268, 405] +[770, 571] +[67, 305] +[571, 507] +[481, 202] +[583, 791] +[325, 857] +[592, 373] +[65, 667] +[1020, 881] +[632, 229] +[802, 470] +[150, 296] +[443, 605] +[828, 737] +[836, 910] +[530, 869] +[482, 467] +[943, 190] +[126, 142] +[179, 665] +[103, 599] +[241, 856] +[621, 940] +[583, 494] +[515, 202] +[135, 157] +[806, 346] +[744, 863] +[746, 703] +[817, 281] +[114, 135] +[82, 426] +[337, 894] +[759, 84] +[228, 160] +[885, 966] +[802, 435] +[632, 296] +[150, 470] +[964, 605] +[140, 737] +[828, 229] +[443, 989] +[410, 648] +[650, 374] +[790, 382] +[581, 793] +[456, 389] +[629, 461] +[343, 302] +[292, 122] +[766, 815] +[182, 559] +[99, 644] +[236, 91] +[231, 957] +[771, 704] +[462, 152] +[737, 380] +[196, 387] +[254, 610] +[480, 684] +[511, 246] +[400, 20] +[1023, 856] +[662, 801] +[698, 665] +[241, 117] +[633, 41] +[106, 454] +[492, 251] +[967, 306] +[783, 864] +[541, 2] +[29, 924] +[449, 876] +[54, 588] +[195, 235] +[752, 275] +[920, 593] +[276, 505] +[89, 336] +[524, 794] +[744, 373] +[641, 779] +[440, 551] +[411, 392] +[215, 849] +[172, 202] +[967, 293] +[761, 343] +[67, 467] +[488, 195] +[620, 847] +[885, 648] +[150, 605] +[964, 435] +[443, 470] +[828, 989] +[650, 374] +[581, 382] [1025, 1025] diff --git a/tts.py b/tts.py index ff2a9a3f77636..95a9218d9cc9a 100644 --- a/tts.py +++ b/tts.py @@ -1,7 +1,7 @@ from vllm import LLM, SamplingParams import torch torch.random.manual_seed(999) -# tts = torch.load('/home/zhn/ttslm/GPT_merged_emb_nonorm.pt') +# tts = torch.load('/home/zhn/ttslm_dev/GPT_merged_emb_nonorm.pt') # text_emb_count = tts['emb_text.weight'].shape[0] # audio_emb_count = tts['emb_code.0.weight'].shape[0] @@ -45,7 +45,7 @@ # # save the model # torch.save(tts, '/home/zhn/ttslm/GPT_merged_emb_nonorm.pt') -llm = LLM(model='/home/zhn/ttslm_dev', gpu_memory_utilization=0.5, dtype=torch.float16, enforce_eager=True) +llm = LLM(model='/home/zhn/ttslm_dev', gpu_memory_utilization=0.5, dtype=torch.float16) prompts = [ { "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", @@ -61,7 +61,7 @@ prompts.append(prompts[0]) prompts.append(prompts[1]) -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[625], max_tokens=2048, top_k=1) +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[625], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) outputs = llm.generate(prompts, sampling_params) for output in outputs: print(output.prompt) diff --git a/tts_fish.py b/tts_fish.py index 8f37d71275359..7bb72140c4ec7 100644 --- a/tts_fish.py +++ b/tts_fish.py @@ -92,3 +92,4 @@ token_ids = output.outputs[0].token_ids for token_id in token_ids: print([x - 0 for x in token_id]) + print(len(token_ids)) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0a80d30ad6743..0adaf3f002d8e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -83,17 +83,16 @@ def __init__( raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() - - # self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - # self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - # self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + def forward(self, x): + y1 = F.silu(self.gate_proj(x)) + y2 = self.up_proj(x) + y = y1 * y2 + return self.down_proj(y) class LlamaAttention(nn.Module): diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 70f1ef05fb7f5..ab73addb81f23 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -77,8 +77,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), + # (".gate_up_proj", ".gate_proj", 0), + # (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: @@ -115,7 +115,7 @@ def get_input_embeddings(self, input_ids: torch.Tensor, is_prompt: bool) -> torc else: code_emb = [ self.emb_code[0](input_ids[:,0]), - self.emb_code[1](input_ids[:,1]) + self.emb_code[1](input_ids[:,1] - 2) ] emb = torch.stack(code_emb, 2).sum(2) return emb diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4008167796cff..d2d281b721931 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1219,7 +1219,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() if hasattr(self.model_config.hf_config, "num_output_head"): - input_tokens = torch.zeros(max_batch_size, self.model_config.hf_config.num_output_head, dtype=torch.long).cuda() + input_tokens = torch.zeros(max_batch_size, self.model_config.hf_config.num_output_head, dtype=torch.long).cuda() + 2 input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) From 1eee0895072704ad991cbc9642a81c628d6852d5 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Mon, 9 Sep 2024 09:55:20 +0000 Subject: [PATCH 31/61] update tokenizer --- benchmarks/benchmark_tts.py | 2 +- tts_fish.py | 33 +++++++++++++++++++++------------ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/benchmarks/benchmark_tts.py b/benchmarks/benchmark_tts.py index 55bad76d7e19e..97ce75478e148 100644 --- a/benchmarks/benchmark_tts.py +++ b/benchmarks/benchmark_tts.py @@ -273,7 +273,7 @@ def run_vllm( tensor_parallel_size=tensor_parallel_size, seed=seed, trust_remote_code=trust_remote_code, - dtype=torch.float32, + dtype=dtype, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, enforce_eager=enforce_eager, diff --git a/tts_fish.py b/tts_fish.py index 7bb72140c4ec7..963a1291459ab 100644 --- a/tts_fish.py +++ b/tts_fish.py @@ -1,8 +1,10 @@ from vllm import LLM, SamplingParams +from tokenizers import Tokenizer +import pypinyin import torch torch.random.manual_seed(999) # tts1 = torch.load('/home/zhn/ttslm_dev/GPT_merged_emb_nonorm.pt') -# tts2 = torch.load('/home/zhn/fishtts/checkpoint-1400000.bak') +# tts2 = torch.load('/data/fishtts/checkpoint-1400000.bak') # llama = tts2['model']['llama'] @@ -70,19 +72,26 @@ # llama['lm_head.1.weight'] = w_output[7002+1026:7002+1026*2] # llama.pop(output_name) -# torch.save(llama, '/home/zhn/fishtts/llama.pt') +# torch.save(llama, '/data/fishtts/llama.pt') -llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32) +texts = [ + '城市霓虹,夜幕低垂,梦想之光,闪烁不已。心向未来,勇往直前,在星空下,奋斗的旋律。', + '在这个数字的世界里,你是我的唯一,爱情如同网络连接,无论距离多遥远。我们的心相互链接,在虚拟的空间中漫游,每条信息都是爱的表达,每个瞬间都是甜蜜的时刻。爱情不再是纸上文字,而是数码世界的交流,在屏幕上,我们相拥相视,你是我的电子爱情。'] +llm_inputs = [] +tokenizer = Tokenizer.from_file('/data/fishtts/vocab.json') +for text in texts: + pinyin = "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]) + txt = f"[zh-cn]{pinyin}" + txt = txt.replace(" ", "[SPACE]") + token_ids = tokenizer.encode(txt).ids + token_ids.insert(0, 7001) + token_ids.append(0) + token_ids.append(0) + llm_inputs.append(token_ids) + +llm = LLM(model='/data/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) prompts = [ - { - "prompt_token_ids": [7001, 5023, 16, 62, 4550, 4557, 4790, 4963, 7, 4676, 4697, 17, - 4549, 2719, 4546, 7, 435, 20, 4499, 37, 1164, 4561, 4637, 828, - 566, 4496, 7, 120, 14, 4695, 32, 4765, 4594, 4648, 4513, 4692, - 37, 1164, 4555, 100, 4544, 4680, 7, 38, 4706, 36, 566, 4498, - 4717, 30, 1164, 4596, 7, 4597, 4858, 475, 20, 4496, 37, 1164, - 4499, 7, 132, 4604, 17, 4610, 17, 4650, 4603, 14, 4596, 4938, - 4513, 0, 0] - } + {"prompt_token_ids": llm_input} for llm_input in llm_inputs ] sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) From 7636836d99f78bb904b1db8ca1ed1dada125c263 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Mon, 9 Sep 2024 14:22:35 +0000 Subject: [PATCH 32/61] fix batch greater than 1 --- tts_async.py | 38 +++++++++++-------- tts_fish.py | 29 ++++++++------ .../layers/multi_head_sampler.py | 1 + vllm/model_executor/models/ttslm.py | 28 ++++++-------- 4 files changed, 52 insertions(+), 44 deletions(-) diff --git a/tts_async.py b/tts_async.py index 7052b92cac8ea..66fdbfb90ed55 100644 --- a/tts_async.py +++ b/tts_async.py @@ -1,23 +1,31 @@ import asyncio import time - +from tokenizers import Tokenizer +import pypinyin import torch from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams -prompts = [ - { - "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", - "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, - }, - { - "prompt": "[Stts][empty_spk][speed_5]Your text two[Ptts]", - "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, - } -] +texts = [ + '城市霓虹,夜幕低垂,梦想之光,闪烁不已。心向未来,勇往直前,在星空下,奋斗的旋律。', + '在这个数字的世界里,你是我的唯一,爱情如同网络连接,无论距离多遥远。我们的心相互链接,在虚拟的空间中漫游,每条信息都是爱的表达,每个瞬间都是甜蜜的时刻。爱情不再是纸上文字,而是数码世界的交流,在屏幕上,我们相拥相视,你是我的电子爱情。'] +llm_inputs = [] +tokenizer = Tokenizer.from_file('/home/zhn/fishtts/vocab.json') +for text in texts: + pinyin = "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]) + txt = f"[zh-cn]{pinyin}" + txt = txt.replace(" ", "[SPACE]") + token_ids = tokenizer.encode(txt).ids + token_ids.insert(0, 7001) + token_ids.append(0) + token_ids.append(7003) + llm_inputs.append(token_ids) -engine_args = AsyncEngineArgs(model='/home/zhn/ttslm_dev', gpu_memory_utilization=0.5, dtype=torch.float16) +engine_args = AsyncEngineArgs(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True, enforce_eager=True) model = AsyncLLMEngine.from_engine_args(engine_args) -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[625], max_tokens=2048, top_k=1) +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) +prompts = [ + {"prompt_token_ids": llm_input} for llm_input in llm_inputs +] async def generate_streaming(prompt, id): results_generator = model.generate(prompt, sampling_params, request_id=id) @@ -29,13 +37,13 @@ async def generate_streaming(prompt, id): tokens.append([x - 0 for x in token_ids[-1]]) count+=1 - print(prompt['prompt']) + print(id) for token in tokens: print(token) async def generate(): tasks = [] - for i in range(1): + for i in range(2): t = generate_streaming(prompts[i%2], i) tasks.append(t) await asyncio.gather(*tasks) diff --git a/tts_fish.py b/tts_fish.py index 963a1291459ab..2358328396d02 100644 --- a/tts_fish.py +++ b/tts_fish.py @@ -4,22 +4,27 @@ import torch torch.random.manual_seed(999) # tts1 = torch.load('/home/zhn/ttslm_dev/GPT_merged_emb_nonorm.pt') -# tts2 = torch.load('/data/fishtts/checkpoint-1400000.bak') +# tts2 = torch.load('/home/zhn/fishtts/checkpoint-1400000.bak') +# layer = 24 +# dim = 1536 +# num_audio_tokens = 1026 +# num_text_tokens = 7002 # llama = tts2['model']['llama'] # llama.pop('freqs_cis') # llama.pop('causal_mask') -# llama['emb_text.weight'] = llama['text_embeddings.weight'] +# text_emb = llama['text_embeddings.weight'] +# for i in range(100): +# text_emb = torch.cat([text_emb, torch.zeros((1,dim), device=text_emb.device)], 0) +# llama['emb_text.weight'] = text_emb # llama.pop('text_embeddings.weight') -# llama['emb_code.0.weight'] = llama['code_embeddings.weight'][0:1026] -# llama['emb_code.1.weight'] = llama['code_embeddings.weight'][1026:] +# llama['emb_code.0.weight'] = llama['code_embeddings.weight'][0:num_audio_tokens] +# llama['emb_code.1.weight'] = llama['code_embeddings.weight'][num_audio_tokens-2:num_audio_tokens - 2 + num_audio_tokens] # llama.pop('code_embeddings.weight') -# layer = 24 -# dim = 1536 # for i in range(layer): # qkv_name = f'layers.{i}.attention.wqkv.weight' # q = llama[qkv_name][0:dim] @@ -68,17 +73,17 @@ # output_name = 'output.weight' # w_output = llama[output_name] -# llama['lm_head.0.weight'] = w_output[7002:7002+1026] -# llama['lm_head.1.weight'] = w_output[7002+1026:7002+1026*2] +# llama['lm_head.0.weight'] = w_output[num_text_tokens:num_text_tokens+num_audio_tokens] +# llama['lm_head.1.weight'] = w_output[num_text_tokens+num_audio_tokens:num_text_tokens+num_audio_tokens*2] # llama.pop(output_name) -# torch.save(llama, '/data/fishtts/llama.pt') +# torch.save(llama, '/home/zhn/fishtts/llama.pt') texts = [ '城市霓虹,夜幕低垂,梦想之光,闪烁不已。心向未来,勇往直前,在星空下,奋斗的旋律。', '在这个数字的世界里,你是我的唯一,爱情如同网络连接,无论距离多遥远。我们的心相互链接,在虚拟的空间中漫游,每条信息都是爱的表达,每个瞬间都是甜蜜的时刻。爱情不再是纸上文字,而是数码世界的交流,在屏幕上,我们相拥相视,你是我的电子爱情。'] llm_inputs = [] -tokenizer = Tokenizer.from_file('/data/fishtts/vocab.json') +tokenizer = Tokenizer.from_file('/home/zhn/fishtts/vocab.json') for text in texts: pinyin = "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]) txt = f"[zh-cn]{pinyin}" @@ -86,10 +91,10 @@ token_ids = tokenizer.encode(txt).ids token_ids.insert(0, 7001) token_ids.append(0) - token_ids.append(0) + token_ids.append(7003) llm_inputs.append(token_ids) -llm = LLM(model='/data/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) +llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True, enforce_eager=True) prompts = [ {"prompt_token_ids": llm_input} for llm_input in llm_inputs ] diff --git a/vllm/model_executor/layers/multi_head_sampler.py b/vllm/model_executor/layers/multi_head_sampler.py index 0f896a897f3bd..7b32b8a7f46ee 100644 --- a/vllm/model_executor/layers/multi_head_sampler.py +++ b/vllm/model_executor/layers/multi_head_sampler.py @@ -39,6 +39,7 @@ def forward( ) -> Optional[SamplerOutput]: batch_size, num_heads, vocab_size = logits.size() logits = logits.reshape(batch_size * num_heads, vocab_size) + logits[:, 0] = logits[:, 1] = -inf # Mask out the padding token. self._init_sampling_tensors(num_heads, vocab_size, logits, sampling_metadata) sampling_tensors = self._sampling_tensors diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index ab73addb81f23..9e9c3b4f53dfb 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -55,7 +55,7 @@ def __init__(self, self.num_audio_tokens = config.num_audio_tokens self.num_text_tokens = config.num_text_tokens self.num_output_head = config.num_output_head - self.spk_emb_token_id = 21143 + self.spk_emb_token_id = 7003 self.gpt = LlamaModel(config) self.model_dim = self.gpt.config.hidden_size @@ -105,17 +105,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def get_input_embeddings(self, input_ids: torch.Tensor, is_prompt: bool) -> torch.Tensor: if is_prompt: emb = self.emb_text(input_ids) - audio_start = torch.tensor([1024, 1022], device=input_ids.device) - code_emb = [ - self.emb_code[i](audio_start[i]) - for i in range(self.num_output_head) - ] - code_emb = torch.stack(code_emb, 1).sum(1) - emb[-1] = code_emb else: code_emb = [ - self.emb_code[0](input_ids[:,0]), - self.emb_code[1](input_ids[:,1] - 2) + self.emb_code[i](input_ids[:,i]) for i in range(self.num_output_head) ] emb = torch.stack(code_emb, 2).sum(2) return emb @@ -158,9 +150,8 @@ def forward( hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids, is_prompt) - # spk_emb = kwargs.get("speech", None) - # if spk_emb is not None: - # self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) + spk_emb = kwargs.get("speech", None) + self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) model_output = self.gpt( input_ids=input_ids, inputs_embeds=hidden_states, @@ -178,10 +169,13 @@ def apply_spk_emb( attn_metadata: AttentionMetadata, input_ids: torch.Tensor, ): - assert emb.size(1) == spk_emb.size(1) - assert attn_metadata.seq_lens_tensor.size(0) == spk_emb.size(0) - # convert spk_emb to the same dtype as emb - spk_emb = spk_emb.to(emb.dtype) + audio_start = torch.tensor([1024, 1024], device=input_ids.device) + code_emb = [ + self.emb_code[i](audio_start[i]) + for i in range(self.num_output_head) + ] + spk_emb = torch.stack(code_emb, 1).sum(1).to(emb.dtype) + # find the index of the speaker token indices = (input_ids == self.spk_emb_token_id).nonzero(as_tuple=True) if indices[0].size(0) == 0: From ee811bb209f660f46464fb9d6d198da3fb73f503 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Tue, 10 Sep 2024 05:14:47 +0000 Subject: [PATCH 33/61] fix capture bug --- vllm/model_executor/models/ttslm.py | 23 +++++++++++++++++------ vllm/worker/model_runner.py | 2 +- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 9e9c3b4f53dfb..8e843c637c2cc 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -55,7 +55,7 @@ def __init__(self, self.num_audio_tokens = config.num_audio_tokens self.num_text_tokens = config.num_text_tokens self.num_output_head = config.num_output_head - self.spk_emb_token_id = 7003 + self.audio_start_token_id = config.audio_start_token_id self.gpt = LlamaModel(config) self.model_dim = self.gpt.config.hidden_size @@ -105,6 +105,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def get_input_embeddings(self, input_ids: torch.Tensor, is_prompt: bool) -> torch.Tensor: if is_prompt: emb = self.emb_text(input_ids) + audio_start = torch.tensor([1024, 1024], device=input_ids.device) + code_emb = [ + self.emb_code[i](audio_start[i]) + for i in range(self.num_output_head) + ] + start_token = torch.stack(code_emb, 1).sum(1).to(emb.dtype) + + # find the index of the speaker token + indices = (input_ids == self.audio_start_token_id).nonzero(as_tuple=True) + if indices[0].size(0) != 0: + emb.index_put_(indices, start_token) else: code_emb = [ self.emb_code[i](input_ids[:,i]) for i in range(self.num_output_head) @@ -150,8 +161,8 @@ def forward( hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids, is_prompt) - spk_emb = kwargs.get("speech", None) - self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) + # spk_emb = kwargs.get("speech", None) + # self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) model_output = self.gpt( input_ids=input_ids, inputs_embeds=hidden_states, @@ -174,13 +185,13 @@ def apply_spk_emb( self.emb_code[i](audio_start[i]) for i in range(self.num_output_head) ] - spk_emb = torch.stack(code_emb, 1).sum(1).to(emb.dtype) + start_token = torch.stack(code_emb, 1).sum(1).to(emb.dtype) # find the index of the speaker token - indices = (input_ids == self.spk_emb_token_id).nonzero(as_tuple=True) + indices = (input_ids == self.audio_start_token_id).nonzero(as_tuple=True) if indices[0].size(0) == 0: return - emb.index_put_(indices, spk_emb) + emb.index_put_(indices, start_token) def merge_sample_results( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d2d281b721931..4008167796cff 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1219,7 +1219,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() if hasattr(self.model_config.hf_config, "num_output_head"): - input_tokens = torch.zeros(max_batch_size, self.model_config.hf_config.num_output_head, dtype=torch.long).cuda() + 2 + input_tokens = torch.zeros(max_batch_size, self.model_config.hf_config.num_output_head, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) From a63f01206d211b25607edebc9a2cfbbd7530c405 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Tue, 10 Sep 2024 05:48:18 +0000 Subject: [PATCH 34/61] update benchmark code --- benchmarks/benchmark_tts.py | 50 +++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/benchmarks/benchmark_tts.py b/benchmarks/benchmark_tts.py index 97ce75478e148..e36d39c9519cf 100644 --- a/benchmarks/benchmark_tts.py +++ b/benchmarks/benchmark_tts.py @@ -3,12 +3,15 @@ import asyncio from asyncio import tasks import json +import os import random import time from typing import List, Optional, Tuple, AsyncGenerator import warnings import numpy as np +import pypinyin +from tokenizers import Tokenizer import torch from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, @@ -41,7 +44,7 @@ def calculate_metrics( # Note : this may inflate the output token count slightly output_len = len(outputs[i].output_tokens) actual_output_lens.append(output_len) - total_input += input_requests[i][1] + total_input += len(input_requests[i][1]) if output_len > 1: tpots.append( (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) @@ -103,7 +106,7 @@ async def generate_streaming(llm: AsyncLLMEngine, request_func_input: RequestFun ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st - sampling_params = SamplingParams(n=1, temperature=1, detokenize=False, stop_token_ids=[625], max_tokens=2048, top_k=1) + sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) results_generator = llm.generate(request_func_input.prompt, sampling_params, request_id=request_id) async for request_output in results_generator: token_ids = request_output.outputs[0].token_ids @@ -182,12 +185,12 @@ async def run_vllm_async( tasks: List[asyncio.Task] = [] request_id = 0 async for request in get_request(requests, request_rate): - prompt, prompt_len, output_len = request + prompt, token_ids, output_len = request request_func_input = RequestFuncInput( api_url="", model=model, - prompt=prompt, - prompt_len=prompt_len, + prompt={"prompt_token_ids": token_ids}, + prompt_len=len(token_ids), output_len=output_len, use_beam_search=use_beam_search, ) @@ -289,19 +292,13 @@ def run_vllm( ) # Add the requests to the engine. - prompts = [ - { - "prompt_token_ids": [7001, 5023, 16, 62, 4550, 4557, 4790, 4963, 7, 4676, 4697, 17, - 4549, 2719, 4546, 7, 435, 20, 4499, 37, 1164, 4561, 4637, 828, - 566, 4496, 7, 120, 14, 4695, 32, 4765, 4594, 4648, 4513, 4692, - 37, 1164, 4555, 100, 4544, 4680, 7, 38, 4706, 36, 566, 4498, - 4717, 30, 1164, 4596, 7, 4597, 4858, 475, 20, 4496, 37, 1164, - 4499, 7, 132, 4604, 17, 4610, 17, 4650, 4603, 14, 4596, 4938, - 4513, 0, 0] - } - ] + prompts = [] + sampling_params: List[SamplingParams] = [] + for prompt, ids, output_len in requests: + prompts.append({"prompt_token_ids": ids}) + sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) - print(prompts) + start = time.perf_counter() outputs = llm.generate(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() @@ -375,13 +372,24 @@ def main(args: argparse.Namespace): random.seed(args.seed) # Sample the requests. - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code) lines = open(args.dataset).read().splitlines() - requests = [(f'[Stts][spk_emb][speed_5]{line}[Ptts]', len(tokenizer(line).input_ids), 2048) for line in lines] + requests = [] + # combin tokenizer path + + tokenizer = Tokenizer.from_file(os.path.join(args.tokenizer, 'vocab.json')) + for text in lines: + pinyin = "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]) + txt = f"[zh-cn]{pinyin}" + txt = txt.replace(" ", "[SPACE]") + token_ids = tokenizer.encode(txt).ids + token_ids.insert(0, 7001) + token_ids.append(0) + token_ids.append(7003) + requests.append((text, token_ids, 2048)) + requests = requests[:args.num_prompts] - total_input_tokens = sum(count for _, count, _ in requests) + total_input_tokens = sum(len(ids) for _, ids, _ in requests) if args.streaming: asyncio.run(run_vllm_async(requests, args.model, args.tokenizer, args.quantization, From 120693ad09f001a07b5afaf639fdbc9c9cf0c7f7 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Tue, 10 Sep 2024 08:04:00 +0000 Subject: [PATCH 35/61] update logits for first 2 tokens --- vllm/model_executor/layers/multi_head_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/multi_head_sampler.py b/vllm/model_executor/layers/multi_head_sampler.py index 7b32b8a7f46ee..0f896a897f3bd 100644 --- a/vllm/model_executor/layers/multi_head_sampler.py +++ b/vllm/model_executor/layers/multi_head_sampler.py @@ -39,7 +39,6 @@ def forward( ) -> Optional[SamplerOutput]: batch_size, num_heads, vocab_size = logits.size() logits = logits.reshape(batch_size * num_heads, vocab_size) - logits[:, 0] = logits[:, 1] = -inf # Mask out the padding token. self._init_sampling_tensors(num_heads, vocab_size, logits, sampling_metadata) sampling_tensors = self._sampling_tensors From b37a43190ec66853212434aa87b01c24e16aff01 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Tue, 10 Sep 2024 08:31:47 +0000 Subject: [PATCH 36/61] undo usless changes --- fish.out | 236 ----------------------- tts.out | 96 --------- vllm/model_executor/layers/sampler.py | 38 ++-- vllm/model_executor/sampling_metadata.py | 10 +- 4 files changed, 13 insertions(+), 367 deletions(-) delete mode 100644 fish.out delete mode 100644 tts.out diff --git a/fish.out b/fish.out deleted file mode 100644 index c3103d175a719..0000000000000 --- a/fish.out +++ /dev/null @@ -1,236 +0,0 @@ -[836, 54] -[521, 273] -[633, 81] -[771, 88] -[431, 330] -[1016, 665] -[573, 848] -[903, 513] -[582, 10] -[575, 460] -[783, 157] -[756, 293] -[52, 213] -[179, 807] -[869, 249] -[760, 916] -[729, 934] -[936, 386] -[986, 883] -[62, 486] -[830, 252] -[1020, 771] -[650, 425] -[632, 374] -[802, 470] -[140, 229] -[150, 605] -[964, 648] -[643, 529] -[73, 633] -[788, 880] -[179, 491] -[269, 766] -[485, 836] -[586, 923] -[983, 593] -[726, 336] -[935, 639] -[1010, 378] -[913, 869] -[883, 810] -[563, 925] -[411, 288] -[598, 90] -[850, 694] -[220, 78] -[514, 1013] -[443, 229] -[632, 255] -[802, 470] -[150, 504] -[964, 374] -[828, 656] -[410, 963] -[650, 408] -[83, 500] -[185, 146] -[523, 968] -[73, 202] -[254, 326] -[193, 47] -[179, 665] -[788, 892] -[485, 960] -[375, 534] -[7, 84] -[579, 653] -[953, 540] -[220, 588] -[810, 117] -[708, 906] -[246, 545] -[534, 377] -[464, 593] -[434, 88] -[817, 595] -[919, 459] -[135, 765] -[632, 255] -[802, 470] -[150, 648] -[140, 605] -[964, 435] -[982, 793] -[454, 815] -[638, 371] -[698, 117] -[68, 1003] -[106, 867] -[991, 597] -[440, 588] -[766, 526] -[449, 586] -[254, 53] -[521, 490] -[82, 386] -[256, 507] -[611, 81] -[767, 934] -[377, 281] -[885, 780] -[620, 885] -[513, 932] -[343, 223] -[643, 435] -[632, 470] -[802, 504] -[140, 648] -[150, 229] -[964, 374] -[828, 255] -[790, 737] -[650, 963] -[443, 382] -[581, 989] -[410, 408] -[456, 507] -[763, 346] -[974, 617] -[102, 461] -[902, 292] -[290, 353] -[196, 117] -[575, 787] -[880, 317] -[415, 82] -[622, 230] -[607, 286] -[503, 664] -[106, 624] -[445, 31] -[478, 876] -[782, 430] -[268, 405] -[770, 571] -[67, 305] -[571, 507] -[481, 202] -[583, 791] -[325, 857] -[592, 373] -[65, 667] -[1020, 881] -[632, 229] -[802, 470] -[150, 296] -[443, 605] -[828, 737] -[836, 910] -[530, 869] -[482, 467] -[943, 190] -[126, 142] -[179, 665] -[103, 599] -[241, 856] -[621, 940] -[583, 494] -[515, 202] -[135, 157] -[806, 346] -[744, 863] -[746, 703] -[817, 281] -[114, 135] -[82, 426] -[337, 894] -[759, 84] -[228, 160] -[885, 966] -[802, 435] -[632, 296] -[150, 470] -[964, 605] -[140, 737] -[828, 229] -[443, 989] -[410, 648] -[650, 374] -[790, 382] -[581, 793] -[456, 389] -[629, 461] -[343, 302] -[292, 122] -[766, 815] -[182, 559] -[99, 644] -[236, 91] -[231, 957] -[771, 704] -[462, 152] -[737, 380] -[196, 387] -[254, 610] -[480, 684] -[511, 246] -[400, 20] -[1023, 856] -[662, 801] -[698, 665] -[241, 117] -[633, 41] -[106, 454] -[492, 251] -[967, 306] -[783, 864] -[541, 2] -[29, 924] -[449, 876] -[54, 588] -[195, 235] -[752, 275] -[920, 593] -[276, 505] -[89, 336] -[524, 794] -[744, 373] -[641, 779] -[440, 551] -[411, 392] -[215, 849] -[172, 202] -[967, 293] -[761, 343] -[67, 467] -[488, 195] -[620, 847] -[885, 648] -[150, 605] -[964, 435] -[443, 470] -[828, 989] -[650, 374] -[581, 382] -[1025, 1025] diff --git a/tts.out b/tts.out deleted file mode 100644 index 68af2fd1eb420..0000000000000 --- a/tts.out +++ /dev/null @@ -1,96 +0,0 @@ -[Stts][empty_spk][speed_5]Your text one[Ptts] -[416, 441, 166, 216] -[416, 422, 166, 191] -[441, 546, 166, 156] -[416, 597, 165, 41] -[441, 297, 166, 331] -[441, 541, 166, 31] -[421, 472, 165, 192] -[446, 287, 167, 456] -[416, 596, 166, 163] -[421, 447, 166, 317] -[446, 287, 167, 456] -[416, 596, 166, 132] -[421, 462, 166, 185] -[416, 346, 166, 432] -[416, 546, 165, 166] -[446, 592, 341, 471] -[468, 597, 466, 467] -[466, 541, 466, 466] -[468, 467, 463, 466] -[461, 456, 407, 401] -[336, 207, 456, 531] -[343, 243, 427, 203] -[468, 473, 306, 456] -[571, 597, 341, 466] -[546, 543, 166, 165] -[411, 391, 166, 166] -[306, 207, 187, 168] -[281, 406, 312, 468] -[163, 467, 336, 216] -[313, 168, 433, 578] -[183, 31, 427, 581] -[180, 41, 408, 576] -[338, 218, 403, 531] -[316, 216, 336, 215] -[416, 416, 162, 166] -[406, 131, 156, 161] -[406, 391, 166, 31] -[406, 390, 162, 41] -[406, 380, 191, 216] -[318, 168, 341, 416] -[218, 222, 541, 591] -[215, 215, 436, 592] -[216, 81, 432, 418] -[366, 491, 456, 456] -[471, 465, 456, 456] -[343, 625, 432, 625] -[Stts][empty_spk][speed_5]Anther long string[Ptts] -[416, 291, 166, 216] -[416, 422, 166, 166] -[441, 546, 166, 131] -[421, 422, 165, 166] -[441, 422, 166, 281] -[416, 591, 166, 10] -[421, 347, 165, 411] -[441, 421, 166, 283] -[421, 561, 166, 136] -[421, 297, 166, 411] -[416, 596, 166, 258] -[421, 541, 166, 140] -[421, 591, 291, 168] -[421, 420, 291, 205] -[291, 221, 216, 156] -[338, 207, 408, 583] -[338, 467, 457, 453] -[341, 216, 460, 490] -[416, 467, 337, 208] -[281, 207, 337, 466] -[333, 207, 456, 456] -[343, 467, 407, 451] -[216, 241, 306, 156] -[216, 242, 331, 458] -[341, 242, 312, 218] -[341, 216, 341, 491] -[216, 367, 441, 223] -[216, 216, 342, 416] -[86, 330, 341, 416] -[216, 216, 338, 208] -[91, 340, 336, 456] -[216, 365, 312, 466] -[346, 472, 336, 481] -[321, 172, 336, 492] -[163, 158, 332, 206] -[406, 157, 186, 158] -[406, 416, 188, 218] -[406, 381, 167, 216] -[411, 516, 162, 43] -[406, 391, 308, 158] -[281, 167, 331, 465] -[343, 219, 431, 401] -[218, 223, 431, 151] -[218, 207, 331, 406] -[341, 235, 308, 408] -[216, 165, 207, 226] -[343, 418, 337, 416] -[625, 625, 625, 625] \ No newline at end of file diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 8d8408edea525..1d45b42e7608e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -44,7 +44,7 @@ class Sampler(nn.Module): in logits for each token in the input prompt. """ - def __init__(self, idx: int = -1): + def __init__(self): super().__init__() # Whether or not the SamplerOutput should have on-device tensors @@ -52,7 +52,6 @@ def __init__(self, idx: int = -1): # speculative decoding. self.include_gpu_probs_tensor = False self.should_modify_greedy_probs_inplace = False - self.head_idx = idx def _init_sampling_tensors( self, @@ -72,14 +71,13 @@ def _init_sampling_tensors( # Initialize new sampling tensors (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p, is_prompt) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype, head_idx=self.head_idx) + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) self._sampling_tensors = sampling_tensors self._do_penalties = do_penalties self._do_top_p_top_k = do_top_p_top_k self._do_min_p = do_min_p - self._is_prompt = is_prompt def forward( self, @@ -115,19 +113,11 @@ def forward( # Apply presence and frequency penalties. if do_penalties: - if self.head_idx >= 0 and is_prompt: - # when multihead output and prompt phase, we do not apply penalties - # because the prompt tokens are not same as the output tokens - pass - else: - skip_prompt_repetition = self.head_idx >= 0 and not is_prompt - logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties, - skip_prompt_repetition=skip_prompt_repetition) - + logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. logits = logits.to(torch.float) @@ -260,19 +250,13 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor, - skip_prompt_repetition: bool = False) -> torch.Tensor: + repetition_penalties: torch.Tensor) -> torch.Tensor: num_seqs, vocab_size = logits.shape output_bin_counts, output_mask = _get_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) - if skip_prompt_repetition: - # when multihead output, we do not apply penalties for prompt tokens - # because the prompt tokens are not same as the output tokens - prompt_mask = torch.zeros((num_seqs, vocab_size), dtype=torch.bool, device=prompt_tokens_tensor.device) - else: - _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, - num_seqs) + _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, + num_seqs) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) repetition_penalties[~(prompt_mask | output_mask)] = 1.0 diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 8acd6c6dadb68..94b4b14416821 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -391,7 +391,6 @@ def from_sampling_metadata( device: torch.device, dtype: torch.dtype, *, - head_idx: int = -1, extra_seeds_to_generate: int = 0, extra_entropy: Optional[Tuple[int, ...]] = None ) -> Tuple["SamplingTensors", bool, bool, bool]: @@ -414,7 +413,6 @@ def from_sampling_metadata( do_penalties = False do_top_p_top_k = False do_min_p = False - is_prompt = False if _USE_TRITON_SAMPLER: prompt_best_of: List[int] = [] @@ -503,7 +501,6 @@ def from_sampling_metadata( if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids - repetition_window = seq_group.sampling_params.repetition_window if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) @@ -515,17 +512,14 @@ def from_sampling_metadata( for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] prompt_tokens.append(seq_data.prompt_token_ids_array) - if head_idx >= 0 and seq_group.is_prompt == False: - output_tokens.append([i[head_idx] for i in seq_data.output_token_ids_array[-repetition_window:]]) - else: - output_tokens.append(seq_data.output_token_ids_array) + output_tokens.append(seq_data.output_token_ids_array) sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, frequency_penalties, repetition_penalties, sampling_seeds, sample_indices, prompt_tokens, output_tokens, vocab_size, extra_seeds_to_generate, device, dtype) - return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p, is_prompt) + return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) @classmethod def from_lists(cls, temperatures: List[float], top_ps: List[float], From 0e6701fff3b52b61110daaad0aea4124704143eb Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Tue, 10 Sep 2024 08:36:15 +0000 Subject: [PATCH 37/61] undo usless changes --- tts.py | 70 --------------------------- vllm/model_executor/layers/sampler.py | 7 ++- 2 files changed, 3 insertions(+), 74 deletions(-) delete mode 100644 tts.py diff --git a/tts.py b/tts.py deleted file mode 100644 index 95a9218d9cc9a..0000000000000 --- a/tts.py +++ /dev/null @@ -1,70 +0,0 @@ -from vllm import LLM, SamplingParams -import torch -torch.random.manual_seed(999) -# tts = torch.load('/home/zhn/ttslm_dev/GPT_merged_emb_nonorm.pt') - -# text_emb_count = tts['emb_text.weight'].shape[0] -# audio_emb_count = tts['emb_code.0.weight'].shape[0] -# model_dim = tts['emb_text.weight'].shape[1] - -# # append audio embeddings to text embeddings -# # all_0 = text_emb + audio_emb_0 -# all_0 = torch.cat([tts['emb_text.weight'], tts['emb_code.0.weight']], dim=0) - -# # all_1 = zero + audio_emb_1 -# all_1 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.1.weight']], dim=0) - -# # all_2 = zero + audio_emb_2 -# all_2 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.2.weight']], dim=0) - -# # all_3 = zero + audio_emb_3 -# all_3 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.3.weight']], dim=0) - -# # remove text emb and audio emb in the model -# tts.pop('emb_text.weight') -# tts.pop('emb_code.0.weight') -# tts.pop('emb_code.1.weight') -# tts.pop('emb_code.2.weight') -# tts.pop('emb_code.3.weight') - -# # add new embeddings to the model -# tts['emb_all.0.weight'] = all_0 -# tts['emb_all.1.weight'] = all_1 -# tts['emb_all.2.weight'] = all_2 -# tts['emb_all.3.weight'] = all_3 - -# for i in range(4): -# original0 = tts[f'head_code.{i}.parametrizations.weight.original0'] -# original1 = tts[f'head_code.{i}.parametrizations.weight.original1'] -# # get the normalized weights based on the original 0 and 1 -# weight_norm0 = torch._weight_norm(original1, original0, dim=0) -# tts.pop(f'head_code.{i}.parametrizations.weight.original0') -# tts.pop(f'head_code.{i}.parametrizations.weight.original1') -# tts[f'lm_head.{i}.weight'] = weight_norm0 - -# # save the model -# torch.save(tts, '/home/zhn/ttslm/GPT_merged_emb_nonorm.pt') - -llm = LLM(model='/home/zhn/ttslm_dev', gpu_memory_utilization=0.5, dtype=torch.float16) -prompts = [ - { - "prompt": "[Stts][empty_spk][speed_5]Your text one[Ptts]", - "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, - }, - { - "prompt": "[Stts][empty_spk][speed_5]Anther long string[Ptts]", - "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, - } -] - -for i in range(0): - prompts.append(prompts[0]) - prompts.append(prompts[1]) - -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[625], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) -outputs = llm.generate(prompts, sampling_params) -for output in outputs: - print(output.prompt) - token_ids = output.outputs[0].token_ids - for token_id in token_ids: - print([x - 0 for x in token_id]) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1d45b42e7608e..41abdf211e7e7 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -107,7 +107,6 @@ def forward( do_penalties = self._do_penalties do_top_p_top_k = self._do_top_p_top_k do_min_p = self._do_min_p - is_prompt = self._is_prompt logits = _apply_min_tokens_penalty(logits, sampling_metadata) @@ -118,6 +117,7 @@ def forward( sampling_tensors.presence_penalties, sampling_tensors.frequency_penalties, sampling_tensors.repetition_penalties) + # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. logits = logits.to(torch.float) @@ -252,11 +252,10 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, frequency_penalties: torch.Tensor, repetition_penalties: torch.Tensor) -> torch.Tensor: num_seqs, vocab_size = logits.shape - output_bin_counts, output_mask = _get_bin_counts_and_mask( - output_tokens_tensor, vocab_size, num_seqs) - _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, num_seqs) + output_bin_counts, output_mask = _get_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) repetition_penalties[~(prompt_mask | output_mask)] = 1.0 From a620e3322dd4265b058e6074b1fd769527c9f2df Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Tue, 10 Sep 2024 14:29:09 +0000 Subject: [PATCH 38/61] fix merge failures for fish --- benchmarks/benchmark_tts.py | 84 ++++++++++++++----- tts_async.py | 3 +- tts_fish.py | 4 +- vllm/engine/llm_engine.py | 7 +- .../layers/multi_head_sampler.py | 5 +- vllm/model_executor/models/ttslm.py | 5 +- vllm/multimodal/audio.py | 2 +- vllm/multimodal/registry.py | 2 +- vllm/multimodal/speech.py | 2 +- vllm/sampling_params.py | 1 + vllm/sequence.py | 18 ++-- 11 files changed, 89 insertions(+), 44 deletions(-) diff --git a/benchmarks/benchmark_tts.py b/benchmarks/benchmark_tts.py index e36d39c9519cf..d4be77e5ee623 100644 --- a/benchmarks/benchmark_tts.py +++ b/benchmarks/benchmark_tts.py @@ -2,11 +2,12 @@ import argparse import asyncio from asyncio import tasks +from datetime import datetime import json import os import random import time -from typing import List, Optional, Tuple, AsyncGenerator +from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator import warnings import numpy as np @@ -36,6 +37,7 @@ def calculate_metrics( itls: List[float] = [] tpots: List[float] = [] ttfts: List[float] = [] + e2els: List[float] = [] for i in range(len(outputs)): if outputs[i].success: # We use the tokenizer to count the number of output tokens for all @@ -50,6 +52,7 @@ def calculate_metrics( (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) itls += outputs[i].itl ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) completed += 1 else: actual_output_lens.append(0) @@ -59,26 +62,35 @@ def calculate_metrics( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", stacklevel=2) + selected_percentiles = [50, 95, 99] metrics = BenchmarkMetrics( completed=completed, total_input=total_input, total_output=sum(actual_output_lens), request_throughput=completed / dur_s, - input_throughput=total_input / dur_s, output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend - median_ttft_ms=np.median(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000, - p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], mean_tpot_ms=np.mean(tpots or 0) * 1000, - median_tpot_ms=np.median(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, - p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], mean_itl_ms=np.mean(itls or 0) * 1000, - median_itl_ms=np.median(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, - p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.median(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.mean(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], ) return metrics, actual_output_lens @@ -222,27 +234,53 @@ async def run_vllm_async( metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) - print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):", - metrics.input_throughput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) - print("{:<40} {:<10.2f}".format("Median TTFT (ms):", - metrics.median_ttft_ms)) - print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + for p, v in metrics.percentiles_ttft_ms: + print("{:<40} {:<10.2f}".format(f"{p}th percentile (ms):", v)) print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) - print("{:<40} {:<10.2f}".format("Median TPOT (ms):", - metrics.median_tpot_ms)) - print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) - print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) - print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) - print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) - print("=" * 50) + for p, v in metrics.percentiles_tpot_ms: + print("{:<40} {:<10.2f}".format(f"{p}th percentile (ms):", v)) + + benchmark_result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + + result_json: Dict[str, Any] = {} + + # Setup + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + result_json["date"] = current_dt + + # Traffic + result_json["request_rate"] = ( + request_rate if request_rate < float("inf") else "inf") + + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} + + # Save to file + base_model_id = 'xtts' + file_name = f"vllm-{request_rate}qps-{base_model_id}-{current_dt}.json" #noqa + # with open(file_name, "w") as outfile: + # json.dump(result_json, outfile) def run_vllm( requests: List[Tuple[str, int, int]], diff --git a/tts_async.py b/tts_async.py index 66fdbfb90ed55..36c0ba350fe64 100644 --- a/tts_async.py +++ b/tts_async.py @@ -20,7 +20,7 @@ token_ids.append(7003) llm_inputs.append(token_ids) -engine_args = AsyncEngineArgs(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True, enforce_eager=True) +engine_args = AsyncEngineArgs(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) model = AsyncLLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) prompts = [ @@ -38,6 +38,7 @@ async def generate_streaming(prompt, id): count+=1 print(id) + print(len(tokens)) for token in tokens: print(token) diff --git a/tts_fish.py b/tts_fish.py index 2358328396d02..7524714269270 100644 --- a/tts_fish.py +++ b/tts_fish.py @@ -81,7 +81,9 @@ texts = [ '城市霓虹,夜幕低垂,梦想之光,闪烁不已。心向未来,勇往直前,在星空下,奋斗的旋律。', - '在这个数字的世界里,你是我的唯一,爱情如同网络连接,无论距离多遥远。我们的心相互链接,在虚拟的空间中漫游,每条信息都是爱的表达,每个瞬间都是甜蜜的时刻。爱情不再是纸上文字,而是数码世界的交流,在屏幕上,我们相拥相视,你是我的电子爱情。'] + '在这个数字的世界里,你是我的唯一,爱情如同网络连接,无论距离多遥远。我们的心相互链接,在虚拟的空间中漫游,每条信息都是爱的表达,每个瞬间都是甜蜜的时刻。爱情不再是纸上文字,而是数码世界的交流,在屏幕上,我们相拥相视,你是我的电子爱情。', + '城市霓虹,夜幕低垂,梦想之光,闪烁不已。心向未来,勇往直前,在星空下,奋斗的旋律。' + ] llm_inputs = [] tokenizer = Tokenizer.from_file('/home/zhn/fishtts/vocab.json') for text in texts: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 94271c4a93151..889a83869c59a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1469,7 +1469,12 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] - seq.append_token_id(sample.output_token, sample.logprobs) + if len(sample.output_tokens) > 1: + seq.append_token_id(sample.output_tokens, + sample.logprobs) + else: + seq.append_token_id(sample.output_token, + sample.logprobs) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/model_executor/layers/multi_head_sampler.py b/vllm/model_executor/layers/multi_head_sampler.py index 0f896a897f3bd..38d4114812b27 100644 --- a/vllm/model_executor/layers/multi_head_sampler.py +++ b/vllm/model_executor/layers/multi_head_sampler.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn -from vllm.model_executor.layers.sampler import _apply_top_k_top_p, _get_bin_counts_and_mask +from vllm.model_executor.layers.sampler import SamplerOutput, _apply_top_k_top_p, _get_bin_counts_and_mask from vllm.triton_utils import HAS_TRITON if HAS_TRITON: @@ -18,8 +18,7 @@ SequenceGroupToSample) from vllm.sampling_params import SamplingType from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - PromptLogprobs, SampleLogprobs, SamplerOutput, - SequenceOutput) + PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.utils import (PyObjectCache, async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 8e843c637c2cc..654eccd611c3d 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -1,3 +1,4 @@ +from array import array from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union import torch @@ -13,14 +14,14 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.multi_head_sampler import MultiheadSampler from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.speech import SpeechPlugin -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors from einops import rearrange from transformers.generation import TopKLogitsWarper, TopPLogitsWarper diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index b4bf4b4541db8..fe2406bd12f54 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -6,7 +6,7 @@ class AudioPlugin(MultiModalPlugin): """Plugin for audio data.""" def get_data_key(self) -> str: - return "audio" + return "audio1" def _default_input_mapper(self, ctx: InputContext, data: object) -> MultiModalInputs: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 09cc4fbdbfc10..4cacfc6da62c4 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -233,7 +233,7 @@ def init_mm_limits_per_prompt( key: config_limits_per_plugin.get(key, 1) for key in self._plugins } - + limits_per_plugin['audio'] = 1 self._limits_by_model[model_config] = limits_per_plugin def get_mm_limits_per_prompt( diff --git a/vllm/multimodal/speech.py b/vllm/multimodal/speech.py index af4e04a507292..4b2dfbcf90592 100644 --- a/vllm/multimodal/speech.py +++ b/vllm/multimodal/speech.py @@ -21,7 +21,7 @@ class FishSpeechPlugin(MultiModalPlugin): def get_data_key(self) -> str: - return "audio1" + return "audio" def _default_input_mapper(self, ctx: InputContext, data: object) -> MultiModalInputs: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 647bbc27e1af7..29025a86ac771 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -121,6 +121,7 @@ class SamplingParams( presence_penalty: float = 0.0 frequency_penalty: float = 0.0 repetition_penalty: float = 1.0 + repetition_window: int = 16 temperature: float = 1.0 top_p: float = 1.0 top_k: int = -1 diff --git a/vllm/sequence.py b/vllm/sequence.py index 39f081a558f3a..72a35040f4edd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -148,9 +148,8 @@ class SequenceData(msgspec.Struct, """ # NOTE: we cannot use Union[List, array] because msgspec cannot support # union of 2 list types. - _prompt_token_ids: array - _output_token_ids: array = msgspec.field( - default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) + _prompt_token_ids: list + _output_token_ids: list = msgspec.field(default_factory=list) ### The below fields should not be passed as an argument ### _cumulative_logprob: float = 0.0 @@ -166,8 +165,6 @@ class SequenceData(msgspec.Struct, _new_appended_tokens: List[int] = msgspec.field(default_factory=list) def __post_init__(self) -> None: - assert self._prompt_token_ids.typecode == "l" - assert self._output_token_ids.typecode == "l" self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( self._prompt_token_ids) self._update_cached_all_tokens() @@ -207,13 +204,13 @@ def output_token_ids(self, new_output_token_ids) -> None: self._update_cached_all_tokens() @property - def output_token_ids_array(self) -> array: + def output_token_ids_array(self) -> list: """Return the prompt token ids in array type. Note that the array is in "I" type, and it is not compatible with torch.long (2 bytes vs 4 bytes). So beware of the usage. """ - assert isinstance(self._output_token_ids, array) + assert isinstance(self._output_token_ids, list) return self._output_token_ids def append_token_id(self, token_id: int, logprob: float) -> None: @@ -385,8 +382,7 @@ def __init__( f"invalid input {inputs}; did you forget the " "encoder input prompt fields?") - self.data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids)) + self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_hiddens: List[torch.Tensor] = [] self.output_text = "" @@ -474,11 +470,11 @@ def reset_state_for_recompute(self): def append_token_id(self, token_id: int, logprobs: Dict[int, Logprob]) -> None: - assert token_id in logprobs self.output_logprobs.append(logprobs) if isinstance(token_id, List): self.data.append_token_id(token_id, logprobs[token_id[0]].logprob) else: + assert token_id in logprobs self.data.append_token_id(token_id, logprobs[token_id].logprob) def append_token_ids( @@ -971,6 +967,7 @@ class SequenceOutput( parent_seq_id: int output_token: int logprobs: Dict[int, Logprob] + output_tokens: List[int] = None def __repr__(self) -> str: return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " @@ -1006,6 +1003,7 @@ class CompletionSequenceGroupOutput( samples: List[SequenceOutput] # Prompt logprob for each prompt query token. prompt_logprobs: Optional[PromptLogprobs] + hidden_state: Optional[torch.Tensor] = None def __repr__(self) -> str: return (f"CompletionSequenceGroupOutput(samples={self.samples}, " From 52d49d43d3d02edbc544e1b7ef534bdcb5c1cd96 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Tue, 10 Sep 2024 16:06:38 +0000 Subject: [PATCH 39/61] NOT FINISHED: num_scheduler_steps for multi head sampler --- testllama.py | 2 +- tts_fish.py | 2 +- vllm/inputs/registry.py | 3 +- .../layers/multi_head_sampler.py | 93 ++++++++++++++++--- vllm/model_executor/models/llama.py | 2 +- 5 files changed, 82 insertions(+), 20 deletions(-) diff --git a/testllama.py b/testllama.py index 200cf7779a2a9..3e985c80f4465 100644 --- a/testllama.py +++ b/testllama.py @@ -4,7 +4,7 @@ torch.random.manual_seed(999) -llm = LLM(model='/home/zhn/g/Meta-Llama-3-8B-Instruct', gpu_memory_utilization=0.5) +llm = LLM(model='/home/zhn/g/Meta-Llama-3-8B-Instruct', gpu_memory_utilization=0.5, enforce_eager=True, num_scheduler_steps=8) prompts = [ "Hi my name is", "Tell me a joke", diff --git a/tts_fish.py b/tts_fish.py index 7524714269270..d47b8eb6786dd 100644 --- a/tts_fish.py +++ b/tts_fish.py @@ -96,7 +96,7 @@ token_ids.append(7003) llm_inputs.append(token_ids) -llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True, enforce_eager=True) +llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True, enforce_eager=True, num_scheduler_steps=8) prompts = [ {"prompt_token_ids": llm_input} for llm_input in llm_inputs ] diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ae6c6c05d9f72..e60ec2026b99f 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -130,8 +130,7 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceData - dummy_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len) + dummy_seq_data = SequenceData([0] * seq_len) dummy_multi_modal_data = None return dummy_seq_data, dummy_multi_modal_data diff --git a/vllm/model_executor/layers/multi_head_sampler.py b/vllm/model_executor/layers/multi_head_sampler.py index 38d4114812b27..e0bb30d37863c 100644 --- a/vllm/model_executor/layers/multi_head_sampler.py +++ b/vllm/model_executor/layers/multi_head_sampler.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn -from vllm.model_executor.layers.sampler import SamplerOutput, _apply_top_k_top_p, _get_bin_counts_and_mask +from vllm.model_executor.layers.sampler import MaybeDeferredSampleResultType, SampleResultArgsType, SampleReturnType, SamplerOutput, _apply_top_k_top_p, _get_bin_counts_and_mask from vllm.triton_utils import HAS_TRITON if HAS_TRITON: @@ -30,6 +30,10 @@ class MultiheadSampler(nn.Module): def __init__(self): super().__init__() + # Whether or not the SamplerOutput should have on-device tensors + # containing the sampled token ids and probabilities. This is used by + # speculative decoding. + self.include_gpu_probs_tensor = False def forward( self, @@ -60,11 +64,57 @@ def forward( # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - id_next = torch.multinomial(probs, 1).to(logits.device) - id_next = id_next.reshape(-1, num_heads).tolist() - return self.build_sampler_output(id_next, sampling_metadata) + # Sample the next tokens. + maybe_deferred_sample_results, maybe_sampled_tokens_tensor = self._sample( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=self.include_gpu_probs_tensor, + modify_greedy_probs=False + ) + + id_next = maybe_sampled_tokens_tensor.reshape(-1, num_heads).tolist() + + if self.include_gpu_probs_tensor: + # Since we will defer sampler result Pythonization, + # preserve GPU-side tensors in support of later + # deferred pythonization of logprobs + sampled_token_ids_tensor = maybe_sampled_tokens_tensor.to(dtype=torch.long, device=probs.device) + on_device_tensors = (probs, logprobs, sampled_token_ids_tensor) + else: + # Since Pythonization has already happened, don't preserve + # GPU-side tensors. + on_device_tensors = None + + return self.build_sampler_output(id_next, sampling_metadata, + on_device_tensors=on_device_tensors, + maybe_deferred_sample_results=maybe_deferred_sample_results, + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) + + def _sample( + self, + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, + ) -> SampleReturnType: + id_next_tensor = torch.multinomial(probs, 1).to(dtype=torch.long, device=probs.device) + + maybe_deferred_args = SampleResultArgsType( + sampling_metadata=sampling_metadata, + sample_metadata=None, + multinomial_samples=None, + greedy_samples=None, + beam_search_logprobs=None, + sample_results_dict={}) + return id_next_tensor, id_next_tensor def _init_sampling_tensors(self, num_heads: int, @@ -212,22 +262,35 @@ def _apply_penalties(self, logits: torch.Tensor, def build_sampler_output(self, sample_results: List[List[int]], - sampling_metadata: SamplingMetadata) -> SamplerOutput: + sampling_metadata: SamplingMetadata, + maybe_deferred_sample_results: MaybeDeferredSampleResultType = None, + on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,torch.Tensor]] = None, + skip_sampler_cpu_output: bool = False) -> SamplerOutput: sampler_output: List[CompletionSequenceGroupOutput] = [] - for seq_group, sample_result in zip(sampling_metadata.seq_groups, sample_results): - seq_ids = seq_group.seq_ids - parent_id = 0 # no beam search for now - seq_outputs: List[SequenceOutput] = [] - log_prob = { sample_result[0]: Logprob(logprob=inf, rank=None, decoded_token=None) } - seq_output = SequenceOutput(seq_ids[parent_id], sample_result[0], log_prob) - seq_output.output_tokens = sample_result - seq_outputs.append(seq_output) - sampler_output.append(CompletionSequenceGroupOutput(seq_outputs, prompt_logprobs=None)) + if skip_sampler_cpu_output: + pass + else: + for seq_group, sample_result in zip(sampling_metadata.seq_groups, sample_results): + seq_ids = seq_group.seq_ids + parent_id = 0 # no beam search for now + seq_outputs: List[SequenceOutput] = [] + log_prob = { sample_result[0]: Logprob(logprob=inf, rank=None, decoded_token=None) } + seq_output = SequenceOutput(seq_ids[parent_id], sample_result[0], log_prob) + seq_output.output_tokens = sample_result + seq_outputs.append(seq_output) + sampler_output.append(CompletionSequenceGroupOutput(seq_outputs, prompt_logprobs=None)) - sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, None) + # If not specified, store None values in SamplerOutput. + if on_device_tensors is not None: + (sampled_token_probs, logprobs_tensor, + sampled_token_ids) = on_device_tensors + else: + sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, + None) return SamplerOutput( outputs=sampler_output, sampled_token_probs=sampled_token_probs, sampled_token_ids=sampled_token_ids, logprobs=logprobs_tensor, + deferred_sample_results=maybe_deferred_sample_results, ) \ No newline at end of file diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index abadddc024be7..acf4966a5fc6b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -157,7 +157,7 @@ def __init__( if quant_config is not None and quant_config.get_name() == "gguf": is_neox_style = False - if config.use_xp_rope: + if hasattr(config, "use_xp_rope") and config.use_xp_rope: self.rotary_emb = XPRotaryEmbedding( self.head_dim, self.head_dim, max_position_embeddings, rope_theta, is_neox_style) From 2b53f49a5262d3a496d20c354bbcc152a728503b Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Thu, 12 Sep 2024 05:39:53 +0000 Subject: [PATCH 40/61] another fix --- vllm/model_executor/layers/multi_head_sampler.py | 9 ++++++--- vllm/worker/multi_step_model_runner.py | 5 ++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/multi_head_sampler.py b/vllm/model_executor/layers/multi_head_sampler.py index e0bb30d37863c..3eeeb52bf1ff5 100644 --- a/vllm/model_executor/layers/multi_head_sampler.py +++ b/vllm/model_executor/layers/multi_head_sampler.py @@ -110,11 +110,14 @@ def _sample( sampling_metadata=sampling_metadata, sample_metadata=None, multinomial_samples=None, - greedy_samples=None, + greedy_samples=id_next_tensor, beam_search_logprobs=None, sample_results_dict={}) - return id_next_tensor, id_next_tensor + if not sampling_metadata.skip_sampler_cpu_output: + return id_next_tensor, id_next_tensor + else: + return maybe_deferred_args, id_next_tensor def _init_sampling_tensors(self, num_heads: int, @@ -292,5 +295,5 @@ def build_sampler_output(self, sampled_token_probs=sampled_token_probs, sampled_token_ids=sampled_token_ids, logprobs=logprobs_tensor, - deferred_sample_results=maybe_deferred_sample_results, + deferred_sample_results_args=maybe_deferred_sample_results, ) \ No newline at end of file diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index b13cf39bd846e..9c85a45ff7c51 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -361,8 +361,11 @@ def execute_model( # if CPU is ahead. if self.is_driver_worker and get_pp_group().is_last_rank: if self.pinned_sampled_token_ids is None: + num_output_heads = 1 + if hasattr(self.model_config.hf_config, "num_output_head"): + num_output_heads = self.model_config.hf_config.num_output_head self.pinned_sampled_token_ids = torch.zeros( - (self.scheduler_config.max_num_seqs, 1), + (self.scheduler_config.max_num_seqs, num_output_heads), dtype=torch.long, device="cpu", pin_memory=True) From e2dbf85eb8b4d0579f9b573923ec662cdb475453 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Fri, 13 Sep 2024 02:31:14 +0000 Subject: [PATCH 41/61] non fused mlp --- vllm/model_executor/models/llama.py | 60 +++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index acf4966a5fc6b..592fb786a76d5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -56,6 +56,28 @@ from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +class LlamaMLPNonFused(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + def forward(self, x): + y1 = F.silu(self.gate_proj(x)) + y2 = self.up_proj(x) + y = y1 * y2 + return self.down_proj(y) + class LlamaMLP(nn.Module): def __init__( @@ -84,15 +106,11 @@ def __init__( "Only silu is supported for now.") self.act_fn = SiluAndMul() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - def forward(self, x): - y1 = F.silu(self.gate_proj(x)) - y2 = self.up_proj(x) - y = y1 * y2 - return self.down_proj(y) + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x class LlamaAttention(nn.Module): @@ -229,14 +247,24 @@ def __init__( cache_config=cache_config, prefix=f"{prefix}.self_attn", ) - self.mlp = LlamaMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - bias=getattr(config, "mlp_bias", False), - prefix=f"{prefix}.mlp", - ) + if getattr(config, "use_fused_mlp", True): + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = LlamaMLPNonFused( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, From 64d8f54f4f2d0984806e2b1f95d67e55c934be93 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Fri, 13 Sep 2024 07:10:27 +0000 Subject: [PATCH 42/61] overall good, only kernal issue now --- vllm/engine/output_processor/multi_step.py | 8 +++++++- vllm/model_executor/layers/multi_head_sampler.py | 5 +++-- vllm/worker/multi_step_model_runner.py | 15 ++++++++++++--- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index c73db765fc3b5..bf98d40967013 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -137,7 +137,13 @@ def _process_decode_and_stop(self, seq: Sequence, def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], sampling_params: SamplingParams) -> None: - output_token_ids = [sample.output_token for sample in valid_samples] + output_token_ids = [] + for sample in valid_samples: + if len(sample.output_tokens) > 1: + output_token_ids.append(sample.output_tokens) + else: + output_token_ids.append(sample.output_token) + output_logprobs = [sample.logprobs for sample in valid_samples] # Truncate to max_tokens if necessary. diff --git a/vllm/model_executor/layers/multi_head_sampler.py b/vllm/model_executor/layers/multi_head_sampler.py index 3eeeb52bf1ff5..0f7ff7a46d719 100644 --- a/vllm/model_executor/layers/multi_head_sampler.py +++ b/vllm/model_executor/layers/multi_head_sampler.py @@ -77,13 +77,14 @@ def forward( modify_greedy_probs=False ) - id_next = maybe_sampled_tokens_tensor.reshape(-1, num_heads).tolist() + sampled_token_ids_tensor = maybe_sampled_tokens_tensor.reshape(-1, num_heads) + id_next = sampled_token_ids_tensor.tolist() if self.include_gpu_probs_tensor: # Since we will defer sampler result Pythonization, # preserve GPU-side tensors in support of later # deferred pythonization of logprobs - sampled_token_ids_tensor = maybe_sampled_tokens_tensor.to(dtype=torch.long, device=probs.device) + sampled_token_ids_tensor = sampled_token_ids_tensor.to(dtype=torch.long, device=probs.device) on_device_tensors = (probs, logprobs, sampled_token_ids_tensor) else: # Since Pythonization has already happened, don't preserve diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 9c85a45ff7c51..84aaed88f49f9 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -691,6 +691,9 @@ def _pythonize_sampler_output( seq_ids = seq_group.seq_ids next_token_ids = sample_result + num_output_heads = sampled_token_ids.shape[1] + if num_output_heads > 1: + next_token_ids = [next_token_ids] parent_ids = [0] if cache is not None: @@ -708,7 +711,11 @@ def _pythonize_sampler_output( seq_output: SequenceOutput = cache.cached_seq_output.get_object( ) seq_output.parent_seq_id = seq_ids[parent_id] - seq_output.output_token = next_token_id + if num_output_heads > 1: + seq_output.output_token = next_token_id[0] + seq_output.output_tokens = next_token_id + else: + seq_output.output_token = next_token_id if logprobs_are_requested: seq_output.logprobs = group_sample_logprobs[tdx] @@ -719,8 +726,10 @@ def _pythonize_sampler_output( logprobs.logprob = float('inf') logprobs.rank = None logprobs.decoded_token = None - - seq_output.logprobs[next_token_id] = logprobs + if num_output_heads > 1: + seq_output.logprobs[next_token_id[0]] = logprobs + else: + seq_output.logprobs[next_token_id] = logprobs seq_outputs.append(seq_output) From b79d49bbe9d41da38b844462037f18c61be73271 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Fri, 13 Sep 2024 08:40:57 +0000 Subject: [PATCH 43/61] fix bug --- vllm/engine/output_processor/multi_step.py | 2 +- vllm/engine/output_processor/single_step.py | 2 +- vllm/model_executor/models/ttslm.py | 11 +++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index bf98d40967013..0245a4a8652c1 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -139,7 +139,7 @@ def _process_seq_outputs(self, seq: Sequence, sampling_params: SamplingParams) -> None: output_token_ids = [] for sample in valid_samples: - if len(sample.output_tokens) > 1: + if sample.output_tokens and len(sample.output_tokens) > 1: output_token_ids.append(sample.output_tokens) else: output_token_ids.append(sample.output_token) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 2a4c1096bb974..f9f1da293dc0f 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -119,7 +119,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # only have one sequence seq = seq_group.seqs[0] if not is_async: - if len(sample.output_tokens) > 1: + if sample.output_tokens and len(sample.output_tokens) > 1: seq.append_token_id(sample.output_tokens, sample.logprobs) else: diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 654eccd611c3d..2a2b1ad671658 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -52,6 +52,8 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() + self.config = config + # static parameters, put them in config later self.num_audio_tokens = config.num_audio_tokens self.num_text_tokens = config.num_text_tokens @@ -81,6 +83,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # (".gate_up_proj", ".gate_proj", 0), # (".gate_up_proj", ".up_proj", 1), ] + + if getattr(self.config, "use_fused_mlp", True): + stacked_params_mapping.extend( + [ + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1) + ] + ) + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: From 44fe04f80945d1899846dc2b5cc8dbbc12e61785 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Fri, 13 Sep 2024 08:41:42 +0000 Subject: [PATCH 44/61] udpate --- testllama.py | 7 ++++--- tts_fish.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/testllama.py b/testllama.py index 3e985c80f4465..41853eba76a01 100644 --- a/testllama.py +++ b/testllama.py @@ -4,10 +4,10 @@ torch.random.manual_seed(999) -llm = LLM(model='/home/zhn/g/Meta-Llama-3-8B-Instruct', gpu_memory_utilization=0.5, enforce_eager=True, num_scheduler_steps=8) +llm = LLM(model='/home/zhn/g/Meta-Llama-3-8B-Instruct', gpu_memory_utilization=0.5, enforce_eager=True) prompts = [ - "Hi my name is", - "Tell me a joke", + "Hi my name is ", + "Tell me a joke ", ] texts = [] @@ -18,6 +18,7 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") texts.append(generated_text) end = time.time() print(f"Time taken: {end - start:.2f}s") diff --git a/tts_fish.py b/tts_fish.py index d47b8eb6786dd..5ff27bbeedc4e 100644 --- a/tts_fish.py +++ b/tts_fish.py @@ -96,12 +96,12 @@ token_ids.append(7003) llm_inputs.append(token_ids) -llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True, enforce_eager=True, num_scheduler_steps=8) +llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) prompts = [ {"prompt_token_ids": llm_input} for llm_input in llm_inputs ] -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) outputs = llm.generate(prompts, sampling_params) for output in outputs: print(output.prompt) From bac660b346ce27d3016ea1cc426012d03e8e61ce Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Fri, 13 Sep 2024 08:47:57 +0000 Subject: [PATCH 45/61] fix bug --- vllm/engine/llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 889a83869c59a..ca2acf6b35026 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1469,7 +1469,7 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] - if len(sample.output_tokens) > 1: + if sample.output_tokens and len(sample.output_tokens) > 1: seq.append_token_id(sample.output_tokens, sample.logprobs) else: From d1210e4d64c5179020674c4e490f6e037dfe20b1 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Thu, 19 Sep 2024 13:52:47 +0000 Subject: [PATCH 46/61] clean up --- testllama.py | 13 -------- vllm/model_executor/layers/sampler.py | 39 +++++++----------------- vllm/model_executor/sampling_metadata.py | 10 ++---- 3 files changed, 13 insertions(+), 49 deletions(-) diff --git a/testllama.py b/testllama.py index 200cf7779a2a9..275ee81b530a4 100644 --- a/testllama.py +++ b/testllama.py @@ -21,16 +21,3 @@ texts.append(generated_text) end = time.time() print(f"Time taken: {end - start:.2f}s") -# for text in texts: -# print(text) - -# for i in range(5): -# prompts.append(prompts[0]) -# prompts.append(prompts[1]) - -# sampling_params = SamplingParams(temperature=1, top_k=1, max_tokens=100) -# outputs = llm.generate(prompts, sampling_params) -# for output in outputs: -# prompt = output.prompt -# generated_text = output.outputs[0].text -# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 8d8408edea525..41abdf211e7e7 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -44,7 +44,7 @@ class Sampler(nn.Module): in logits for each token in the input prompt. """ - def __init__(self, idx: int = -1): + def __init__(self): super().__init__() # Whether or not the SamplerOutput should have on-device tensors @@ -52,7 +52,6 @@ def __init__(self, idx: int = -1): # speculative decoding. self.include_gpu_probs_tensor = False self.should_modify_greedy_probs_inplace = False - self.head_idx = idx def _init_sampling_tensors( self, @@ -72,14 +71,13 @@ def _init_sampling_tensors( # Initialize new sampling tensors (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p, is_prompt) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype, head_idx=self.head_idx) + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) self._sampling_tensors = sampling_tensors self._do_penalties = do_penalties self._do_top_p_top_k = do_top_p_top_k self._do_min_p = do_min_p - self._is_prompt = is_prompt def forward( self, @@ -109,24 +107,16 @@ def forward( do_penalties = self._do_penalties do_top_p_top_k = self._do_top_p_top_k do_min_p = self._do_min_p - is_prompt = self._is_prompt logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Apply presence and frequency penalties. if do_penalties: - if self.head_idx >= 0 and is_prompt: - # when multihead output and prompt phase, we do not apply penalties - # because the prompt tokens are not same as the output tokens - pass - else: - skip_prompt_repetition = self.head_idx >= 0 and not is_prompt - logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties, - skip_prompt_repetition=skip_prompt_repetition) + logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. @@ -260,20 +250,13 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor, - skip_prompt_repetition: bool = False) -> torch.Tensor: + repetition_penalties: torch.Tensor) -> torch.Tensor: num_seqs, vocab_size = logits.shape + _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, + num_seqs) output_bin_counts, output_mask = _get_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) - if skip_prompt_repetition: - # when multihead output, we do not apply penalties for prompt tokens - # because the prompt tokens are not same as the output tokens - prompt_mask = torch.zeros((num_seqs, vocab_size), dtype=torch.bool, device=prompt_tokens_tensor.device) - else: - _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, - num_seqs) - repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) repetition_penalties[~(prompt_mask | output_mask)] = 1.0 logits = torch.where(logits > 0, logits / repetition_penalties, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 8acd6c6dadb68..94b4b14416821 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -391,7 +391,6 @@ def from_sampling_metadata( device: torch.device, dtype: torch.dtype, *, - head_idx: int = -1, extra_seeds_to_generate: int = 0, extra_entropy: Optional[Tuple[int, ...]] = None ) -> Tuple["SamplingTensors", bool, bool, bool]: @@ -414,7 +413,6 @@ def from_sampling_metadata( do_penalties = False do_top_p_top_k = False do_min_p = False - is_prompt = False if _USE_TRITON_SAMPLER: prompt_best_of: List[int] = [] @@ -503,7 +501,6 @@ def from_sampling_metadata( if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids - repetition_window = seq_group.sampling_params.repetition_window if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) @@ -515,17 +512,14 @@ def from_sampling_metadata( for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] prompt_tokens.append(seq_data.prompt_token_ids_array) - if head_idx >= 0 and seq_group.is_prompt == False: - output_tokens.append([i[head_idx] for i in seq_data.output_token_ids_array[-repetition_window:]]) - else: - output_tokens.append(seq_data.output_token_ids_array) + output_tokens.append(seq_data.output_token_ids_array) sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, frequency_penalties, repetition_penalties, sampling_seeds, sample_indices, prompt_tokens, output_tokens, vocab_size, extra_seeds_to_generate, device, dtype) - return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p, is_prompt) + return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) @classmethod def from_lists(cls, temperatures: List[float], top_ps: List[float], From c34c64f269cd24ce1d58424009f7c91d5c024947 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Thu, 19 Sep 2024 13:53:31 +0000 Subject: [PATCH 47/61] cleanup --- vllm/model_executor/models/ttslm.py | 6 ------ vllm/multimodal/speech.py | 24 ------------------------ 2 files changed, 30 deletions(-) diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/ttslm.py index 086c00f1a8b85..8d75dfa05aafb 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/ttslm.py @@ -127,12 +127,6 @@ def sample( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - # head_logits = logits.permute(1, 0, 2) - # next_tokens = self.samplers[0](head_logits[0], sampling_metadata) - # for i in range(self.num_output_head - 1): - # output = self.samplers[i](head_logits[i + 1], sampling_metadata) - # self.merge_sample_results(next_tokens, output) - next_tokens = self.sampler(logits, sampling_metadata) return next_tokens diff --git a/vllm/multimodal/speech.py b/vllm/multimodal/speech.py index af4e04a507292..08f01c6eef1a5 100644 --- a/vllm/multimodal/speech.py +++ b/vllm/multimodal/speech.py @@ -18,30 +18,6 @@ import base64 import pickle -class FishSpeechPlugin(MultiModalPlugin): - - def get_data_key(self) -> str: - return "audio1" - - def _default_input_mapper(self, ctx: InputContext, - data: object) -> MultiModalInputs: - if isinstance(data, str): - base64_decoded = base64.b64decode(data) - deserialized_data = pickle.loads(base64_decoded) - tensor = torch.from_numpy(deserialized_data) - return MultiModalInputs({"audio": tensor}) - elif isinstance(data, torch.Tensor): - raise NotImplementedError("Embeddings input is not supported yet") - - raise TypeError(f"Invalid image type: {type(data)}") - - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - return 16 - - @staticmethod - def get_default_audio(): - return 'a' - class SpeechPlugin(MultiModalPlugin): def get_data_key(self) -> str: From 901738702023f8917285fbf3d550572dfb16015f Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Thu, 19 Sep 2024 14:27:01 +0000 Subject: [PATCH 48/61] clean code --- testllama.py | 15 +- tts_fish.py | 216 +++++++++++------- vllm/model_executor/models/__init__.py | 2 +- .../models/{ttslm.py => fishtts.py} | 11 +- vllm/multimodal/speech.py | 24 -- 5 files changed, 131 insertions(+), 137 deletions(-) rename vllm/model_executor/models/{ttslm.py => fishtts.py} (93%) diff --git a/testllama.py b/testllama.py index 41853eba76a01..382276c8164f4 100644 --- a/testllama.py +++ b/testllama.py @@ -21,17 +21,4 @@ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") texts.append(generated_text) end = time.time() -print(f"Time taken: {end - start:.2f}s") -# for text in texts: -# print(text) - -# for i in range(5): -# prompts.append(prompts[0]) -# prompts.append(prompts[1]) - -# sampling_params = SamplingParams(temperature=1, top_k=1, max_tokens=100) -# outputs = llm.generate(prompts, sampling_params) -# for output in outputs: -# prompt = output.prompt -# generated_text = output.outputs[0].text -# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file +print(f"Time taken: {end - start:.2f}s") \ No newline at end of file diff --git a/tts_fish.py b/tts_fish.py index 5ff27bbeedc4e..7250fb0ccfedd 100644 --- a/tts_fish.py +++ b/tts_fish.py @@ -1,83 +1,90 @@ +import asyncio from vllm import LLM, SamplingParams from tokenizers import Tokenizer import pypinyin import torch + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine torch.random.manual_seed(999) -# tts1 = torch.load('/home/zhn/ttslm_dev/GPT_merged_emb_nonorm.pt') -# tts2 = torch.load('/home/zhn/fishtts/checkpoint-1400000.bak') - -# layer = 24 -# dim = 1536 -# num_audio_tokens = 1026 -# num_text_tokens = 7002 -# llama = tts2['model']['llama'] - -# llama.pop('freqs_cis') -# llama.pop('causal_mask') - -# text_emb = llama['text_embeddings.weight'] -# for i in range(100): -# text_emb = torch.cat([text_emb, torch.zeros((1,dim), device=text_emb.device)], 0) -# llama['emb_text.weight'] = text_emb -# llama.pop('text_embeddings.weight') - -# llama['emb_code.0.weight'] = llama['code_embeddings.weight'][0:num_audio_tokens] -# llama['emb_code.1.weight'] = llama['code_embeddings.weight'][num_audio_tokens-2:num_audio_tokens - 2 + num_audio_tokens] -# llama.pop('code_embeddings.weight') - -# for i in range(layer): -# qkv_name = f'layers.{i}.attention.wqkv.weight' -# q = llama[qkv_name][0:dim] -# k = llama[qkv_name][dim:2*dim] -# v = llama[qkv_name][2*dim:] -# llama[f'gpt.layers.{i}.self_attn.q_proj.weight'] = q -# llama[f'gpt.layers.{i}.self_attn.k_proj.weight'] = k -# llama[f'gpt.layers.{i}.self_attn.v_proj.weight'] = v -# llama.pop(qkv_name) - -# wo_name = f'layers.{i}.attention.wo.weight' -# wo = llama[wo_name] -# llama[f'gpt.layers.{i}.self_attn.o_proj.weight'] = wo -# llama.pop(wo_name) - -# gate_proj_name = f'layers.{i}.feed_forward.w1.weight' -# w_gate = llama[gate_proj_name] -# llama[f'gpt.layers.{i}.mlp.gate_proj.weight'] = w_gate -# llama.pop(gate_proj_name) - -# gate_up_proj_name = f'layers.{i}.feed_forward.w3.weight' -# w_gate_up = llama[gate_up_proj_name] -# llama[f'gpt.layers.{i}.mlp.up_proj.weight'] = w_gate_up -# llama.pop(gate_up_proj_name) - -# gate_down_proj_name = f'layers.{i}.feed_forward.w2.weight' -# w_gate_down = llama[gate_down_proj_name] -# llama[f'gpt.layers.{i}.mlp.down_proj.weight'] = w_gate_down -# llama.pop(gate_down_proj_name) - -# attn_norm_name = f'layers.{i}.attention_norm.weight' -# w_attn_norm = llama[attn_norm_name] -# llama[f'gpt.layers.{i}.input_layernorm.weight'] = w_attn_norm -# llama.pop(attn_norm_name) - -# ffn_norm_name = f'layers.{i}.ffn_norm.weight' -# w_ffn_norm = llama[ffn_norm_name] -# llama[f'gpt.layers.{i}.post_attention_layernorm.weight'] = w_ffn_norm -# llama.pop(ffn_norm_name) - - -# norm_name = 'norm.weight' -# w_norm = llama[norm_name] -# llama['gpt.norm.weight'] = w_norm -# llama.pop(norm_name) - -# output_name = 'output.weight' -# w_output = llama[output_name] -# llama['lm_head.0.weight'] = w_output[num_text_tokens:num_text_tokens+num_audio_tokens] -# llama['lm_head.1.weight'] = w_output[num_text_tokens+num_audio_tokens:num_text_tokens+num_audio_tokens*2] -# llama.pop(output_name) - -# torch.save(llama, '/home/zhn/fishtts/llama.pt') + +def convert_model(): + tts2 = torch.load('/home/zhn/fishtts/checkpoint-1400000.bak') + + layer = 24 + dim = 1536 + num_audio_tokens = 1026 + num_text_tokens = 7002 + llama = tts2['model']['llama'] + + llama.pop('freqs_cis') + llama.pop('causal_mask') + + text_emb = llama['text_embeddings.weight'] + for i in range(100): + text_emb = torch.cat([text_emb, torch.zeros((1,dim), device=text_emb.device)], 0) + llama['emb_text.weight'] = text_emb + llama.pop('text_embeddings.weight') + + llama['emb_code.0.weight'] = llama['code_embeddings.weight'][0:num_audio_tokens] + llama['emb_code.1.weight'] = llama['code_embeddings.weight'][num_audio_tokens-2:num_audio_tokens - 2 + num_audio_tokens] + llama.pop('code_embeddings.weight') + + for i in range(layer): + qkv_name = f'layers.{i}.attention.wqkv.weight' + q = llama[qkv_name][0:dim] + k = llama[qkv_name][dim:2*dim] + v = llama[qkv_name][2*dim:] + llama[f'gpt.layers.{i}.self_attn.q_proj.weight'] = q + llama[f'gpt.layers.{i}.self_attn.k_proj.weight'] = k + llama[f'gpt.layers.{i}.self_attn.v_proj.weight'] = v + llama.pop(qkv_name) + + wo_name = f'layers.{i}.attention.wo.weight' + wo = llama[wo_name] + llama[f'gpt.layers.{i}.self_attn.o_proj.weight'] = wo + llama.pop(wo_name) + + gate_proj_name = f'layers.{i}.feed_forward.w1.weight' + w_gate = llama[gate_proj_name] + llama[f'gpt.layers.{i}.mlp.gate_proj.weight'] = w_gate + llama.pop(gate_proj_name) + + gate_up_proj_name = f'layers.{i}.feed_forward.w3.weight' + w_gate_up = llama[gate_up_proj_name] + llama[f'gpt.layers.{i}.mlp.up_proj.weight'] = w_gate_up + llama.pop(gate_up_proj_name) + + gate_down_proj_name = f'layers.{i}.feed_forward.w2.weight' + w_gate_down = llama[gate_down_proj_name] + llama[f'gpt.layers.{i}.mlp.down_proj.weight'] = w_gate_down + llama.pop(gate_down_proj_name) + + attn_norm_name = f'layers.{i}.attention_norm.weight' + w_attn_norm = llama[attn_norm_name] + llama[f'gpt.layers.{i}.input_layernorm.weight'] = w_attn_norm + llama.pop(attn_norm_name) + + ffn_norm_name = f'layers.{i}.ffn_norm.weight' + w_ffn_norm = llama[ffn_norm_name] + llama[f'gpt.layers.{i}.post_attention_layernorm.weight'] = w_ffn_norm + llama.pop(ffn_norm_name) + + + norm_name = 'norm.weight' + w_norm = llama[norm_name] + llama['gpt.norm.weight'] = w_norm + llama.pop(norm_name) + + output_name = 'output.weight' + w_output = llama[output_name] + llama['lm_head.0.weight'] = w_output[num_text_tokens:num_text_tokens+num_audio_tokens] + llama['lm_head.1.weight'] = w_output[num_text_tokens+num_audio_tokens:num_text_tokens+num_audio_tokens*2] + llama.pop(output_name) + + torch.save(llama, '/home/zhn/fishtts/llama.pt') + +streaming=True texts = [ '城市霓虹,夜幕低垂,梦想之光,闪烁不已。心向未来,勇往直前,在星空下,奋斗的旋律。', @@ -96,16 +103,49 @@ token_ids.append(7003) llm_inputs.append(token_ids) -llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) -prompts = [ - {"prompt_token_ids": llm_input} for llm_input in llm_inputs -] - -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) -outputs = llm.generate(prompts, sampling_params) -for output in outputs: - print(output.prompt) - token_ids = output.outputs[0].token_ids - for token_id in token_ids: - print([x - 0 for x in token_id]) - print(len(token_ids)) +if not streaming: + llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) + prompts = [ + {"prompt_token_ids": llm_input} for llm_input in llm_inputs + ] + + sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + print(output.prompt) + token_ids = output.outputs[0].token_ids + for token_id in token_ids: + print([x - 0 for x in token_id]) + print(len(token_ids)) + +else: + engine_args = AsyncEngineArgs(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) + model = AsyncLLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) + prompts = [ + {"prompt_token_ids": llm_input} for llm_input in llm_inputs + ] + + async def generate_streaming(prompt, id): + results_generator = model.generate(prompt, sampling_params, request_id=id) + count=0 + tokens = [] + async for request_output in results_generator: + token_ids = request_output.outputs[0].token_ids + print(f'{id} {[x - 0 for x in token_ids[-1]]}') + tokens.append([x - 0 for x in token_ids[-1]]) + count+=1 + + print(id) + print(len(tokens)) + for token in tokens: + print(token) + + async def generate(): + tasks = [] + for i in range(2): + t = generate_streaming(prompts[i%2], i) + tasks.append(t) + await asyncio.gather(*tasks) + + asyncio.run(generate()) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 208e10ff8a0d2..f83f45ce2925c 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -65,7 +65,7 @@ "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), - "ChatTtsLlm": ("ttslm", "ChatTtsLlm") + "FishTtsLlm": ("fishtts", "FishTtsLlm") } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/ttslm.py b/vllm/model_executor/models/fishtts.py similarity index 93% rename from vllm/model_executor/models/ttslm.py rename to vllm/model_executor/models/fishtts.py index 2a2b1ad671658..5eadfdb63db5c 100644 --- a/vllm/model_executor/models/ttslm.py +++ b/vllm/model_executor/models/fishtts.py @@ -45,7 +45,7 @@ def get_max_speech_tokens(ctx: InputContext): @MULTIMODAL_REGISTRY.register_speech_input_mapper() @INPUT_REGISTRY.register_dummy_data(dummy_data_for_ttsllm) @MULTIMODAL_REGISTRY.register_max_speech_tokens(get_max_speech_tokens) -class ChatTtsLlm(nn.Module): +class FishTtsLlm(nn.Module): def __init__(self, config: LlamaConfig, cache_config: Optional[CacheConfig] = None, @@ -72,7 +72,6 @@ def __init__(self, ]) self.logits_processor = LogitsProcessor(self.num_audio_tokens) self.sampler = MultiheadSampler() - # self.samplers = [Sampler(head_idx) for head_idx in range(self.num_output_head)] def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -149,12 +148,6 @@ def sample( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - # head_logits = logits.permute(1, 0, 2) - # next_tokens = self.samplers[0](head_logits[0], sampling_metadata) - # for i in range(self.num_output_head - 1): - # output = self.samplers[i](head_logits[i + 1], sampling_metadata) - # self.merge_sample_results(next_tokens, output) - next_tokens = self.sampler(logits, sampling_metadata) return next_tokens @@ -173,8 +166,6 @@ def forward( hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids, is_prompt) - # spk_emb = kwargs.get("speech", None) - # self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) model_output = self.gpt( input_ids=input_ids, inputs_embeds=hidden_states, diff --git a/vllm/multimodal/speech.py b/vllm/multimodal/speech.py index 4b2dfbcf90592..08f01c6eef1a5 100644 --- a/vllm/multimodal/speech.py +++ b/vllm/multimodal/speech.py @@ -18,30 +18,6 @@ import base64 import pickle -class FishSpeechPlugin(MultiModalPlugin): - - def get_data_key(self) -> str: - return "audio" - - def _default_input_mapper(self, ctx: InputContext, - data: object) -> MultiModalInputs: - if isinstance(data, str): - base64_decoded = base64.b64decode(data) - deserialized_data = pickle.loads(base64_decoded) - tensor = torch.from_numpy(deserialized_data) - return MultiModalInputs({"audio": tensor}) - elif isinstance(data, torch.Tensor): - raise NotImplementedError("Embeddings input is not supported yet") - - raise TypeError(f"Invalid image type: {type(data)}") - - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - return 16 - - @staticmethod - def get_default_audio(): - return 'a' - class SpeechPlugin(MultiModalPlugin): def get_data_key(self) -> str: From a6321bd8e431a1c3de15512a55f4f665f0c6f189 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Thu, 19 Sep 2024 14:29:17 +0000 Subject: [PATCH 49/61] clean code --- tts_fish.py => fishtts_sample.py | 0 testllama.py | 24 --------------- tts_async.py | 52 -------------------------------- 3 files changed, 76 deletions(-) rename tts_fish.py => fishtts_sample.py (100%) delete mode 100644 testllama.py delete mode 100644 tts_async.py diff --git a/tts_fish.py b/fishtts_sample.py similarity index 100% rename from tts_fish.py rename to fishtts_sample.py diff --git a/testllama.py b/testllama.py deleted file mode 100644 index 382276c8164f4..0000000000000 --- a/testllama.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch -import time -from vllm import LLM, SamplingParams - -torch.random.manual_seed(999) - -llm = LLM(model='/home/zhn/g/Meta-Llama-3-8B-Instruct', gpu_memory_utilization=0.5, enforce_eager=True) -prompts = [ - "Hi my name is ", - "Tell me a joke ", -] - -texts = [] -start = time.time() -for i in range(10): - sampling_params = SamplingParams(temperature=0, top_k=1, max_tokens=200, top_p=1, repetition_penalty=0.9) - outputs = llm.generate(prompts, sampling_params) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - texts.append(generated_text) -end = time.time() -print(f"Time taken: {end - start:.2f}s") \ No newline at end of file diff --git a/tts_async.py b/tts_async.py deleted file mode 100644 index 36c0ba350fe64..0000000000000 --- a/tts_async.py +++ /dev/null @@ -1,52 +0,0 @@ -import asyncio -import time -from tokenizers import Tokenizer -import pypinyin -import torch -from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams - -texts = [ - '城市霓虹,夜幕低垂,梦想之光,闪烁不已。心向未来,勇往直前,在星空下,奋斗的旋律。', - '在这个数字的世界里,你是我的唯一,爱情如同网络连接,无论距离多遥远。我们的心相互链接,在虚拟的空间中漫游,每条信息都是爱的表达,每个瞬间都是甜蜜的时刻。爱情不再是纸上文字,而是数码世界的交流,在屏幕上,我们相拥相视,你是我的电子爱情。'] -llm_inputs = [] -tokenizer = Tokenizer.from_file('/home/zhn/fishtts/vocab.json') -for text in texts: - pinyin = "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]) - txt = f"[zh-cn]{pinyin}" - txt = txt.replace(" ", "[SPACE]") - token_ids = tokenizer.encode(txt).ids - token_ids.insert(0, 7001) - token_ids.append(0) - token_ids.append(7003) - llm_inputs.append(token_ids) - -engine_args = AsyncEngineArgs(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) -model = AsyncLLMEngine.from_engine_args(engine_args) -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) -prompts = [ - {"prompt_token_ids": llm_input} for llm_input in llm_inputs -] - -async def generate_streaming(prompt, id): - results_generator = model.generate(prompt, sampling_params, request_id=id) - count=0 - tokens = [] - async for request_output in results_generator: - token_ids = request_output.outputs[0].token_ids - print(f'{id} {[x - 0 for x in token_ids[-1]]}') - tokens.append([x - 0 for x in token_ids[-1]]) - count+=1 - - print(id) - print(len(tokens)) - for token in tokens: - print(token) - -async def generate(): - tasks = [] - for i in range(2): - t = generate_streaming(prompts[i%2], i) - tasks.append(t) - await asyncio.gather(*tasks) - -asyncio.run(generate()) From 9dd86bd4550aafad3ac85eef0624e17e5873ce57 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Fri, 20 Sep 2024 06:32:05 +0000 Subject: [PATCH 50/61] replace torch layer to vllm layers --- vllm/model_executor/models/chattts.py | 24 +++++++++++-- vllm/model_executor/models/fishtts.py | 50 ++++++++++++++++++++++++--- vllm/model_executor/models/llama.py | 26 ++++++++++---- 3 files changed, 88 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/chattts.py b/vllm/model_executor/models/chattts.py index f3ed059a0f669..9a112a06db80f 100644 --- a/vllm/model_executor/models/chattts.py +++ b/vllm/model_executor/models/chattts.py @@ -49,7 +49,8 @@ class ChatTtsLlm(nn.Module): def __init__(self, config: LlamaConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None,) -> None: super().__init__() # static parameters, put them in config later @@ -65,8 +66,27 @@ def __init__(self, VocabParallelEmbedding(self.num_audio_tokens, self.model_dim) for _ in range(self.num_output_head) ]) + ParallelLMHead( + self.num_audio_tokens, + self.model_dim, + org_num_embeddings=self.num_audio_tokens, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) self.lm_head = nn.ModuleList([ - nn.Linear(self.model_dim, self.num_audio_tokens, bias=False) for _ in range(self.num_output_head) + ParallelLMHead( + self.num_audio_tokens, + self.model_dim, + org_num_embeddings=self.num_audio_tokens, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) for _ in range(self.num_output_head) ]) self.logits_processor = LogitsProcessor(self.num_audio_tokens) self.sampler = MultiheadSampler() diff --git a/vllm/model_executor/models/fishtts.py b/vllm/model_executor/models/fishtts.py index 5eadfdb63db5c..c296f4260a35d 100644 --- a/vllm/model_executor/models/fishtts.py +++ b/vllm/model_executor/models/fishtts.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding +from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -45,11 +46,43 @@ def get_max_speech_tokens(ctx: InputContext): @MULTIMODAL_REGISTRY.register_speech_input_mapper() @INPUT_REGISTRY.register_dummy_data(dummy_data_for_ttsllm) @MULTIMODAL_REGISTRY.register_max_speech_tokens(get_max_speech_tokens) -class FishTtsLlm(nn.Module): +class FishTtsLlm(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__(self, config: LlamaConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None,) -> None: super().__init__() self.config = config @@ -60,7 +93,7 @@ def __init__(self, self.num_output_head = config.num_output_head self.audio_start_token_id = config.audio_start_token_id - self.gpt = LlamaModel(config) + self.gpt = LlamaModel(config, lora_config=lora_config) self.model_dim = self.gpt.config.hidden_size self.emb_text = VocabParallelEmbedding(self.num_text_tokens, self.model_dim) self.emb_code = nn.ModuleList([ @@ -68,7 +101,16 @@ def __init__(self, ]) self.lm_head = nn.ModuleList([ - nn.Linear(self.model_dim, self.num_audio_tokens, bias=False) for _ in range(self.num_output_head) + ParallelLMHead( + self.num_audio_tokens, + self.model_dim, + org_num_embeddings=self.num_audio_tokens, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) for _ in range(self.num_output_head) ]) self.logits_processor = LogitsProcessor(self.num_audio_tokens) self.sampler = MultiheadSampler() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 592fb786a76d5..83ef03f55d98b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -68,15 +68,29 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.gate_proj = RowParallelLinear(input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + self.up_proj = RowParallelLinear(input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") def forward(self, x): - y1 = F.silu(self.gate_proj(x)) - y2 = self.up_proj(x) + y1, _ = self.gate_proj(x) + y1 = F.silu(y1) + y2, _ = self.up_proj(x) y = y1 * y2 - return self.down_proj(y) + y, _ = self.down_proj(y) + return y class LlamaMLP(nn.Module): From 26a016949385d61b89c5f64d6dba2b670a36f8f2 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Tue, 24 Sep 2024 09:33:00 +0000 Subject: [PATCH 51/61] optimize chattts convert code --- chattts_sample.py | 61 ++++++++++++++--------------------------------- 1 file changed, 18 insertions(+), 43 deletions(-) diff --git a/chattts_sample.py b/chattts_sample.py index 752399760c75d..0892f59141955 100644 --- a/chattts_sample.py +++ b/chattts_sample.py @@ -5,49 +5,21 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine torch.random.manual_seed(999) -# tts = torch.load('/home/zhn/ttslm/GPT_merged_emb_nonorm.pt') -# text_emb_count = tts['emb_text.weight'].shape[0] -# audio_emb_count = tts['emb_code.0.weight'].shape[0] -# model_dim = tts['emb_text.weight'].shape[1] - -# # append audio embeddings to text embeddings -# # all_0 = text_emb + audio_emb_0 -# all_0 = torch.cat([tts['emb_text.weight'], tts['emb_code.0.weight']], dim=0) - -# # all_1 = zero + audio_emb_1 -# all_1 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.1.weight']], dim=0) - -# # all_2 = zero + audio_emb_2 -# all_2 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.2.weight']], dim=0) - -# # all_3 = zero + audio_emb_3 -# all_3 = torch.cat([torch.zeros(text_emb_count, model_dim), tts['emb_code.3.weight']], dim=0) - -# # remove text emb and audio emb in the model -# tts.pop('emb_text.weight') -# tts.pop('emb_code.0.weight') -# tts.pop('emb_code.1.weight') -# tts.pop('emb_code.2.weight') -# tts.pop('emb_code.3.weight') - -# # add new embeddings to the model -# tts['emb_all.0.weight'] = all_0 -# tts['emb_all.1.weight'] = all_1 -# tts['emb_all.2.weight'] = all_2 -# tts['emb_all.3.weight'] = all_3 - -# for i in range(4): -# original0 = tts[f'head_code.{i}.parametrizations.weight.original0'] -# original1 = tts[f'head_code.{i}.parametrizations.weight.original1'] -# # get the normalized weights based on the original 0 and 1 -# weight_norm0 = torch._weight_norm(original1, original0, dim=0) -# tts.pop(f'head_code.{i}.parametrizations.weight.original0') -# tts.pop(f'head_code.{i}.parametrizations.weight.original1') -# tts[f'lm_head.{i}.weight'] = weight_norm0 - -# # save the model -# torch.save(tts, '/home/zhn/ttslm/GPT_merged_emb_nonorm.pt') +def convert_model(): + chatts = torch.load('/home/zhn/g/ChatTTS/asset/GPT.pt') + + chatts.pop('head_text.parametrizations.weight.original0') + chatts.pop('head_text.parametrizations.weight.original1') + for i in range(4): + original0 = chatts[f'head_code.{i}.parametrizations.weight.original0'] + original1 = chatts[f'head_code.{i}.parametrizations.weight.original1'] + # get the normalized weights based on the original 0 and 1 + weight_norm0 = torch._weight_norm(original1, original0, dim=0) + chatts.pop(f'head_code.{i}.parametrizations.weight.original0') + chatts.pop(f'head_code.{i}.parametrizations.weight.original1') + chatts[f'lm_head.{i}.weight'] = weight_norm0 + torch.save(chatts, '/home/zhn/ttslm_dev/chattts.pt') streaming = False llm = LLM(model='/home/zhn/ttslm_dev', gpu_memory_utilization=0.5, dtype=torch.float32) @@ -55,10 +27,13 @@ { "prompt": "[Stts][spk_emb][speed_5]Your text one[Ptts]", "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, + }, + { + "prompt": "[Stts][spk_emb][speed_5]Your text two[Ptts]", + "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, } ] - if not streaming: sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[625], max_tokens=2048, top_k=1) outputs = llm.generate(prompts, sampling_params) From 00255ef498ccb76404c198c16b23be08bfdcf499 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Thu, 26 Sep 2024 03:12:05 +0000 Subject: [PATCH 52/61] optimize chattts convert code --- vllm/model_executor/models/chattts.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vllm/model_executor/models/chattts.py b/vllm/model_executor/models/chattts.py index 9a112a06db80f..09f9887e0b0bf 100644 --- a/vllm/model_executor/models/chattts.py +++ b/vllm/model_executor/models/chattts.py @@ -66,16 +66,6 @@ def __init__(self, VocabParallelEmbedding(self.num_audio_tokens, self.model_dim) for _ in range(self.num_output_head) ]) - ParallelLMHead( - self.num_audio_tokens, - self.model_dim, - org_num_embeddings=self.num_audio_tokens, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - quant_config=quant_config, - ) self.lm_head = nn.ModuleList([ ParallelLMHead( self.num_audio_tokens, From 4ae3509653a89386fdf2942c6de505aef343a03d Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Fri, 27 Sep 2024 09:58:29 +0000 Subject: [PATCH 53/61] e2e streaming --- fishtts_sample.py | 209 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 167 insertions(+), 42 deletions(-) diff --git a/fishtts_sample.py b/fishtts_sample.py index 7250fb0ccfedd..c4683d068a57c 100644 --- a/fishtts_sample.py +++ b/fishtts_sample.py @@ -1,4 +1,10 @@ import asyncio +import threading +import time +from typing import List + +import numpy +import onnx from vllm import LLM, SamplingParams from tokenizers import Tokenizer import pypinyin @@ -6,10 +12,14 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +import onnxruntime +import soundfile as sf +import queue + torch.random.manual_seed(999) def convert_model(): - tts2 = torch.load('/home/zhn/fishtts/checkpoint-1400000.bak') + tts2 = torch.load('/data/fishtts/checkpoint-1400000.bak') layer = 24 dim = 1536 @@ -82,17 +92,33 @@ def convert_model(): llama['lm_head.1.weight'] = w_output[num_text_tokens+num_audio_tokens:num_text_tokens+num_audio_tokens*2] llama.pop(output_name) - torch.save(llama, '/home/zhn/fishtts/llama.pt') + torch.save(llama, '/data/fishtts/llama.pt') -streaming=True +def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() texts = [ '城市霓虹,夜幕低垂,梦想之光,闪烁不已。心向未来,勇往直前,在星空下,奋斗的旋律。', '在这个数字的世界里,你是我的唯一,爱情如同网络连接,无论距离多遥远。我们的心相互链接,在虚拟的空间中漫游,每条信息都是爱的表达,每个瞬间都是甜蜜的时刻。爱情不再是纸上文字,而是数码世界的交流,在屏幕上,我们相拥相视,你是我的电子爱情。', - '城市霓虹,夜幕低垂,梦想之光,闪烁不已。心向未来,勇往直前,在星空下,奋斗的旋律。' + '探索清新世界的钥匙在此!用海洋微风洗衣粉,让您的衣物充满清晨海边的清新气息。我们的高效配方深层清洁衣物纤维去除顽固污渍的同时,带来持久的清香。不只是清洗更是衣物的焕新旅程。', + '从现在开始,让我们的多功能厨师机器人成为你厨房里的得力助手。它可以搅拌,切碎,烹饪,烘焙,满足你所有烹饪需求。创新美食,只需轻松一按。', + '打造完美家居生活,只需一款智能净化器。它能有效过滤空气中的污染物,释放负离子,让你每天呼吸到的都是最纯净的空气,为家人的健康护航。', + '我刚看完《流浪地球》,这部电影真的很打动我。它不仅仅展示了科幻世界中的宏大景象,更通过对人类团结和奉献精神的刻画,让我对未来充满了思考。影片中的视觉效果和细腻的情感描写,让我觉得这是一部值得反复琢磨的作品。如果你喜欢有深度的科幻电影,这部绝对不会让你失望。', + '下个月我计划去日本体验当地的文化。我特别期待去京都的古寺庙,想感受一下传统的日式建筑和庭园。东京的市场也让我兴奋,我打算品尝各种地道的小吃。此外,我计划学习一些基本的日语,这样能更好地融入当地生活,提升旅行的整体体验。你们有没有什么建议或者特别推荐的地方?', + '在保持健康方面,我尝试了一些新的饮食习惯。现在我更多地选择新鲜的蔬菜和水果,减少了糖分和加工食品的摄入。我发现这种饮食方式不仅改善了我的体重,还提升了整体的能量水平。此外,保持充足的水分摄入也是关键,它有助于身体的代谢和排毒。你们有什么其他的健康饮食建议吗?', + '为了提高学习效率,我采取了一些新方法。例如,我将复杂的学习任务拆分成小的目标,每完成一个小目标就能获得成就感。此外,我还使用了番茄工作法,设定25分钟专注学习,然后休息5分钟,这样可以有效避免疲劳。通过这些方法,我发现自己在学习过程中更加专注和高效。', + '有一本书《思考,快与慢》给我留下了深刻的印象。这本书由丹尼尔·卡尼曼撰写,详细探讨了人类思维的两种模式——快速直观和缓慢理性。通过丰富的实证研究,作者揭示了我们在日常决策中的思维偏差。这本书不仅在理论上很有趣,对实际生活中的决策也提供了很多有益的启示。', + '提升工作效率需要良好的时间管理。我发现将任务分解成小步骤,并逐步完成,能让工作变得更有条理。同时,使用待办事项列表和设置提醒也能帮助我保持高效。此外,我还注意到合理的休息和调整对工作效率至关重要。这样不仅提高了我的工作质量,也让我保持了良好的工作状态。', + '探索不同的音乐风格是我最近的兴趣之一。我特别喜欢电子音乐,尤其是那些融合了传统乐器的作品。这种音乐风格不仅提供了新的听觉体验,也让我对音乐的表现形式有了更深的理解。我发现,了解和欣赏不同风格的音乐,能够丰富我的音乐视野和审美体验。', + '照顾宠物需要全面的关注和细心的呵护。我了解到,定期带狗狗散步有助于它们的身体健康,同时提供丰富的玩具和定期的健康检查也很重要。此外,保持良好的饮食习惯对宠物的整体健康有很大影响。照顾宠物的过程中,了解它们的需求并给予关爱,能让它们生活得更加愉快和健康。', + '处理社交媒体信息过载,是我近期面临的一个问题。为了避免被海量的信息分散注意力,我开始设置每天查看社交媒体的时间限制,同时选择关注一些高质量的内容。此外,我还定期清理不感兴趣的账号,这样能够保持信息的有效性和对内容的专注。你们有什么管理社交媒体的好方法吗?', + '每个人都可以在日常生活中采取一些简单的环保行动。我开始减少一次性塑料的使用,进行垃圾分类,并尽量节约能源。这些小措施虽然看似微不足道,但积累起来对环境的保护却能产生积极影响。我相信,关注环保不仅是为了现在的生活,也为未来的子孙着想。你们在环保方面有什么实用的建议吗?', + '她给我们发了一张照片,呃,在一个满是山、山珍海味婚礼上她拿了一个巨小的饭盒在吃,反正就一个特别清淡的啊,减脂营备的餐,然后她呢当时在群里发这个是,嗯,为了求表扬,哈哈哈!', + '我这周末过得我觉得,我真的以为是真正意义上的休息,但是没想到周一的时候去上班,呃的时候,我,我还是感觉,呃,很奇怪,就是很提不起精神的感觉哎!', + '嗯,我刚刚就在想,你说创造一个环境其实,今年,呃,我去采访一位,呃,很有就是,阅历的同龄人的时候,她劝我做一件事情就是找一个心理咨询师,呃,去聊一聊。' ] llm_inputs = [] -tokenizer = Tokenizer.from_file('/home/zhn/fishtts/vocab.json') +tokenizer = Tokenizer.from_file('/data/fishtts/vocab.json') for text in texts: pinyin = "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]) txt = f"[zh-cn]{pinyin}" @@ -103,49 +129,148 @@ def convert_model(): token_ids.append(7003) llm_inputs.append(token_ids) +streaming=False +chunk_size=20 +frame_shift=1200 +hidden_size = 1536 +speaker_embedding = torch.zeros((1, 192, 1), dtype=torch.float32).to('cuda') +so = onnxruntime.SessionOptions() +so.enable_profiling = True +ort_session = onnxruntime.InferenceSession('/data/fishtts/genertor.onnx', providers=['CUDAExecutionProvider'], sess_options=so) +sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) +prompts = [ + {"prompt_token_ids": llm_input} for llm_input in llm_inputs +] + +class Metrics: + def __init__(self): + self.time_start = 0 + self.time_end = 0 + self.time_first_byte = 0 + self.token_times = [] + + def calc_non_streaming(self): + total_time = self.time_end - self.time_start + audio_time = len(self.token_times) * 50 / 1000 + rtf = total_time / audio_time + latent_time = self.token_times[-1] - self.time_start + first_byte_time = self.time_first_byte - self.time_start + print(f'latent time: {latent_time}, first byte time: {first_byte_time}, total time: {total_time}, audio time: {audio_time}, rtf: {rtf}') + + def calc_streaming(self): + total_time = self.time_end - self.time_start + audio_time = len(self.token_times) * 50 / 1000 + rtf = total_time / audio_time + first_chunk_time = self.token_times[chunk_size - 1] - self.time_start + first_byte_time = self.time_first_byte - self.time_start + print(f'first chunk time: {first_chunk_time}, first byte time: {first_byte_time}, total time: {total_time}, audio time: {audio_time}, rtf: {rtf}') + +def generate_chunk_audio(latent): + # pad to chunk_size + latent_len = latent.size(1) + if latent_len < chunk_size: + latent = torch.cat([latent, torch.zeros((1, chunk_size - latent_len % chunk_size, hidden_size), dtype=torch.float32).to('cuda')], 1) + onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), (latent, speaker_embedding))} + onnxruntime_outputs = ort_session.run(None, onnxruntime_input) + onnxruntime_outputs = onnxruntime_outputs[0][0][0] + if latent_len < chunk_size: + return onnxruntime_outputs[:latent_len * frame_shift] + return onnxruntime_outputs + +def save_audio(total_audio, path): + total_audio = numpy.concatenate(total_audio, axis=0) + sf.write(path, total_audio, 24000) + if not streaming: - llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) - prompts = [ - {"prompt_token_ids": llm_input} for llm_input in llm_inputs - ] + llm = LLM(model='/data/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) + for i in range(len(prompts)): + metrics = Metrics() + metrics.time_start = time.perf_counter() + outputs = llm.generate(prompts[i], sampling_params) + for output in outputs: + token_ids = output.outputs[0].token_ids + output_len = len(token_ids) + cur_time = time.perf_counter() + metrics.token_times.extend([cur_time] * output_len) + print(f'{i}: {output_len}') + + time_latent = time.perf_counter() + total_audio = [] + chunk_num = output_len // chunk_size + for j in range(chunk_num): + latent = torch.stack(output.outputs[0].hidden_states, 0).unsqueeze(0).to('cuda')[:,j*chunk_size:(j+1)*chunk_size] + onnxruntime_outputs = generate_chunk_audio(latent) + metrics.time_first_byte = time.perf_counter() + total_audio.append(onnxruntime_outputs) + + if output_len % chunk_size != 0: + latent = torch.stack(output.outputs[0].hidden_states, 0).unsqueeze(0).to('cuda')[:,chunk_num*chunk_size:] + onnxruntime_outputs = generate_chunk_audio(latent) + total_audio.append(onnxruntime_outputs) + + save_audio(total_audio, f'hh_{i}.wav') + print(f'save audio {i}') + + metrics.time_end = time.perf_counter() + metrics.calc_non_streaming() - sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) - outputs = llm.generate(prompts, sampling_params) - for output in outputs: - print(output.prompt) - token_ids = output.outputs[0].token_ids - for token_id in token_ids: - print([x - 0 for x in token_id]) - print(len(token_ids)) else: - engine_args = AsyncEngineArgs(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) + engine_args = AsyncEngineArgs(model='/data/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) model = AsyncLLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) - prompts = [ - {"prompt_token_ids": llm_input} for llm_input in llm_inputs - ] - async def generate_streaming(prompt, id): + def generate_audio_streaming(latent_queue, id, metrics: Metrics): + latent_buffer = [] + audio_data_buffer = [] + while True: + latent = latent_queue.get() + if latent is None: + break + latent_buffer.append(latent) + if len(latent_buffer) == chunk_size: + latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') + onnxruntime_outputs = generate_chunk_audio(latent) + audio_data_buffer.append(onnxruntime_outputs) + latent_buffer = [] + + if metrics.time_first_byte == 0: + metrics.time_first_byte = time.perf_counter() + + if len(latent_buffer) > 0: + latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') + onnxruntime_outputs = generate_chunk_audio(latent) + audio_data_buffer.append(onnxruntime_outputs) + + save_audio(audio_data_buffer, f'hh_{id}.wav') + print(f'save audio {id}') + + + async def generate_token_streaming(prompt, id, latent_queue, metrics: Metrics): results_generator = model.generate(prompt, sampling_params, request_id=id) - count=0 tokens = [] async for request_output in results_generator: - token_ids = request_output.outputs[0].token_ids - print(f'{id} {[x - 0 for x in token_ids[-1]]}') - tokens.append([x - 0 for x in token_ids[-1]]) - count+=1 - - print(id) - print(len(tokens)) - for token in tokens: - print(token) - - async def generate(): - tasks = [] - for i in range(2): - t = generate_streaming(prompts[i%2], i) - tasks.append(t) - await asyncio.gather(*tasks) - - asyncio.run(generate()) + metrics.token_times.append(time.perf_counter()) + token_ids = request_output.outputs[0].token_ids[-1] + latent = request_output.outputs[0].hidden_states[-1] + tokens.append(token_ids) + latent_queue.put(latent) + + latent_queue.put(None) + print(f'{id}: {len(tokens)}') + + async def generate(prompts): + for i in range(len(prompts)): + metrics = Metrics() + metrics.time_start = time.perf_counter() + + q = queue.Queue() + t = asyncio.create_task(generate_token_streaming(prompts[i], i, q, metrics)) + g = threading.Thread(target=generate_audio_streaming, args=(q, i, metrics)) + g.start() + await t + g.join() + metrics.time_end = time.perf_counter() + metrics.calc_streaming() + + asyncio.run(generate(prompts)) + From 1eadf24df344da5d53cb7159a6baa5ad608aedf1 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Sun, 29 Sep 2024 10:35:20 +0000 Subject: [PATCH 54/61] update --- fishtts_sample.py | 168 +++++++++++++++++++++++++++++----------------- 1 file changed, 107 insertions(+), 61 deletions(-) diff --git a/fishtts_sample.py b/fishtts_sample.py index c4683d068a57c..af7478d22b069 100644 --- a/fishtts_sample.py +++ b/fishtts_sample.py @@ -1,9 +1,10 @@ +import argparse import asyncio import threading import time from typing import List -import numpy +import numpy as np import onnx from vllm import LLM, SamplingParams from tokenizers import Tokenizer @@ -129,14 +130,10 @@ def to_numpy(tensor): token_ids.append(7003) llm_inputs.append(token_ids) -streaming=False chunk_size=20 frame_shift=1200 hidden_size = 1536 speaker_embedding = torch.zeros((1, 192, 1), dtype=torch.float32).to('cuda') -so = onnxruntime.SessionOptions() -so.enable_profiling = True -ort_session = onnxruntime.InferenceSession('/data/fishtts/genertor.onnx', providers=['CUDAExecutionProvider'], sess_options=so) sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) prompts = [ {"prompt_token_ids": llm_input} for llm_input in llm_inputs @@ -178,11 +175,11 @@ def generate_chunk_audio(latent): return onnxruntime_outputs def save_audio(total_audio, path): - total_audio = numpy.concatenate(total_audio, axis=0) + total_audio = np.concatenate(total_audio, axis=0) sf.write(path, total_audio, 24000) -if not streaming: - llm = LLM(model='/data/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) +def run(): + llm = LLM(model='/data/fishtts', gpu_memory_utilization=0.7, dtype=torch.float32, skip_tokenizer_init=True) for i in range(len(prompts)): metrics = Metrics() metrics.time_start = time.perf_counter() @@ -214,63 +211,112 @@ def save_audio(total_audio, path): metrics.time_end = time.perf_counter() metrics.calc_non_streaming() +async def generate_audio_streaming(latent_queue: asyncio.Queue, id, metrics: Metrics): + latent_buffer = [] + audio_data_buffer = [] + while True: + latent = await latent_queue.get() + if latent is None: + break + latent_buffer.append(latent) + if len(latent_buffer) == chunk_size: + latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') + audio_data_buffer.append(generate_chunk_audio(latent)) + latent_buffer = [] -else: - engine_args = AsyncEngineArgs(model='/data/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True) - model = AsyncLLMEngine.from_engine_args(engine_args) - - def generate_audio_streaming(latent_queue, id, metrics: Metrics): - latent_buffer = [] - audio_data_buffer = [] - while True: - latent = latent_queue.get() - if latent is None: - break - latent_buffer.append(latent) - if len(latent_buffer) == chunk_size: - latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') - onnxruntime_outputs = generate_chunk_audio(latent) - audio_data_buffer.append(onnxruntime_outputs) - latent_buffer = [] - - if metrics.time_first_byte == 0: - metrics.time_first_byte = time.perf_counter() + if metrics.time_first_byte == 0: + metrics.time_first_byte = time.perf_counter() - if len(latent_buffer) > 0: - latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') - onnxruntime_outputs = generate_chunk_audio(latent) - audio_data_buffer.append(onnxruntime_outputs) - - save_audio(audio_data_buffer, f'hh_{id}.wav') - print(f'save audio {id}') + if len(latent_buffer) > 0: + latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') + audio_data_buffer.append(generate_chunk_audio(latent)) + save_audio(audio_data_buffer, f'hh_{id}.wav') + print(f'save audio {id}') + +async def generate_token_streaming(engine: AsyncLLMEngine, prompt, id, latent_queue: asyncio.Queue, metrics: Metrics): + results_generator = engine.generate(prompt, sampling_params, request_id=id) + tokens = [] + async for request_output in results_generator: + metrics.token_times.append(time.perf_counter()) + token_ids = request_output.outputs[0].token_ids[-1] + latent = request_output.outputs[0].hidden_states[-1] + tokens.append(token_ids) + latent_queue.put_nowait(latent) + + latent_queue.put_nowait(None) + print(f'{id}: {len(tokens)}') + +async def get_request( + input_requests, + request_rate: float, +): + requests = iter(input_requests) + for request in requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + +async def generate_streaming(engine, prompt, request_id) -> Metrics: + metrics = Metrics() + metrics.time_start = time.perf_counter() - async def generate_token_streaming(prompt, id, latent_queue, metrics: Metrics): - results_generator = model.generate(prompt, sampling_params, request_id=id) - tokens = [] - async for request_output in results_generator: - metrics.token_times.append(time.perf_counter()) - token_ids = request_output.outputs[0].token_ids[-1] - latent = request_output.outputs[0].hidden_states[-1] - tokens.append(token_ids) - latent_queue.put(latent) - - latent_queue.put(None) - print(f'{id}: {len(tokens)}') - - async def generate(prompts): + latent_queue = asyncio.Queue() + vllm_task = asyncio.create_task(generate_token_streaming(engine, prompt, request_id, latent_queue, metrics)) + generator_task = asyncio.create_task(generate_audio_streaming(latent_queue, request_id, metrics)) + await vllm_task + await generator_task + metrics.time_end = time.perf_counter() + return metrics + +async def run_streaming(request_rate): + engine_args = AsyncEngineArgs(model='/data/fishtts', gpu_memory_utilization=0.7, dtype=torch.float32, skip_tokenizer_init=True) + engine = AsyncLLMEngine.from_engine_args(engine_args) + if request_rate < 0: for i in range(len(prompts)): - metrics = Metrics() - metrics.time_start = time.perf_counter() - - q = queue.Queue() - t = asyncio.create_task(generate_token_streaming(prompts[i], i, q, metrics)) - g = threading.Thread(target=generate_audio_streaming, args=(q, i, metrics)) - g.start() - await t - g.join() - metrics.time_end = time.perf_counter() + me = await generate_streaming(engine, prompts[i], i) + me.calc_streaming() + else: + tasks: List[asyncio.Task] = [] + request_id = 0 + me = await generate_streaming(engine, prompts[0], 0) + async for prompt in get_request(prompts, request_rate): + tasks.append(asyncio.create_task(generate_streaming(engine, prompt, request_id))) + request_id += 1 + + metrics_list: List[Metrics] = await asyncio.gather(*tasks) + for metrics in metrics_list: metrics.calc_streaming() - asyncio.run(generate(prompts)) - +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--streaming", action="store_true") + parser.add_argument("--request-rate", + type=float, + default=-1, + help="request rate per second") + parser.add_argument("--chunk-size", + type=int, + default=20, + help="audio chunk size") + + args = parser.parse_args() + + if args.chunk_size: + chunk_size = args.chunk_size + + ort_session = onnxruntime.InferenceSession('/data/fishtts/genertor.onnx', providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider']) + warmup_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), (torch.zeros(1, chunk_size, hidden_size).to('cuda'), speaker_embedding))} + warmup_outputs = ort_session.run(None, warmup_input) + + if not args.streaming: + run() + else: + asyncio.run(run_streaming(args.request_rate)) \ No newline at end of file From 06244eb1a1c8e59eb6e5d5c05dfccf75cb9fd617 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Mon, 30 Sep 2024 10:34:49 +0000 Subject: [PATCH 55/61] try add trt --- fishtts_sample.py | 174 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 121 insertions(+), 53 deletions(-) diff --git a/fishtts_sample.py b/fishtts_sample.py index af7478d22b069..ce7c3d6b93571 100644 --- a/fishtts_sample.py +++ b/fishtts_sample.py @@ -17,6 +17,10 @@ import soundfile as sf import queue +import tensorrt as trt +import pycuda.driver as cuda +import pycuda.autoinit + torch.random.manual_seed(999) def convert_model(): @@ -130,49 +134,129 @@ def to_numpy(tensor): token_ids.append(7003) llm_inputs.append(token_ids) -chunk_size=20 -frame_shift=1200 -hidden_size = 1536 -speaker_embedding = torch.zeros((1, 192, 1), dtype=torch.float32).to('cuda') sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) prompts = [ {"prompt_token_ids": llm_input} for llm_input in llm_inputs ] +ctx = cuda.Device(0).make_context() + class Metrics: - def __init__(self): + def __init__(self, chunk_size=20): + self.chunk_size = chunk_size self.time_start = 0 self.time_end = 0 - self.time_first_byte = 0 self.token_times = [] + self.audio_chunk_times = [] def calc_non_streaming(self): total_time = self.time_end - self.time_start audio_time = len(self.token_times) * 50 / 1000 rtf = total_time / audio_time latent_time = self.token_times[-1] - self.time_start - first_byte_time = self.time_first_byte - self.time_start + first_byte_time = self.audio_chunk_times[0] - self.time_start print(f'latent time: {latent_time}, first byte time: {first_byte_time}, total time: {total_time}, audio time: {audio_time}, rtf: {rtf}') def calc_streaming(self): total_time = self.time_end - self.time_start audio_time = len(self.token_times) * 50 / 1000 rtf = total_time / audio_time - first_chunk_time = self.token_times[chunk_size - 1] - self.time_start - first_byte_time = self.time_first_byte - self.time_start + first_chunk_time = self.token_times[self.chunk_size - 1] - self.time_start + first_byte_time = self.audio_chunk_times[0] - self.time_start print(f'first chunk time: {first_chunk_time}, first byte time: {first_byte_time}, total time: {total_time}, audio time: {audio_time}, rtf: {rtf}') -def generate_chunk_audio(latent): - # pad to chunk_size - latent_len = latent.size(1) - if latent_len < chunk_size: - latent = torch.cat([latent, torch.zeros((1, chunk_size - latent_len % chunk_size, hidden_size), dtype=torch.float32).to('cuda')], 1) - onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), (latent, speaker_embedding))} - onnxruntime_outputs = ort_session.run(None, onnxruntime_input) - onnxruntime_outputs = onnxruntime_outputs[0][0][0] - if latent_len < chunk_size: - return onnxruntime_outputs[:latent_len * frame_shift] - return onnxruntime_outputs +class Generator: + def __init__(self, model_path, use_trt=False, chunk_size=20, hidden_size=1536, frame_shift=1200): + self.onnx_session = None + self.trt_engine = None + self.use_trt = use_trt + self.model_path = model_path + self.chunk_size = chunk_size + self.hidden_size = hidden_size + self.frame_shift = frame_shift + self.speaker_embedding = torch.zeros((1, 192, 1), dtype=torch.float32).to('cuda') + + def generate_audio_onnx(self, latent): + onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(self.onnx_session.get_inputs(), (latent, self.speaker_embedding))} + onnxruntime_outputs = self.onnx_session.run(None, onnxruntime_input) + onnxruntime_outputs = onnxruntime_outputs[0][0][0] + return onnxruntime_outputs + + def generate_audio_trt(self, latent): + with self.trt_engine.create_execution_context() as context: + ctx.push() + + stream = cuda.Stream() + context.set_input_shape('input', (1, self.chunk_size, self.hidden_size)) + context.set_input_shape('speaker_embedding', (1, 192, 1)) + + bindings = [] + bindings.append(latent.data_ptr()) + bindings.append(self.speaker_embedding.data_ptr()) + dtype = trt.nptype(self.trt_engine.get_tensor_dtype("output")) + size = trt.volume(context.get_tensor_shape('output')) + output_buffer = cuda.pagelocked_empty(size, dtype) + output_memory = cuda.mem_alloc(output_buffer.nbytes) + bindings.append(int(output_memory)) + + for i in range(len(bindings)): + context.set_tensor_address(self.trt_engine.get_tensor_name(i), bindings[i]) + + context.execute_async_v3(stream_handle=stream.handle) + stream.synchronize() + + cuda.memcpy_dtoh_async(output_buffer, output_memory, stream) + + ctx.pop() + return output_buffer + + def generate_audio(self, latent, metric: Metrics): + latent_len = latent.size(1) + total_audio = [] + chunk_num = latent_len // self.chunk_size + for j in range(chunk_num): + latent_chunk = latent[:,j*self.chunk_size:(j+1)*self.chunk_size] + if self.use_trt: + audio_outputs = self.generate_audio_trt(latent_chunk) + else: + audio_outputs = self.generate_audio_onnx(latent_chunk) + total_audio.append(audio_outputs) + metric.audio_chunk_times.append(time.perf_counter()) + if latent_len % self.chunk_size != 0: + latent_chunk = latent[:,chunk_num*self.chunk_size:] + latent_chunk = torch.cat([latent_chunk, torch.zeros((1, self.chunk_size - latent_chunk.size(1), self.hidden_size), dtype=torch.float32).to('cuda')], 1) + if self.use_trt: + audio_outputs = self.generate_audio_trt(latent_chunk) + else: + audio_outputs = self.generate_audio_onnx(latent_chunk) + audio_outputs = audio_outputs[:latent_len % self.chunk_size * self.frame_shift] + total_audio.append(audio_outputs) + metric.audio_chunk_times.append(time.perf_counter()) + + return total_audio + + def warm_up_onnx(self): + self.onnx_session = onnxruntime.InferenceSession('/data/fishtts/genertor.onnx', providers=['CUDAExecutionProvider']) + warmup_input = torch.zeros(1, self.chunk_size, self.hidden_size).to('cuda') + self.generate_audio_onnx(warmup_input) + print(f'warmup onnx done') + + def warm_up_trt(self): + trt_logger = trt.Logger(trt.Logger.INFO) + trt_runtime = trt.Runtime(trt_logger) + with open('/data/fishtts/genertor.trt', 'rb') as f: + self.trt_engine = trt_runtime.deserialize_cuda_engine(f.read()) + warmup_input = torch.zeros(1, self.chunk_size, self.hidden_size).to('cuda') + self.generate_audio_trt(warmup_input) + print(f'warmup trt done') + + + def warm_up(self, use_trt): + self.use_trt = use_trt + if use_trt: + self.warm_up_trt() + else: + self.warm_up_onnx() def save_audio(total_audio, path): total_audio = np.concatenate(total_audio, axis=0) @@ -181,7 +265,7 @@ def save_audio(total_audio, path): def run(): llm = LLM(model='/data/fishtts', gpu_memory_utilization=0.7, dtype=torch.float32, skip_tokenizer_init=True) for i in range(len(prompts)): - metrics = Metrics() + metrics = Metrics(chunk_size=generator.chunk_size) metrics.time_start = time.perf_counter() outputs = llm.generate(prompts[i], sampling_params) for output in outputs: @@ -190,20 +274,8 @@ def run(): cur_time = time.perf_counter() metrics.token_times.extend([cur_time] * output_len) print(f'{i}: {output_len}') - - time_latent = time.perf_counter() - total_audio = [] - chunk_num = output_len // chunk_size - for j in range(chunk_num): - latent = torch.stack(output.outputs[0].hidden_states, 0).unsqueeze(0).to('cuda')[:,j*chunk_size:(j+1)*chunk_size] - onnxruntime_outputs = generate_chunk_audio(latent) - metrics.time_first_byte = time.perf_counter() - total_audio.append(onnxruntime_outputs) - - if output_len % chunk_size != 0: - latent = torch.stack(output.outputs[0].hidden_states, 0).unsqueeze(0).to('cuda')[:,chunk_num*chunk_size:] - onnxruntime_outputs = generate_chunk_audio(latent) - total_audio.append(onnxruntime_outputs) + latent = torch.stack(output.outputs[0].hidden_states, 0).unsqueeze(0).to('cuda') + total_audio = generator.generate_audio(latent, metrics) save_audio(total_audio, f'hh_{i}.wav') print(f'save audio {i}') @@ -219,17 +291,14 @@ async def generate_audio_streaming(latent_queue: asyncio.Queue, id, metrics: Met if latent is None: break latent_buffer.append(latent) - if len(latent_buffer) == chunk_size: + if len(latent_buffer) == generator.chunk_size: latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') - audio_data_buffer.append(generate_chunk_audio(latent)) + audio_data_buffer.extend(generator.generate_audio(latent, metrics)) latent_buffer = [] - if metrics.time_first_byte == 0: - metrics.time_first_byte = time.perf_counter() - if len(latent_buffer) > 0: latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') - audio_data_buffer.append(generate_chunk_audio(latent)) + audio_data_buffer.extend(generator.generate_audio(latent, metrics)) save_audio(audio_data_buffer, f'hh_{id}.wav') print(f'save audio {id}') @@ -247,10 +316,7 @@ async def generate_token_streaming(engine: AsyncLLMEngine, prompt, id, latent_qu latent_queue.put_nowait(None) print(f'{id}: {len(tokens)}') -async def get_request( - input_requests, - request_rate: float, -): +async def get_request(input_requests, request_rate: float,): requests = iter(input_requests) for request in requests: yield request @@ -265,7 +331,7 @@ async def get_request( await asyncio.sleep(interval) async def generate_streaming(engine, prompt, request_id) -> Metrics: - metrics = Metrics() + metrics = Metrics(chunk_size=generator.chunk_size) metrics.time_start = time.perf_counter() latent_queue = asyncio.Queue() @@ -295,27 +361,29 @@ async def run_streaming(request_rate): for metrics in metrics_list: metrics.calc_streaming() + +generator = Generator('/data/fishtts', use_trt=False, chunk_size=20, hidden_size=1536, frame_shift=1200) + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--streaming", action="store_true") + parser.add_argument("--streaming", action="store_true", default=False) parser.add_argument("--request-rate", type=float, default=-1, help="request rate per second") parser.add_argument("--chunk-size", type=int, - default=20, + default=None, help="audio chunk size") + parser.add_argument("--use-trt", action="store_true", default=False) args = parser.parse_args() if args.chunk_size: - chunk_size = args.chunk_size - - ort_session = onnxruntime.InferenceSession('/data/fishtts/genertor.onnx', providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider']) - warmup_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), (torch.zeros(1, chunk_size, hidden_size).to('cuda'), speaker_embedding))} - warmup_outputs = ort_session.run(None, warmup_input) + generator.chunk_size = args.chunk_size + generator.warm_up(args.use_trt) + if not args.streaming: run() else: From d8f9a59ae25c1233c7787cddca2d65e338545fee Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Mon, 30 Sep 2024 15:06:00 +0000 Subject: [PATCH 56/61] trt work --- fishtts_sample.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/fishtts_sample.py b/fishtts_sample.py index ce7c3d6b93571..a78a7f53a8e7b 100644 --- a/fishtts_sample.py +++ b/fishtts_sample.py @@ -139,7 +139,7 @@ def to_numpy(tensor): {"prompt_token_ids": llm_input} for llm_input in llm_inputs ] -ctx = cuda.Device(0).make_context() +# ctx = cuda.Device(0).make_context() class Metrics: def __init__(self, chunk_size=20): @@ -184,15 +184,25 @@ def generate_audio_onnx(self, latent): def generate_audio_trt(self, latent): with self.trt_engine.create_execution_context() as context: - ctx.push() + # ctx.push() stream = cuda.Stream() context.set_input_shape('input', (1, self.chunk_size, self.hidden_size)) context.set_input_shape('speaker_embedding', (1, 192, 1)) bindings = [] + + # d_input = cuda.mem_alloc(latent.nbytes) + # cuda.memcpy_htod_async(d_input, latent.cpu().numpy(), stream) + # cuda.memcpy_dtod_async(d_input, latent.data_ptr(), latent.nbytes, stream) bindings.append(latent.data_ptr()) + + # d_speaker_embedding = cuda.mem_alloc(self.speaker_embedding.nbytes) + # cuda.memcpy_htod_async(d_speaker_embedding, self.speaker_embedding.cpu().numpy(), stream) + # cuda.memcpy_dtod_async(d_speaker_embedding, self.speaker_embedding.data_ptr(), self.speaker_embedding.nbytes, stream) + # bindings.append(int(d_speaker_embedding)) bindings.append(self.speaker_embedding.data_ptr()) + dtype = trt.nptype(self.trt_engine.get_tensor_dtype("output")) size = trt.volume(context.get_tensor_shape('output')) output_buffer = cuda.pagelocked_empty(size, dtype) @@ -207,7 +217,7 @@ def generate_audio_trt(self, latent): cuda.memcpy_dtoh_async(output_buffer, output_memory, stream) - ctx.pop() + # ctx.pop() return output_buffer def generate_audio(self, latent, metric: Metrics): @@ -242,9 +252,9 @@ def warm_up_onnx(self): print(f'warmup onnx done') def warm_up_trt(self): - trt_logger = trt.Logger(trt.Logger.INFO) + trt_logger = trt.Logger(trt.Logger.ERROR) trt_runtime = trt.Runtime(trt_logger) - with open('/data/fishtts/genertor.trt', 'rb') as f: + with open('/data/fishtts/genertor.fp16.trt', 'rb') as f: self.trt_engine = trt_runtime.deserialize_cuda_engine(f.read()) warmup_input = torch.zeros(1, self.chunk_size, self.hidden_size).to('cuda') self.generate_audio_trt(warmup_input) @@ -376,6 +386,7 @@ async def run_streaming(request_rate): default=None, help="audio chunk size") parser.add_argument("--use-trt", action="store_true", default=False) + parser.add_argument("--fp16", action="store_true", default=False) args = parser.parse_args() From 81e66fb357e30896a556ba48cdae2ee344906d16 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Thu, 19 Dec 2024 18:53:50 +0800 Subject: [PATCH 57/61] fish_tts_changes --- requirements-cuda.txt | 4 + vllm/lora/models.py | 33 +++- vllm/model_executor/models/fishtts.py | 28 ++- vllm/multimodal/speech.py | 25 +-- xtts/__init__.py | 0 xtts/generator.py | 188 ++++++++++++++++++ xtts/metrics.py | 27 +++ xtts/model_setting.py | 38 ++++ xtts/preceiver_resampler.py | 35 ++++ xtts/run.py | 77 ++++++++ xtts/tts_engine.py | 255 ++++++++++++++++++++++++ xtts/utils.py | 274 ++++++++++++++++++++++++++ 12 files changed, 949 insertions(+), 35 deletions(-) create mode 100644 xtts/__init__.py create mode 100644 xtts/generator.py create mode 100644 xtts/metrics.py create mode 100644 xtts/model_setting.py create mode 100644 xtts/preceiver_resampler.py create mode 100644 xtts/run.py create mode 100644 xtts/tts_engine.py create mode 100644 xtts/utils.py diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 5d4dee8c7129a..b7cef9e700fbf 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,3 +8,7 @@ torch == 2.5.1; platform_machine != 'aarch64' # These must be updated alongside torch torchvision == 0.20.1; platform_machine != 'aarch64' # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1 +onnxruntime-gpu +pycuda +pypinyin +omegaconf \ No newline at end of file diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 70806a77b9fff..b634c53db5802 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -486,14 +486,31 @@ def _create_lora_modules(self): new_module.scaling_factor_to_offset # (yard1): TODO make this more robust if "lm_head" in module_name: - logits_processor_module = self.model.get_submodule( - "logits_processor") - new_module = replace_submodule( - self.model, "logits_processor", - from_layer_logits_processor(logits_processor_module, - module, self.lora_slots, - self.lora_config, - self.model.config)) + if hasattr(self.model.config, "num_output_head") and \ + self.model.config.num_output_head > 1: + # hack for multi-head output + # we need to replace the logits processor for each head + for i in range(self.model.config.num_output_head): + logits_processor_module = self.model.get_submodule(f"logits_processor.{i}") + new_module = replace_submodule( + self.model, f"logits_processor.{i}", + from_layer_logits_processor(logits_processor_module, + module[i], self.lora_slots, + self.lora_config, + self.model.config)) + self.register_module(f"{module_name}.{i}", new_module) + self._register_packed_modules(f"{module_name}.{i}") + new_module.set_mapping(self.punica_wrapper) + continue + else: + logits_processor_module = self.model.get_submodule( + "logits_processor") + new_module = replace_submodule( + self.model, "logits_processor", + from_layer_logits_processor(logits_processor_module, + module, self.lora_slots, + self.lora_config, + self.model.config)) # In some models, especially multimodal ones, layers with the same # name may have different types, such as nn.Linear and diff --git a/vllm/model_executor/models/fishtts.py b/vllm/model_executor/models/fishtts.py index c296f4260a35d..557a2c800028c 100644 --- a/vllm/model_executor/models/fishtts.py +++ b/vllm/model_executor/models/fishtts.py @@ -92,6 +92,7 @@ def __init__(self, self.num_text_tokens = config.num_text_tokens self.num_output_head = config.num_output_head self.audio_start_token_id = config.audio_start_token_id + self.audio_ref_token_id = config.audio_ref_start_token_id self.gpt = LlamaModel(config, lora_config=lora_config) self.model_dim = self.gpt.config.hidden_size @@ -100,9 +101,12 @@ def __init__(self, VocabParallelEmbedding(self.num_audio_tokens, self.model_dim) for _ in range(self.num_output_head) ]) + unpadded_vocab_size = self.num_audio_tokens + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = nn.ModuleList([ ParallelLMHead( - self.num_audio_tokens, + unpadded_vocab_size, self.model_dim, org_num_embeddings=self.num_audio_tokens, padding_size=DEFAULT_VOCAB_PADDING_SIZE @@ -112,7 +116,7 @@ def __init__(self, quant_config=quant_config, ) for _ in range(self.num_output_head) ]) - self.logits_processor = LogitsProcessor(self.num_audio_tokens) + self.logits_processor = nn.ModuleList([LogitsProcessor(unpadded_vocab_size, self.num_audio_tokens) for _ in range(self.num_output_head)]) self.sampler = MultiheadSampler() def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -155,9 +159,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): except KeyError: pass - def get_input_embeddings(self, input_ids: torch.Tensor, is_prompt: bool) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor, audio_ref: torch.Tensor, is_prompt: bool) -> torch.Tensor: if is_prompt: - emb = self.emb_text(input_ids) + emb: torch.Tensor = self.emb_text(input_ids) audio_start = torch.tensor([1024, 1024], device=input_ids.device) code_emb = [ self.emb_code[i](audio_start[i]) @@ -165,10 +169,19 @@ def get_input_embeddings(self, input_ids: torch.Tensor, is_prompt: bool) -> torc ] start_token = torch.stack(code_emb, 1).sum(1).to(emb.dtype) - # find the index of the speaker token + # find the index of the audio BOS token indices = (input_ids == self.audio_start_token_id).nonzero(as_tuple=True) if indices[0].size(0) != 0: emb.index_put_(indices, start_token) + + # batch size = 2 + # inpudId 7004 7004 XXXX 7004 7001 1 2 34 | 7003 7004 7004 XXXX 7004 7001 1 2 34 7003 + # speaker ref [16*2, 1536] + indices = (input_ids == self.audio_ref_token_id).nonzero(as_tuple=True) + if indices[0].size(0) != 0: + for idx, audio_ref_start in enumerate(indices[0]): + emb[audio_ref_start:audio_ref_start+16] = audio_ref[idx].to(emb.dtype) + else: code_emb = [ self.emb_code[i](input_ids[:,i]) for i in range(self.num_output_head) @@ -179,7 +192,7 @@ def get_input_embeddings(self, input_ids: torch.Tensor, is_prompt: bool) -> torc def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = [ - self.logits_processor(self.lm_head[i], hidden_states, sampling_metadata) + self.logits_processor[i](self.lm_head[i], hidden_states, sampling_metadata) for i in range(self.num_output_head) ] logits = torch.stack(logits, 0).permute(1, 0, 2) @@ -207,7 +220,8 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids, is_prompt) + audio_ref = kwargs.get("audio", None) + hidden_states = self.get_input_embeddings(input_ids, audio_ref, is_prompt) model_output = self.gpt( input_ids=input_ids, inputs_embeds=hidden_states, diff --git a/vllm/multimodal/speech.py b/vllm/multimodal/speech.py index 42375ac6634cf..7456204ce464f 100644 --- a/vllm/multimodal/speech.py +++ b/vllm/multimodal/speech.py @@ -23,32 +23,17 @@ class SpeechPlugin(MultiModalPlugin): def get_data_key(self) -> str: return "audio" - def _decode_spk_emb(self, spk_emb: str) -> np.ndarray: - n = base64.b64decode(spk_emb) - return np.frombuffer(n, dtype=np.float16).copy() - def _default_input_mapper(self, ctx: InputContext, data: object) -> MultiModalInputs: model_config = ctx.model_config - if isinstance(data, str): - n =F.normalize( - torch.from_numpy(self._decode_spk_emb(data)), - p=2.0, - dim=0, - eps=1e-12, - ) - - return MultiModalInputs({"speech": n}) - elif isinstance(data, torch.Tensor): - raise NotImplementedError("Embeddings input is not supported yet") - - raise TypeError(f"Invalid image type: {type(data)}") + if data is None: + return MultiModalInputs({"audio": torch.zeros(16, model_config.hf_config.hidden_size)}) + else: + return MultiModalInputs({"audio": data}) def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: return 3000 @staticmethod def sample_random_speaker() -> str: - n = np.random.randn(768).astype(np.float16) - s = base64.b64encode(n).decode("utf-8") - return s + return None diff --git a/xtts/__init__.py b/xtts/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/xtts/generator.py b/xtts/generator.py new file mode 100644 index 0000000000000..739495daef49e --- /dev/null +++ b/xtts/generator.py @@ -0,0 +1,188 @@ +import os +import time +from typing import List +import onnxruntime +from onnxruntime import InferenceSession, SessionOptions, RunOptions +import torch +from metrics import TtsMetrics + +import numpy as np +import soundfile as sf +import tensorrt as trt +import pycuda.driver as cuda +import pycuda.autoinit + +from utils import * +from model_setting import ModelSetting + + +class AudioGenerator: + def __init__(self, model_setting: ModelSetting): + self.onnx_session: InferenceSession = None + self.trt_engine = None + self.model_setting = model_setting + self.speaker_embedding = torch.zeros((1, 192, 1), dtype=self.model_setting.dtype).to('cuda') + self.model_path: str = None + + self.post_init() + + def post_init(self): + extension = 'onnx' if self.model_setting.runtime == 'onnx' else 'trt' + self.model_path = os.path.join(self.model_setting.model_dir, f'generator.{self.model_setting.dtype_str}.{extension}') + logger.info(f'Loading generator model from {self.model_path}') + if self.model_setting.runtime == 'onnx': + if self.model_setting.use_onnx_graph: + providers = [("CUDAExecutionProvider", {"enable_cuda_graph": '1'})] + sess_options = onnxruntime.SessionOptions() + self.onnx_session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, providers=providers) + self.onnx_input_buffer = torch.zeros(1, self.model_setting.chunk_size, self.model_setting.hidden_size, dtype=self.model_setting.dtype).to('cuda') + self.onnx_output_buffer = torch.zeros(1, 1, self.model_setting.chunk_size * self.model_setting.frame_shift, dtype=self.model_setting.dtype).to('cuda') + else: + providers = ["CUDAExecutionProvider"] + sess_options = onnxruntime.SessionOptions() + self.onnx_session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, providers=providers) + else: + trt_logger = trt.Logger(trt.Logger.ERROR) + trt_runtime = trt.Runtime(trt_logger) + with open(self.model_path, 'rb') as f: + self.trt_engine = trt_runtime.deserialize_cuda_engine(f.read()) + logger.info('Generator model loaded') + self.warm_up() + + def warm_up(self): + logger.info('warmup generator...') + warmup_input = torch.zeros(1, self.model_setting.chunk_size, self.model_setting.hidden_size).to('cuda').to(self.model_setting.dtype) + metrics = TtsMetrics() + self.generate_audio(warmup_input, metrics) + logger.info('warmup generator done') + + def generate_audio_onnx(self, latent: torch.Tensor) -> np.ndarray: + if self.model_setting.use_onnx_graph: + io_binding = self.onnx_session.io_binding() + latent_len = latent.size(1) + ro = onnxruntime.RunOptions() + ro.add_run_config_entry("gpu_graph_id", "1") + # copy latent to input buffer + self.onnx_input_buffer.copy_(latent) + io_binding.bind_input('input', + device_type='cuda', + device_id=0, + element_type=np.float16, + shape=tuple(self.onnx_input_buffer.shape), + buffer_ptr=self.onnx_input_buffer.data_ptr()) + io_binding.bind_input('speaker_embedding', + device_type='cuda', + device_id=0, + element_type=np.float16, + shape=tuple(self.speaker_embedding.shape), + buffer_ptr=self.speaker_embedding.data_ptr()) + io_binding.bind_output('output', + device_type='cuda', + device_id=0, + element_type=np.float16, + shape=tuple(self.onnx_output_buffer.shape), + buffer_ptr=self.onnx_output_buffer.data_ptr()) + self.onnx_session.run_with_iobinding(io_binding, ro) + + # copy output tensor to host + onnxruntime_outputs = self.onnx_output_buffer.cpu().numpy()[0][0] + else: + onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(self.onnx_session.get_inputs(), (latent, self.speaker_embedding))} + onnxruntime_outputs = self.onnx_session.run(None, onnxruntime_input) + onnxruntime_outputs = onnxruntime_outputs[0][0][0] + return onnxruntime_outputs + + def generate_audio_trt(self, latent: torch.Tensor) -> np.ndarray: + with self.trt_engine.create_execution_context() as context: + stream = cuda.Stream() + + # set input shape as it supports dynamic shape + latent_len = latent.size(1) + context.set_input_shape('input', (1, latent_len, self.model_setting.hidden_size)) + context.set_input_shape('speaker_embedding', (1, 192, 1)) + + # set input data + bindings = [] + input_buffer_1 = latent.to('cuda').to(torch.float32) + bindings.append(input_buffer_1.data_ptr()) + input_buffer_2 = self.speaker_embedding.to('cuda').to(torch.float32) + bindings.append(input_buffer_2.data_ptr()) + + # get output size based on input + # and set output buffer + dtype = trt.nptype(self.trt_engine.get_tensor_dtype("output")) + size = trt.volume(context.get_tensor_shape('output')) + logger.debug(f'output dtype: {dtype}, size: {size}') + output_buffer = cuda.pagelocked_empty(size, dtype) + output_memory = cuda.mem_alloc(output_buffer.nbytes) + bindings.append(int(output_memory)) + + # set input and output buffer to context + for i in range(len(bindings)): + context.set_tensor_address(self.trt_engine.get_tensor_name(i), bindings[i]) + + # execute inference + logger.debug('execute trt engine') + context.execute_async_v3(stream_handle=stream.handle) + stream.synchronize() + logger.debug('execute trt engine done') + + # copy output buffer to host + cuda.memcpy_dtoh_async(output_buffer, output_memory, stream) + + return output_buffer + + def generate_chunk_audio(self, latent: torch.Tensor, metric: TtsMetrics, padding: bool, trim_begin: bool, trim_end: bool) -> np.ndarray: + latent_len = latent.size(1) + + if latent_len % self.model_setting.chunk_size != 0 and padding: + # pad to chunk size with last frame + pad_len = self.model_setting.chunk_size - latent_len % self.model_setting.chunk_size + latent = torch.cat([latent, latent[:, -1:, :].repeat(1, pad_len, 1)], dim=1) + + if self.model_setting.runtime == 'trt': + audio_outputs = self.generate_audio_trt(latent) + else: + audio_outputs = self.generate_audio_onnx(latent) + + metric.audio_chunk_times.append(time.perf_counter()) + + if padding: + audio_outputs = audio_outputs[:latent_len * self.model_setting.frame_shift] + + if self.model_setting.overlap_window > 0 and trim_begin: + audio_outputs = audio_outputs[self.model_setting.overlap_window * self.model_setting.frame_shift:] + + if self.model_setting.overlap_window > 0 and trim_end: + audio_outputs = audio_outputs[:-self.model_setting.overlap_window * self.model_setting.frame_shift] + + # convert to fp32 + audio_outputs = audio_outputs.astype(np.float32) + return audio_outputs + + def generate_audio(self, latent: torch.Tensor, metric: TtsMetrics) -> List[np.ndarray]: + logger.debug(f'latent shape: {latent.shape}') + latent_len = latent.size(1) + total_audio: List[np.ndarray] = [] + padding = self.model_setting.chunk_padding + overlap_window = self.model_setting.overlap_window + chunk_size = self.model_setting.chunk_size + + if latent.shape[1] > chunk_size: + total_len = latent.shape[1] + start_idx = 0 + audio_seg = self.generate_chunk_audio(latent[:, :chunk_size, :], metric, False, False, True) + total_audio.append(audio_seg) + start_idx = chunk_size - overlap_window * 2 + while start_idx < total_len: + if total_len <= start_idx + chunk_size: + audio_seg = self.generate_chunk_audio(latent[:, start_idx:, :], metric, padding, True, False) + else: + audio_seg = self.generate_chunk_audio(latent[:, start_idx:start_idx + chunk_size, :], metric, False, True, True) + total_audio.append(audio_seg) + start_idx += chunk_size - overlap_window * 2 + else: + audio_outputs = self.generate_chunk_audio(latent, metric, padding, False, False) + total_audio.append(audio_outputs) + + return total_audio diff --git a/xtts/metrics.py b/xtts/metrics.py new file mode 100644 index 0000000000000..84c0189adede3 --- /dev/null +++ b/xtts/metrics.py @@ -0,0 +1,27 @@ +from typing import List +from utils import * + +class TtsMetrics: + def __init__(self, chunk_size: int = 20, first_chunk_size: int = 10): + self.chunk_size = chunk_size + self.first_chunk_size = first_chunk_size + self.time_start: float = 0 + self.time_end: float = 0 + self.token_times: List[float] = [] + self.audio_chunk_times: List[float] = [] + + def calc_non_streaming(self): + total_time = self.time_end - self.time_start + audio_time = len(self.token_times) * 50 / 1000 + rtf = total_time / audio_time + latent_time = self.token_times[-1] - self.time_start + first_byte_time = self.audio_chunk_times[0] - self.time_start + print(f'latent time: {latent_time}, first byte time: {first_byte_time}, total time: {total_time}, audio time: {audio_time}, rtf: {rtf}') + + def calc_streaming(self): + total_time = self.time_end - self.time_start + audio_time = len(self.token_times) * 50 / 1000 + rtf = total_time / audio_time + first_chunk_time = self.token_times[self.first_chunk_size - 1] - self.time_start + first_byte_time = self.audio_chunk_times[0] - self.time_start + print(f'first chunk time: {first_chunk_time}, first byte time: {first_byte_time}, total time: {total_time}, audio time: {audio_time}, rtf: {rtf}') \ No newline at end of file diff --git a/xtts/model_setting.py b/xtts/model_setting.py new file mode 100644 index 0000000000000..300c1562885ce --- /dev/null +++ b/xtts/model_setting.py @@ -0,0 +1,38 @@ +import torch + + +class ModelSetting: + def __init__(self, + model_dir: str = None, + runtime: str = 'onnx', + dtype: str = 'float32', + chunk_size: int = 20, + overlap_window: int = 0, + hidden_size: int = 1536, + frame_shift: int = 1200, + streaming: bool = False, + first_chunk_size: int = 10, + chunk_padding: bool = True, + cut_tail: int = 150, + support_lora: bool = False, + scale_rate: float = 2.7): + self.model_dir = model_dir + self.runtime = runtime + self.chunk_size = chunk_size + self.overlap_window = overlap_window + self.hidden_size = hidden_size + self.frame_shift = frame_shift + self.streaming = streaming + self.first_chunk_size = first_chunk_size + self.chunk_padding = chunk_padding + self.dtype_str = dtype + if dtype == 'float32': + self.dtype = torch.float32 + elif dtype == 'float16': + self.dtype = torch.float16 + self.cut_tail = cut_tail + self.support_lora = support_lora + self.use_onnx_graph = False + + self.gpu_memory_utilization = 0.7 + self.scale_rate = scale_rate \ No newline at end of file diff --git a/xtts/preceiver_resampler.py b/xtts/preceiver_resampler.py new file mode 100644 index 0000000000000..8a4c49d73e7b4 --- /dev/null +++ b/xtts/preceiver_resampler.py @@ -0,0 +1,35 @@ +from typing import List +from numpy import ndarray +import numpy +from onnxruntime import InferenceSession +import os + +from utils import * +from model_setting import ModelSetting + +class PreceiverResampler: + def __init__(self, model_setting: ModelSetting): + self.onnx_session: InferenceSession = None + self.ref_mel: ndarray = None + self.model_setting = model_setting + self.magic_number_1 = 0.4 + self.post_init() + + def post_init(self): + self.onnx_session = InferenceSession(os.path.join(self.model_setting.model_dir, 'preceiver_resampler.onnx'), providers=['CUDAExecutionProvider']) + self.ref_mel = np.load(os.path.join(self.model_setting.model_dir, 'ref_mel.npy')) + + def get_reference_audio(self, text_token_count: int) -> torch.Tensor: + ref_length = max(round(text_token_count * self.model_setting.scale_rate * self.magic_number_1), 2) + conds: ndarray = self.ref_mel.copy() + cut_len = conds.shape[1] % ref_length if conds.shape[1] % ref_length != 0 else 0 + if cut_len != 0: + conds = conds[:, :-cut_len] + if conds.shape[1] // ref_length > 0: + conds = numpy.split(conds, conds.shape[1] // ref_length, axis=1) + conds = numpy.concatenate(conds, axis=0) + onnxruntime_input = {k.name: v for k, v in zip(self.onnx_session.get_inputs(), (conds,))} + onnxruntime_outputs = self.onnx_session.run(None, onnxruntime_input) + conds_output = torch.Tensor(onnxruntime_outputs[0]).mean(0, keepdim=True).to('cuda') + return conds_output[0] + diff --git a/xtts/run.py b/xtts/run.py new file mode 100644 index 0000000000000..d8418d88b41dd --- /dev/null +++ b/xtts/run.py @@ -0,0 +1,77 @@ +import argparse +import asyncio +import logging +import os + +from tts_engine import XTtsEngine +from model_setting import ModelSetting +from utils import * + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--input", type=str, required=True) + parser.add_argument("--output", type=str, required=True) + parser.add_argument("--streaming", action="store_true", default=False) + parser.add_argument("--request-rate", + type=float, + default=-1, + help="request rate per second") + parser.add_argument("--chunk-size", + type=int, + default=20, + help="audio chunk size") + parser.add_argument("--first-chunk-size", + type=int, + default=10, + help="audio chunk size") + parser.add_argument("--overlap-window", type=int, default=0, help="overlap window size") + parser.add_argument("--runtime", type=str, default="onnx") + parser.add_argument("--lora", type=str, default=None, help="lora model path") + parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--log-level", type=str, default="INFO") + + parser.add_argument("--top_k", type=int, default=-1) + parser.add_argument("--top_p", type=float, default=1) + parser.add_argument("--temperature", type=float, default=1) + parser.add_argument("--scale_rate", type=float, default=2.7) + parser.add_argument("--cut_tail", type=int, default=0) + + args = parser.parse_args() + + # set log level + logging.basicConfig(level=args.log_level) + + # convert_model('/home/zhn/fishtts/checkpoint-1734000.bak', '/home/zhn/fishtts/llama.pt') + # convert_model_lora('/home/zhn/fishtts/lora1/lora.bak', '/home/zhn/fishtts/lora1/adapter_model.bin') + + model_setting = ModelSetting(model_dir=args.model, + runtime=args.runtime, + dtype=args.dtype, + streaming=args.streaming, + overlap_window=args.overlap_window, + chunk_size=args.chunk_size, + first_chunk_size=args.first_chunk_size, + cut_tail=args.cut_tail, + scale_rate=args.scale_rate) + if args.lora: + model_setting.support_lora = True + tts_engine = XTtsEngine(model_setting) + + # warm up + logger.info('E2E warmup with lora...') + tts_engine.warm_up(args.lora) + logger.info('E2E warmup done') + + with open(args.input, "r") as f: + texts = f.readlines() + # if output directory does not exist, create it + if not os.path.exists(args.output): + os.makedirs(args.output) + if args.streaming: + asyncio.run(tts_engine.synthesize_async(texts=texts, output_dir=args.output, request_rate=args.request_rate, lora_path=args.lora, + top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)) + else: + tts_engine.synthesize(texts=texts, output_dir=args.output, lora_path=args.lora, + top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) diff --git a/xtts/tts_engine.py b/xtts/tts_engine.py new file mode 100644 index 0000000000000..940ec0d71aeef --- /dev/null +++ b/xtts/tts_engine.py @@ -0,0 +1,255 @@ +import asyncio +import time +from typing import Any, Union + +import numpy +import onnxruntime +import pypinyin +from tokenizers import Tokenizer +import torch +from transformers import LlamaConfig + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.llm import LLM + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from generator import AudioGenerator +from metrics import TtsMetrics +from preceiver_resampler import PreceiverResampler +from utils import * +from model_setting import ModelSetting + +class XTtsEngine: + def __init__(self, model_setting: ModelSetting): + self.model_setting : ModelSetting = model_setting + self.tokenizer = None + self.llm_engine : Union[AsyncLLMEngine, LLM]= None + self.audio_generator: AudioGenerator = None + self.preceiever_sampler = None + self.sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) + self.post_init() + + def post_init(self): + + # initialize tokenizer + logger.info('Loading tokenizer...') + self.tokenizer = Tokenizer.from_file('/home/zhn/fishtts/vocab.json') + logger.info('Tokenizer loaded.') + + # initialize LLM + logger.info('Loading LLM...') + if self.model_setting.streaming: + logger.info('Using AsyncLLMEngine...') + engine_args = AsyncEngineArgs(model=self.model_setting.model_dir, + gpu_memory_utilization=self.model_setting.gpu_memory_utilization, + dtype=self.model_setting.dtype, + skip_tokenizer_init=True) + if self.model_setting.support_lora: + engine_args.enable_lora = True + engine_args.max_lora_rank = 128 + self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args) + else: + logger.info('Using LLM...') + if self.model_setting.support_lora: + self.llm_engine = LLM(self.model_setting.model_dir, + gpu_memory_utilization=self.model_setting.gpu_memory_utilization, + dtype=self.model_setting.dtype, + skip_tokenizer_init=True, enable_lora=True, max_lora_rank=128) + else: + self.llm_engine = LLM(self.model_setting.model_dir, + gpu_memory_utilization=self.model_setting.gpu_memory_utilization, + dtype=self.model_setting.dtype, + skip_tokenizer_init=True) + + logger.info('LLM loaded.') + + # initialize audio generator + logger.info('Loading audio generator...') + self.audio_generator = AudioGenerator(self.model_setting) + logger.info('Audio generator loaded.') + + self.preceiever_sampler = PreceiverResampler(self.model_setting) + logger.info('Preceiver resampler loaded.') + + def warm_up(self, lora_path: str = None): + lora_request = None + if lora_path: + lora_request=LoRARequest("lora", 1, lora_local_path=lora_path) + else: + lora_request = None + if self.model_setting.streaming: + prompts = self.text_to_prompts("你好") + asyncio.run(self.generate_streaming(self.llm_engine, prompts[0], self.sampling_params, 'warmup', '.', lora_request=lora_request)) + + def text_to_prompts(self, texts: Union[str, List[str]]) -> List[dict[str, Any]]: + if isinstance(texts, str): + texts = [texts] + llm_inputs = [] + for text in texts: + text_split = mix_sentence_spliter(text) + token_ids_merge: List[int] = [] + for idx, sub_text in enumerate(text_split): + locale = 'zh' if re.search(r'[\u4e00-\u9fa50-9]', sub_text) else 'en' + txt = text_normalizer({'text': sub_text, 'locale': locale}, idx == len(text_split) - 1) + if locale == 'zh': + txt = "".join([p[0] for p in pypinyin.pinyin(txt, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]) + + locale = "zh-cn" if locale == "zh" else locale + txt = f"[{locale}]{txt}" + txt = txt.replace(" ", "[SPACE]") + token_ids: List[int] = self.tokenizer.encode(txt).ids + token_ids_merge.extend(token_ids if idx == 0 else token_ids[1:]) + + token_ids_merge.insert(0, 7001) + token_ids_merge.append(0) + token_ids_merge.append(7003) + + # append reference audio embeddings + token_ids_merge = [7004] + [7005]* 15 + token_ids_merge + llm_inputs.append(token_ids_merge) + prompts = [ + {"prompt_token_ids": llm_input, "multi_modal_data":{ "audio": self.preceiever_sampler.get_reference_audio(len(llm_input) - 16 - 3) } } for llm_input in llm_inputs + ] + return prompts + + def synthesize(self, texts: List[str], output_dir: str, lora_path: str = None, + top_k: int = 1, top_p: float = 1, temperature: float = 1.0): + if isinstance(texts, str): + texts = [texts] + logger.info(f'Synthesizing {len(texts)} texts...') + prompts = self.text_to_prompts(texts) + lora_request = None + if lora_path: + lora_request=LoRARequest("lora", 1, lora_local_path=lora_path) + for i,p in enumerate(prompts): + logger.debug(f'Processing text {i+1}/{len(prompts)}...') + metrics = TtsMetrics(chunk_size=self.model_setting.chunk_size) + metrics.time_start = time.perf_counter() + sampling_params = SamplingParams(detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, + repetition_penalty=1.5, repetition_window=16, + top_k=top_k, top_p=top_p, temperature=temperature) + outputs = self.llm_engine.generate(p, sampling_params, lora_request=lora_request) + output = outputs[0].outputs[0] + token_ids = output.token_ids + token_ids = token_ids[:-1] + output_len = len(token_ids) + cur_time = time.perf_counter() + metrics.token_times.extend([cur_time] * output_len) + logger.info(f'{output_len} tokens generated:') + logger.debug(token_ids) + latent = torch.stack(output.hidden_states, 0).unsqueeze(0).to('cuda') + # the last token is EOS, should be excluded + latent = latent[:, :output_len, :] + total_audio = self.audio_generator.generate_audio(latent, metrics) + save_audio(total_audio, f'{output_dir}/{i:03d}.wav', self.model_setting.cut_tail) + logger.debug(f'Audio generated') + metrics.time_end = time.perf_counter() + metrics.calc_non_streaming() + + async def generate_token_streaming(self, engine: AsyncLLMEngine, + prompt: dict[str, Any], lora_request: LoRARequest, id: str, sampling_params: SamplingParams, + latent_queue: asyncio.Queue, metrics: TtsMetrics): + results_generator = engine.generate(prompt, sampling_params, request_id=id, lora_request=lora_request) + tokens = [] + logger.debug(f'Generating tokens for {id}...') + async for request_output in results_generator: + metrics.token_times.append(time.perf_counter()) + token_ids = request_output.outputs[0].token_ids[-1] + latent = request_output.outputs[0].hidden_states[-1] + tokens.append(token_ids) + latent_queue.put_nowait(latent) + # the last token is EOS, should be excluded + tokens = tokens[:-1] + logger.info(f'Tokens generated for {id}, {len(tokens)} tokens generated.') + latent_queue.put_nowait(None) + + + async def generate_audio_streaming(self, latent_queue: asyncio.Queue, id: str, metrics: TtsMetrics, output_dir: str): + latent_buffer = [] + audio_data_buffer = [] + chunk_id = 0 + padding = self.model_setting.chunk_padding + while True: + latent = await latent_queue.get() + if latent is None: + break + latent_buffer.append(latent) + + if chunk_id == 0 and len(latent_buffer) == self.model_setting.first_chunk_size: + latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') + trim_end = True if self.model_setting.overlap_window > 0 else False + audio_data_buffer.append(self.audio_generator.generate_chunk_audio(latent, metrics, padding=False, trim_begin=False, trim_end=trim_end)) + logger.debug(f'Chunk audio generated for promot {id} chunk {chunk_id}...') + chunk_id += 1 + if trim_end: + latent_buffer = latent_buffer[-self.model_setting.overlap_window*2:] + else: + latent_buffer = [] + + elif len(latent_buffer) == self.model_setting.chunk_size: + latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') + trim_begin = trim_end = True if self.model_setting.overlap_window > 0 else False + audio_data_buffer.append(self.audio_generator.generate_chunk_audio(latent, metrics, padding=False, trim_begin=trim_begin, trim_end=trim_end)) + logger.debug(f'Chunk audio generated for promot {id} chunk {chunk_id}...') + chunk_id += 1 + if trim_end: + latent_buffer = latent_buffer[-self.model_setting.overlap_window*2:] + else: + latent_buffer = [] + + if len(latent_buffer) > 0: + latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') + # the last token is EOS, should be excluded + latent = latent[:, :-1, :] + trim_begin = True if self.model_setting.overlap_window > 0 else False + audio_data_buffer.append(self.audio_generator.generate_chunk_audio(latent, metrics, padding=padding, trim_begin=trim_begin, trim_end=False)) + logger.debug(f'Chunk audio generated for promot {id} chunk {chunk_id}...') + + save_audio(audio_data_buffer, f'{output_dir}/{id}.wav', self.model_setting.cut_tail) + logger.debug(f'Audio generated for prompt {id}.') + + async def generate_streaming(self, engine: AsyncLLMEngine, + prompt: dict[str, Any], + sampling_params: SamplingParams, + request_id: str, + output_dir: str, + lora_request: LoRARequest = None) -> TtsMetrics: + metrics = TtsMetrics(chunk_size=self.model_setting.chunk_size, first_chunk_size=self.model_setting.first_chunk_size) + metrics.time_start = time.perf_counter() + + latent_queue = asyncio.Queue() + vllm_task = asyncio.create_task(self.generate_token_streaming(engine, prompt, lora_request, request_id, sampling_params, latent_queue, metrics)) + generator_task = asyncio.create_task(self.generate_audio_streaming(latent_queue, request_id, metrics, output_dir)) + await vllm_task + await generator_task + metrics.time_end = time.perf_counter() + return metrics + + async def synthesize_async(self, texts: List[str], output_dir: str, lora_path: str, request_rate: float = -1, + top_k: int = 1, top_p: float = 1, temperature: float = 1.0): + if isinstance(texts, str): + texts = [texts] + logger.info(f'Synthesizing {len(texts)} texts streaming...') + sampling_params = SamplingParams(detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, + repetition_penalty=1.5, repetition_window=16, + top_k=top_k, top_p=top_p, temperature=temperature) + prompts = self.text_to_prompts(texts) + lora_request = None + if lora_path: + lora_request=LoRARequest("lora", 1, lora_local_path=lora_path) + if request_rate < 0: + for i in range(len(prompts)): + me = await self.generate_streaming(self.llm_engine, prompts[i], sampling_params, f'{i:03d}', output_dir, lora_request=lora_request) + me.calc_streaming() + else: + tasks: List[asyncio.Task] = [] + request_id = 0 + async for prompt in get_request(prompts, request_rate): + tasks.append(asyncio.create_task(self.generate_streaming(self.llm_engine, prompt, sampling_params, f'{request_id:03d}', output_dir, lora_request=lora_request))) + request_id += 1 + + metrics_list: List[TtsMetrics] = await asyncio.gather(*tasks) + for metrics in metrics_list: + metrics.calc_streaming() \ No newline at end of file diff --git a/xtts/utils.py b/xtts/utils.py new file mode 100644 index 0000000000000..402d4249016f3 --- /dev/null +++ b/xtts/utils.py @@ -0,0 +1,274 @@ +import asyncio +import logging +from typing import List + +import numpy as np +import soundfile as sf +import torch +import re + +logger = logging.getLogger(__name__) + +def to_numpy(tensor: torch.Tensor) -> np.ndarray: + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + +def save_audio(total_audio: List[np.ndarray], path: str, cut_tail: int = 0): + total_audio = np.concatenate(total_audio, axis=0) + if cut_tail > 0: + total_audio = total_audio[:-cut_tail * 24] + sf.write(path, total_audio, 24000) + +async def get_request(input_requests, request_rate: float): + requests = iter(input_requests) + for request in requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + +def convert_model(path_torch: str, path_vllm: str): + tts2 = torch.load(path_torch) + + layer = 24 + dim = 1536 + num_audio_tokens = 1026 + num_text_tokens = 7002 + llama = tts2['model']['llama'] + + llama.pop('freqs_cis') + llama.pop('causal_mask') + + text_emb = llama['text_embeddings.weight'] + for i in range(100): + text_emb = torch.cat([text_emb, torch.zeros((1,dim), device=text_emb.device)], 0) + llama['emb_text.weight'] = text_emb + llama.pop('text_embeddings.weight') + + # 0-1023: audio1, 1024 bos 1026 eos + # 1027-2050: audio2, 2051 bos 2053 eos + + # 0-1023: audio1, 1024-2047: audio2 + # bos1 2048, eos1 2049 + # bos2 2050, eos2 2051 + llama['emb_code.0.weight'] = llama['code_embeddings.weight'][0:num_audio_tokens].clone() + llama['emb_code.1.weight'] = llama['code_embeddings.weight'][num_audio_tokens-2:num_audio_tokens - 2 + num_audio_tokens].clone() + llama['emb_code.0.weight'][1024]=llama['code_embeddings.weight'][2048] + llama['emb_code.1.weight'][1024]=llama['code_embeddings.weight'][2050] + llama.pop('code_embeddings.weight') + + for i in range(layer): + qkv_name = f'layers.{i}.attention.wqkv.weight' + q = llama[qkv_name][0:dim] + k = llama[qkv_name][dim:2*dim] + v = llama[qkv_name][2*dim:] + llama[f'gpt.layers.{i}.self_attn.q_proj.weight'] = q + llama[f'gpt.layers.{i}.self_attn.k_proj.weight'] = k + llama[f'gpt.layers.{i}.self_attn.v_proj.weight'] = v + llama.pop(qkv_name) + + wo_name = f'layers.{i}.attention.wo.weight' + wo = llama[wo_name] + llama[f'gpt.layers.{i}.self_attn.o_proj.weight'] = wo + llama.pop(wo_name) + + gate_proj_name = f'layers.{i}.feed_forward.w1.weight' + w_gate = llama[gate_proj_name] + llama[f'gpt.layers.{i}.mlp.gate_proj.weight'] = w_gate + llama.pop(gate_proj_name) + + gate_up_proj_name = f'layers.{i}.feed_forward.w3.weight' + w_gate_up = llama[gate_up_proj_name] + llama[f'gpt.layers.{i}.mlp.up_proj.weight'] = w_gate_up + llama.pop(gate_up_proj_name) + + gate_down_proj_name = f'layers.{i}.feed_forward.w2.weight' + w_gate_down = llama[gate_down_proj_name] + llama[f'gpt.layers.{i}.mlp.down_proj.weight'] = w_gate_down + llama.pop(gate_down_proj_name) + + attn_norm_name = f'layers.{i}.attention_norm.weight' + w_attn_norm = llama[attn_norm_name] + llama[f'gpt.layers.{i}.input_layernorm.weight'] = w_attn_norm + llama.pop(attn_norm_name) + + ffn_norm_name = f'layers.{i}.ffn_norm.weight' + w_ffn_norm = llama[ffn_norm_name] + llama[f'gpt.layers.{i}.post_attention_layernorm.weight'] = w_ffn_norm + llama.pop(ffn_norm_name) + + + norm_name = 'norm.weight' + w_norm = llama[norm_name] + llama['gpt.norm.weight'] = w_norm + llama.pop(norm_name) + + output_name = 'output.weight' + w_output = llama[output_name] + llama['lm_head.0.weight'] = w_output[num_text_tokens:num_text_tokens+num_audio_tokens] + llama['lm_head.1.weight'] = w_output[num_text_tokens+num_audio_tokens:num_text_tokens+num_audio_tokens*2] + llama.pop(output_name) + + torch.save(llama, path_vllm) + +def convert_model_lora(path_torch: str, path_vllm: str): + lora = torch.load(path_torch) + + layer = 24 + dim = 1536 + for i in range(layer): + qkv_name_A = f'layers.{i}.attention.wqkv.lora_A' + q_A = lora[qkv_name_A] + k_A = lora[qkv_name_A] + v_A = lora[qkv_name_A] + lora[f'base_model.model.gpt.layers.{i}.self_attn.q_proj.lora_A.weight'] = q_A + lora[f'base_model.model.gpt.layers.{i}.self_attn.k_proj.lora_A.weight'] = k_A + lora[f'base_model.model.gpt.layers.{i}.self_attn.v_proj.lora_A.weight'] = v_A + lora.pop(qkv_name_A) + + qkv_name_B = f'layers.{i}.attention.wqkv.lora_B' + q_B = lora[qkv_name_B][0:dim] + k_B = lora[qkv_name_B][dim:2*dim] + v_B = lora[qkv_name_B][2*dim:] + lora[f'base_model.model.gpt.layers.{i}.self_attn.q_proj.lora_B.weight'] = q_B + lora[f'base_model.model.gpt.layers.{i}.self_attn.k_proj.lora_B.weight'] = k_B + lora[f'base_model.model.gpt.layers.{i}.self_attn.v_proj.lora_B.weight'] = v_B + lora.pop(qkv_name_B) + + wo_name_A = f'layers.{i}.attention.wo.lora_A' + wo_A = lora[wo_name_A] + lora[f'base_model.model.gpt.layers.{i}.self_attn.o_proj.lora_A.weight'] = wo_A + lora.pop(wo_name_A) + + wo_name_B = f'layers.{i}.attention.wo.lora_B' + wo_B = lora[wo_name_B] + lora[f'base_model.model.gpt.layers.{i}.self_attn.o_proj.lora_B.weight'] = wo_B + lora.pop(wo_name_B) + + gate_proj_name_A = f'layers.{i}.feed_forward.w1.lora_A' + w_gate_A = lora[gate_proj_name_A] + lora[f'base_model.model.gpt.layers.{i}.mlp.gate_proj.lora_A.weight'] = w_gate_A + lora.pop(gate_proj_name_A) + + gate_proj_name_B = f'layers.{i}.feed_forward.w1.lora_B' + w_gate_B = lora[gate_proj_name_B] + lora[f'base_model.model.gpt.layers.{i}.mlp.gate_proj.lora_B.weight'] = w_gate_B + lora.pop(gate_proj_name_B) + + gate_up_proj_name_A = f'layers.{i}.feed_forward.w3.lora_A' + w_gate_up_A = lora[gate_up_proj_name_A] + lora[f'base_model.model.gpt.layers.{i}.mlp.up_proj.lora_A.weight'] = w_gate_up_A + lora.pop(gate_up_proj_name_A) + + gate_up_proj_name_B = f'layers.{i}.feed_forward.w3.lora_B' + w_gate_up_B = lora[gate_up_proj_name_B] + lora[f'base_model.model.gpt.layers.{i}.mlp.up_proj.lora_B.weight'] = w_gate_up_B + lora.pop(gate_up_proj_name_B) + + gate_down_proj_name_A = f'layers.{i}.feed_forward.w2.lora_A' + w_gate_down_A = lora[gate_down_proj_name_A] + lora[f'base_model.model.gpt.layers.{i}.mlp.down_proj.lora_A.weight'] = w_gate_down_A + lora.pop(gate_down_proj_name_A) + + gate_down_proj_name_B = f'layers.{i}.feed_forward.w2.lora_B' + w_gate_down_B = lora[gate_down_proj_name_B] + lora[f'base_model.model.gpt.layers.{i}.mlp.down_proj.lora_B.weight'] = w_gate_down_B + lora.pop(gate_down_proj_name_B) + + num_audio_tokens = 1026 + num_text_tokens = 7002 + output_name_A = 'output.lora_A' + w_output_A = lora[output_name_A] + lora['base_model.model.lm_head.0.lora_A.weight'] = w_output_A + lora['base_model.model.lm_head.1.lora_A.weight'] = w_output_A + lora.pop(output_name_A) + + output_name_B = 'output.lora_B' + w_output_B = lora[output_name_B] + lora['base_model.model.lm_head.0.lora_B.weight'] = w_output_B[num_text_tokens:num_text_tokens+num_audio_tokens] + lora['base_model.model.lm_head.1.lora_B.weight'] = w_output_B[num_text_tokens+num_audio_tokens:num_text_tokens+num_audio_tokens*2] + lora.pop(output_name_B) + + # convert the model to fp16 + for k, v in lora.items(): + if isinstance(v, torch.Tensor): + lora[k] = v.half() + + torch.save(lora, path_vllm) + +def mix_sentence_spliter(text): + segments_with_punctuation = re.findall(r'[\u4e00-\u9fff0-9]+|[a-zA-Z\s\'-]+|[,.!?,。!?"“”::;;—()(){}]', text) + combined_segments = [] + for i, segment in enumerate(segments_with_punctuation): + if i > 0 and re.match(r'[,.!?,。!?"“”::;;—()(){}]', segment): + combined_segments[-1] += segment + else: + if len(combined_segments) > 0 and re.match(r'[\u4e00-\u9fff0-9]', combined_segments[-1]) and re.match( + r'[\u4e00-\u9fff0-9]', segment): + combined_segments[-1] += segment + elif len(combined_segments) > 0 and re.match(r'[a-zA-Z\s\'-]', combined_segments[-1]) and re.match( + r'[a-zA-Z\s\'-]', segment): + combined_segments[-1] += segment + else: + combined_segments.append(segment) + + out_combined_segments = [] + for segment in combined_segments: + if segment.strip(): + if out_combined_segments and out_combined_segments[-1] in '.,!?,。!?"“”::;;—()(){}': + out_combined_segments[-1] += segment.strip() + else: + out_combined_segments.append(segment.strip()) + + return out_combined_segments + +def text_normalizer(x, is_lastsegment=True): + x['before_norm_text'] = x['text'] + if x['locale'] == "zh": + x['text'] = x['text'].replace(" ", "") + x['text'] = x['text'].replace('.', '。').replace('!', '!').replace('?', '?').replace(',', ',').replace(':', ':') + x['text'] = x['text'].replace('“', '"').replace('”', '"').replace('‘', '"').replace('’', '"') + if is_lastsegment: + if len(x['text']) > 0 and x['text'][-1] == '"': + x['text'] = x['text'][:-1] + + if len(x['text']) > 0 and x['text'][-1] == ',': + x['text'] = x['text'][:-1] + '。' + + x['text'] = x['text'].replace('—', '。') + if len(x['text']) > 0 and x['text'][-1] not in ['。', '!', '?']: + x['text'] += '。' + + if re.search('[a-zA-Z]', x['before_norm_text']): + x['text'] = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9。,,!?《》、:"\' ]', '', x['text']) + x['text'] = re.sub(r'(?<=[\u4e00-\u9fa5。,,!?《》、:"\'])\s+(?=[\u4e00-\u9fa5。,,!?《》、:"\'])', '', x['text']) + x['text'] = re.sub(r'(?<=[\u4e00-\u9fa5])\s+(?=[a-zA-Z])', '', x['text']) + x['text'] = re.sub(r'(?<=[a-zA-Z])\s+(?=[\u4e00-\u9fa5])', '', x['text']) + else: + x['text'] = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9。,,!?《》、:"\']', '', x['text']) + + x['text'] = re.sub(r'([。,,!?])\1+', r'\1', x['text']) + else: + x['text'] = re.sub(r'[^\w.,!?\'" ]', '', x['text']) + + if is_lastsegment: + if len(x['text']) > 0 and x['text'][-1] == ',': + x['text'] = x['text'][:-1] + '.' + if len(x['text']) > 0 and x['text'][-1] not in ['.', '!', '?', '"', '\'']: + x['text'] += '.' + + x['text'] = re.sub(r'([,!?])\1+', r'\1', x['text']) + x['text'] = re.sub(r'\.{2,}', '...', x['text']) + x['text'] = re.sub(r'\s+([.,!?"\'])', r'\1', x['text']) + x['text'] = re.sub(r'([.,!?"\'])\s+', r'\1', x['text']) + + x['text'] = re.sub(r"\s+", ' ', x['text'].lower()).strip() + return x['text'] + + +# trtexec --onnx=/home/zhn/fishtts/genertor.onnx --saveEngine=/home/zhn/fishtts/genertor.trt --memPoolSize=workspace:10000 --minShapes=input:1x1x1536,speaker_embedding:1x192x1 --maxShapes=input:1x512x1536,speaker_embedding:1x192x1 --optShapes=input:1x20x1536,speaker_embedding:1x192x1 --device=0 \ No newline at end of file From a2c67c322ae16bc5a4ff295f853571d70217c67b Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Mon, 23 Dec 2024 07:32:26 +0000 Subject: [PATCH 58/61] make e2e work after merge --- vllm/engine/llm_engine.py | 20 +++++++++++------- vllm/engine/output_processor/stop_checker.py | 2 +- .../layers/multi_head_sampler.py | 9 +------- vllm/model_executor/models/fishtts.py | 21 ++++++++++--------- vllm/model_executor/models/registry.py | 3 +++ vllm/multimodal/speech.py | 14 ++++++------- vllm/sequence.py | 14 ++++++------- vllm/worker/model_runner.py | 4 ++-- 8 files changed, 43 insertions(+), 44 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 971d87356b013..6d4bbb5c4dad0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1263,21 +1263,25 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] - if sample.output_tokens and len(sample.output_tokens) > 1: - seq.append_token_id(sample.output_tokens, - sample.logprobs) - else: - seq.append_token_id(sample.output_token, - sample.logprobs) if self.scheduler_config.is_multi_step: is_prefill_append = seq.data.get_num_uncomputed_tokens( ) == 0 - seq.append_token_id(sample.output_token, sample.logprobs) + if sample.output_tokens and len(sample.output_tokens) > 1: + seq.append_token_id(sample.output_tokens, + sample.logprobs) + else: + seq.append_token_id(sample.output_token, + sample.logprobs) if not is_prefill_append: seq_group.update_num_computed_tokens(1) else: - seq.append_token_id(sample.output_token, sample.logprobs) + if sample.output_tokens and len(sample.output_tokens) > 1: + seq.append_token_id(sample.output_tokens, + sample.logprobs) + else: + seq.append_token_id(sample.output_token, + sample.logprobs) def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 0624ed2fd8f60..66233637f8ec9 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams diff --git a/vllm/model_executor/layers/multi_head_sampler.py b/vllm/model_executor/layers/multi_head_sampler.py index 0f7ff7a46d719..3a01e03dca338 100644 --- a/vllm/model_executor/layers/multi_head_sampler.py +++ b/vllm/model_executor/layers/multi_head_sampler.py @@ -10,9 +10,6 @@ from vllm.model_executor.layers.sampler import MaybeDeferredSampleResultType, SampleResultArgsType, SampleReturnType, SamplerOutput, _apply_top_k_top_p, _get_bin_counts_and_mask from vllm.triton_utils import HAS_TRITON -if HAS_TRITON: - from vllm.model_executor.layers.ops.sample import sample as sample_triton - from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, SequenceGroupToSample) @@ -20,8 +17,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad, - maybe_expand_dim) + is_pin_memory_available, make_tensor_with_pad) from einops import rearrange @@ -242,9 +238,6 @@ def from_sampling_metadata(self, min_ps=None, presence_penalties=None, frequency_penalties=None, - sampling_seeds=None, - sample_indices=None, - extra_seeds=None, prompt_tokens=None ) diff --git a/vllm/model_executor/models/fishtts.py b/vllm/model_executor/models/fishtts.py index 557a2c800028c..67872013a358f 100644 --- a/vllm/model_executor/models/fishtts.py +++ b/vllm/model_executor/models/fishtts.py @@ -8,8 +8,8 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.config import CacheConfig, VllmConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, DummyData from vllm.inputs.registry import InputContext from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.multi_head_sampler import MultiheadSampler @@ -23,6 +23,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.speech import SpeechPlugin from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors +from .interfaces import SupportsMultiModal from einops import rearrange from transformers.generation import TopKLogitsWarper, TopPLogitsWarper @@ -38,7 +39,7 @@ def dummy_data_for_ttsllm(ctx: InputContext, seq_len: int, mm_counts: Mapping[st dummy_seq_data = SequenceData([0] * seq_len) dummy_multi_modal_data = {"audio": SpeechPlugin.sample_random_speaker()} - return dummy_seq_data, dummy_multi_modal_data + return DummyData(dummy_seq_data, dummy_multi_modal_data, None) def get_max_speech_tokens(ctx: InputContext): return 16 @@ -46,7 +47,7 @@ def get_max_speech_tokens(ctx: InputContext): @MULTIMODAL_REGISTRY.register_speech_input_mapper() @INPUT_REGISTRY.register_dummy_data(dummy_data_for_ttsllm) @MULTIMODAL_REGISTRY.register_max_speech_tokens(get_max_speech_tokens) -class FishTtsLlm(nn.Module, SupportsLoRA): +class FishTtsLlm(nn.Module, SupportsLoRA, SupportsMultiModal): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -78,13 +79,13 @@ class FishTtsLlm(nn.Module, SupportsLoRA): "up_proj": ("gate_up_proj", 1), } - def __init__(self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None,) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + quant_config = vllm_config.quant_config + self.config = config # static parameters, put them in config later @@ -94,7 +95,7 @@ def __init__(self, self.audio_start_token_id = config.audio_start_token_id self.audio_ref_token_id = config.audio_ref_start_token_id - self.gpt = LlamaModel(config, lora_config=lora_config) + self.gpt = LlamaModel(vllm_config=vllm_config, prefix=prefix) self.model_dim = self.gpt.config.hidden_size self.emb_text = VocabParallelEmbedding(self.num_text_tokens, self.model_dim) self.emb_code = nn.ModuleList([ diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 68a2467a813a1..8c8283b3042e4 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -168,6 +168,9 @@ "UltravoxModel": ("ultravox", "UltravoxModel"), # [Encoder-decoder] "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 + + # TTS + "FishTtsLlm": ("fishtts", "FishTtsLlm"), } _SPECULATIVE_DECODING_MODELS = { diff --git a/vllm/multimodal/speech.py b/vllm/multimodal/speech.py index 7456204ce464f..ce69b377ab9e9 100644 --- a/vllm/multimodal/speech.py +++ b/vllm/multimodal/speech.py @@ -8,15 +8,13 @@ from PIL import Image from transformers import PreTrainedTokenizerBase -from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger -from vllm.transformers_utils.image_processor import get_image_processor from vllm.transformers_utils.tokenizer import get_tokenizer -from .base import MultiModalInputs, MultiModalPlugin -import base64 -import pickle +from .base import MultiModalPlugin +from .inputs import MultiModalKwargs + class SpeechPlugin(MultiModalPlugin): @@ -24,12 +22,12 @@ def get_data_key(self) -> str: return "audio" def _default_input_mapper(self, ctx: InputContext, - data: object) -> MultiModalInputs: + data: object) -> MultiModalKwargs: model_config = ctx.model_config if data is None: - return MultiModalInputs({"audio": torch.zeros(16, model_config.hf_config.hidden_size)}) + return MultiModalKwargs({"audio": torch.zeros(16, model_config.hf_config.hidden_size)}) else: - return MultiModalInputs({"audio": data}) + return MultiModalKwargs({"audio": data}) def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: return 3000 diff --git a/vllm/sequence.py b/vllm/sequence.py index 17238371438c8..bb87843948862 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -203,17 +203,17 @@ def from_seqs( Construct a :class:`SequenceData` instance from prompt and output token sequences. """ - prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - prompt_token_ids) + # prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + # prompt_token_ids) if output_token_ids is None: - return SequenceData(prompt_token_ids_arr) + return SequenceData(prompt_token_ids) - output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - output_token_ids) + # output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + # output_token_ids) - return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr) + return SequenceData(prompt_token_ids, + _output_token_ids=output_token_ids) def __post_init__(self) -> None: self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e2654f14bf929..8551c4505b385 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -513,12 +513,12 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, if inter_data.is_prompt: inter_data.input_tokens[seq_idx].extend(tokens) else: - inter_data.input_tokens[seq_idx].append(tokens) + inter_data.input_tokens[seq_idx].extend(tokens) else: if isinstance(tokens, list): inter_data.input_tokens[seq_idx].extend(tokens) else: - inter_data.input_tokens[seq_idx].append(tokens) + inter_data.input_tokens[seq_idx].extend(tokens) if seq_data.mrope_position_delta is not None: if inter_data.mrope_input_positions is None: From 4d53ec8127e82fe151fb1096fe1e4bbc0b3d12df Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Mon, 23 Dec 2024 07:55:59 +0000 Subject: [PATCH 59/61] make e2e work after merge --- chattts_sample.py | 71 ---- fishtts_sample.py | 401 ------------------ requirements-cuda.txt | 2 +- vllm/inputs/registry.py | 4 - .../model_executor/layers/logits_processor.py | 18 +- vllm/model_executor/models/chattts.py | 197 --------- 6 files changed, 13 insertions(+), 680 deletions(-) delete mode 100644 chattts_sample.py delete mode 100644 fishtts_sample.py delete mode 100644 vllm/model_executor/models/chattts.py diff --git a/chattts_sample.py b/chattts_sample.py deleted file mode 100644 index 0892f59141955..0000000000000 --- a/chattts_sample.py +++ /dev/null @@ -1,71 +0,0 @@ -import asyncio -from vllm import LLM, SamplingParams -import torch - -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -torch.random.manual_seed(999) - -def convert_model(): - chatts = torch.load('/home/zhn/g/ChatTTS/asset/GPT.pt') - - chatts.pop('head_text.parametrizations.weight.original0') - chatts.pop('head_text.parametrizations.weight.original1') - for i in range(4): - original0 = chatts[f'head_code.{i}.parametrizations.weight.original0'] - original1 = chatts[f'head_code.{i}.parametrizations.weight.original1'] - # get the normalized weights based on the original 0 and 1 - weight_norm0 = torch._weight_norm(original1, original0, dim=0) - chatts.pop(f'head_code.{i}.parametrizations.weight.original0') - chatts.pop(f'head_code.{i}.parametrizations.weight.original1') - chatts[f'lm_head.{i}.weight'] = weight_norm0 - torch.save(chatts, '/home/zhn/ttslm_dev/chattts.pt') - -streaming = False -llm = LLM(model='/home/zhn/ttslm_dev', gpu_memory_utilization=0.5, dtype=torch.float32) -prompts = [ - { - "prompt": "[Stts][spk_emb][speed_5]Your text one[Ptts]", - "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, - }, - { - "prompt": "[Stts][spk_emb][speed_5]Your text two[Ptts]", - "multi_modal_data": {"audio": 'Dj9nNtQ7e0B4P9G7mjvYsJuukjacMNE9FzrKLxo3VzX6PZGyAjXVuhvBOznqOvC5ikDRvdsp/bEiPuE99TANNhq77L+RtQEx7DzPPDK21rQyvK69077OPxS4ArcDvTg6db7Psiq7MTYWuLm8N6faNkIs1zQTuc61YbnBtYG66L6tMW62hb5oN0K7pirVOSEwjzzwvHA2ern3uqG4/LQwK6m8NTtzNyW+9D2VP7c8wrFIvDg8c7bkPKoy0Du0Nac5YCElPxAwbzaIP2S9Fb3BO7G8xj+etvu5sT/sPXg/QDvkuqQ/SrRrOpQ79rDqOGE7BjZmvMY0Tj2Oucqw9rZmP2E5hTYpulA4ZbkTKBa5aTmdOKqyVjZfvPe9CDdVMRU5xECOO027eDxDvWO3XqzaPvW5nbpNrk+5gjxSvJe8Abf4t7g/mDuaLee8wToAvC6+wTwSNV47GjXbOjW5I7x0OW26hLlKt3okqrWAvSa5csAAOWm+a7ytKuu5LD4svUY0B6msvDA4fbEtOpc6VEBZPei1qbNmPD45Hzl4vjqm57l7NEiuSDsTNUo6dreOu5C2W6QRNoVAaammu4m1a7e0Nyq+DCwCMW4zx7ids7K95bdYuKg2oLxAOvo9JMDSNmo5GblbtkGwxLTytLS+nDxSPIK63r0rOJu4izActYg5jznOuYm9Vb7iOMG8DrjFuLuqv78KMjE4qS35tDwuYjaMwdM/PjY4uUG25jdurxDAIUATrxU8ibklPpU03zguvNg8hTk1PQa5srecQD28WrCMPEEuqjhktuw1Or2iPpCvcblVOV+3Vr1jvY05G7FHsFA8grzYuSi2Ci5mqZe046versK5gbpkNCG7obz/tAyxg0CjuYw7+jnBvDA6vLxRwdo4Br83N3m5oz5GPEG3gbnbvA63fLTWtQQxoTiYNR26wL11PYS5OLkruF+/prett/Y4ljm5G3Q7cDUEpCYxGrJRtkm0g7x9NPK7vDr8s1Yq5zcmOUi+n7ILO8u+MjaDwBe88iigvuK7472YPkExijggs6i1d76rtBi63TqQPN490jrKuHe8IjqzNRO/sDWEOosr7LF2usm9ZzrytHmz7j56Oe81jDAUPuE+cDuoODDA6y8qsAMuizt7wY64UjR4t6u99DpuNMW7CTWTOmsdaDoCLVO2o7yArTgyAjter1I3uLtDuhGvOLyQOmk3dzRsNQtAgLdnwGEyeqokuNW7B7fOMQa400Cvu1i6lMAGNAjAUbhptK659T71O0c9D7uEPPA3LDoePEW7yrbnOlDBDj5tONqpDjMfvJm1ZzwxPWRBu6xFuR24Erz8rCA2YqhjPH+zj7WmuIi+Rb63vDU9KTuAPVS94L9IuaAx+zcxPIw0+D4gs067ur4du78zHLt+O6Y9vcDqOnvAbTbmPCcsHr2vuDC01bF+sgiv1TLWvL+zEz4othq3UDwewbo8Drk/MSouijtItG25MT26wDUlZL0YugZAtak+MiC/1blKMq017zhrIDc4QzdSv70oJTq4ONK1crp5vDKwiDzPv0e5j7xIOm+4iLfxsiA3R7YYwlk4BjcGs1a8oTURNoW377Jxu7exNL6tNo67nzOgvbi4Eq+lvl8vEbb4vje91zgqOzmwejYqM0Y1JT0Hvmk5HbYts9a0AbbovN4xrz4vq6yVgLmFr3U37bVDt3Q2WLsKNJA6cT3JrXg9Izz4u9+84bQePKCszrg1Ppc0pMATNSk4ODMyN06sRr3utZG7drsNvT03RC4stzK5/6VRPALACDZeuIq/yL9NuwYzSjaDuxE8sT1WNru4fzwLOvA6QLmrwTxAcr6UqMNASjWYOwK+HL9DuF08irVbulYyVDzIvdi8T7phPIQzREB+O36oDz3FMqS5cTmSuaW0UD17rm26brcWPGy4MbYVuOMxK7Zovhm7drv5PIA6szgLuIe6D7g1vP9AfsCDOCO9rq7nu7a8kLwFP0GwILpyOE0hwzlMv9w0Ljqlviw3yT6bo5I6XTmpuRcpRb45t0A0yTlNJsM5abyCru8k'}, - } -] - -if not streaming: - sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[625], max_tokens=2048, top_k=1) - outputs = llm.generate(prompts, sampling_params) - for output in outputs: - print(output.prompt) - token_ids = output.outputs[0].token_ids - for token_id in token_ids: - print([x - 0 for x in token_id]) -else: - engine_args = AsyncEngineArgs(model='/home/zhn/ttslm_dev', gpu_memory_utilization=0.5, dtype=torch.float16) - model = AsyncLLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[625], max_tokens=2048, top_k=1) - - async def generate_streaming(prompt, id): - results_generator = model.generate(prompt, sampling_params, request_id=id) - count=0 - tokens = [] - async for request_output in results_generator: - token_ids = request_output.outputs[0].token_ids - print(f'{id} {[x - 0 for x in token_ids[-1]]}') - tokens.append([x - 0 for x in token_ids[-1]]) - count+=1 - - print(prompt['prompt']) - for token in tokens: - print(token) - - async def generate(): - tasks = [] - for i in range(1): - t = generate_streaming(prompts[i%2], i) - tasks.append(t) - await asyncio.gather(*tasks) - - asyncio.run(generate()) diff --git a/fishtts_sample.py b/fishtts_sample.py deleted file mode 100644 index a78a7f53a8e7b..0000000000000 --- a/fishtts_sample.py +++ /dev/null @@ -1,401 +0,0 @@ -import argparse -import asyncio -import threading -import time -from typing import List - -import numpy as np -import onnx -from vllm import LLM, SamplingParams -from tokenizers import Tokenizer -import pypinyin -import torch - -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -import onnxruntime -import soundfile as sf -import queue - -import tensorrt as trt -import pycuda.driver as cuda -import pycuda.autoinit - -torch.random.manual_seed(999) - -def convert_model(): - tts2 = torch.load('/data/fishtts/checkpoint-1400000.bak') - - layer = 24 - dim = 1536 - num_audio_tokens = 1026 - num_text_tokens = 7002 - llama = tts2['model']['llama'] - - llama.pop('freqs_cis') - llama.pop('causal_mask') - - text_emb = llama['text_embeddings.weight'] - for i in range(100): - text_emb = torch.cat([text_emb, torch.zeros((1,dim), device=text_emb.device)], 0) - llama['emb_text.weight'] = text_emb - llama.pop('text_embeddings.weight') - - llama['emb_code.0.weight'] = llama['code_embeddings.weight'][0:num_audio_tokens] - llama['emb_code.1.weight'] = llama['code_embeddings.weight'][num_audio_tokens-2:num_audio_tokens - 2 + num_audio_tokens] - llama.pop('code_embeddings.weight') - - for i in range(layer): - qkv_name = f'layers.{i}.attention.wqkv.weight' - q = llama[qkv_name][0:dim] - k = llama[qkv_name][dim:2*dim] - v = llama[qkv_name][2*dim:] - llama[f'gpt.layers.{i}.self_attn.q_proj.weight'] = q - llama[f'gpt.layers.{i}.self_attn.k_proj.weight'] = k - llama[f'gpt.layers.{i}.self_attn.v_proj.weight'] = v - llama.pop(qkv_name) - - wo_name = f'layers.{i}.attention.wo.weight' - wo = llama[wo_name] - llama[f'gpt.layers.{i}.self_attn.o_proj.weight'] = wo - llama.pop(wo_name) - - gate_proj_name = f'layers.{i}.feed_forward.w1.weight' - w_gate = llama[gate_proj_name] - llama[f'gpt.layers.{i}.mlp.gate_proj.weight'] = w_gate - llama.pop(gate_proj_name) - - gate_up_proj_name = f'layers.{i}.feed_forward.w3.weight' - w_gate_up = llama[gate_up_proj_name] - llama[f'gpt.layers.{i}.mlp.up_proj.weight'] = w_gate_up - llama.pop(gate_up_proj_name) - - gate_down_proj_name = f'layers.{i}.feed_forward.w2.weight' - w_gate_down = llama[gate_down_proj_name] - llama[f'gpt.layers.{i}.mlp.down_proj.weight'] = w_gate_down - llama.pop(gate_down_proj_name) - - attn_norm_name = f'layers.{i}.attention_norm.weight' - w_attn_norm = llama[attn_norm_name] - llama[f'gpt.layers.{i}.input_layernorm.weight'] = w_attn_norm - llama.pop(attn_norm_name) - - ffn_norm_name = f'layers.{i}.ffn_norm.weight' - w_ffn_norm = llama[ffn_norm_name] - llama[f'gpt.layers.{i}.post_attention_layernorm.weight'] = w_ffn_norm - llama.pop(ffn_norm_name) - - - norm_name = 'norm.weight' - w_norm = llama[norm_name] - llama['gpt.norm.weight'] = w_norm - llama.pop(norm_name) - - output_name = 'output.weight' - w_output = llama[output_name] - llama['lm_head.0.weight'] = w_output[num_text_tokens:num_text_tokens+num_audio_tokens] - llama['lm_head.1.weight'] = w_output[num_text_tokens+num_audio_tokens:num_text_tokens+num_audio_tokens*2] - llama.pop(output_name) - - torch.save(llama, '/data/fishtts/llama.pt') - -def to_numpy(tensor): - return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() - -texts = [ - '城市霓虹,夜幕低垂,梦想之光,闪烁不已。心向未来,勇往直前,在星空下,奋斗的旋律。', - '在这个数字的世界里,你是我的唯一,爱情如同网络连接,无论距离多遥远。我们的心相互链接,在虚拟的空间中漫游,每条信息都是爱的表达,每个瞬间都是甜蜜的时刻。爱情不再是纸上文字,而是数码世界的交流,在屏幕上,我们相拥相视,你是我的电子爱情。', - '探索清新世界的钥匙在此!用海洋微风洗衣粉,让您的衣物充满清晨海边的清新气息。我们的高效配方深层清洁衣物纤维去除顽固污渍的同时,带来持久的清香。不只是清洗更是衣物的焕新旅程。', - '从现在开始,让我们的多功能厨师机器人成为你厨房里的得力助手。它可以搅拌,切碎,烹饪,烘焙,满足你所有烹饪需求。创新美食,只需轻松一按。', - '打造完美家居生活,只需一款智能净化器。它能有效过滤空气中的污染物,释放负离子,让你每天呼吸到的都是最纯净的空气,为家人的健康护航。', - '我刚看完《流浪地球》,这部电影真的很打动我。它不仅仅展示了科幻世界中的宏大景象,更通过对人类团结和奉献精神的刻画,让我对未来充满了思考。影片中的视觉效果和细腻的情感描写,让我觉得这是一部值得反复琢磨的作品。如果你喜欢有深度的科幻电影,这部绝对不会让你失望。', - '下个月我计划去日本体验当地的文化。我特别期待去京都的古寺庙,想感受一下传统的日式建筑和庭园。东京的市场也让我兴奋,我打算品尝各种地道的小吃。此外,我计划学习一些基本的日语,这样能更好地融入当地生活,提升旅行的整体体验。你们有没有什么建议或者特别推荐的地方?', - '在保持健康方面,我尝试了一些新的饮食习惯。现在我更多地选择新鲜的蔬菜和水果,减少了糖分和加工食品的摄入。我发现这种饮食方式不仅改善了我的体重,还提升了整体的能量水平。此外,保持充足的水分摄入也是关键,它有助于身体的代谢和排毒。你们有什么其他的健康饮食建议吗?', - '为了提高学习效率,我采取了一些新方法。例如,我将复杂的学习任务拆分成小的目标,每完成一个小目标就能获得成就感。此外,我还使用了番茄工作法,设定25分钟专注学习,然后休息5分钟,这样可以有效避免疲劳。通过这些方法,我发现自己在学习过程中更加专注和高效。', - '有一本书《思考,快与慢》给我留下了深刻的印象。这本书由丹尼尔·卡尼曼撰写,详细探讨了人类思维的两种模式——快速直观和缓慢理性。通过丰富的实证研究,作者揭示了我们在日常决策中的思维偏差。这本书不仅在理论上很有趣,对实际生活中的决策也提供了很多有益的启示。', - '提升工作效率需要良好的时间管理。我发现将任务分解成小步骤,并逐步完成,能让工作变得更有条理。同时,使用待办事项列表和设置提醒也能帮助我保持高效。此外,我还注意到合理的休息和调整对工作效率至关重要。这样不仅提高了我的工作质量,也让我保持了良好的工作状态。', - '探索不同的音乐风格是我最近的兴趣之一。我特别喜欢电子音乐,尤其是那些融合了传统乐器的作品。这种音乐风格不仅提供了新的听觉体验,也让我对音乐的表现形式有了更深的理解。我发现,了解和欣赏不同风格的音乐,能够丰富我的音乐视野和审美体验。', - '照顾宠物需要全面的关注和细心的呵护。我了解到,定期带狗狗散步有助于它们的身体健康,同时提供丰富的玩具和定期的健康检查也很重要。此外,保持良好的饮食习惯对宠物的整体健康有很大影响。照顾宠物的过程中,了解它们的需求并给予关爱,能让它们生活得更加愉快和健康。', - '处理社交媒体信息过载,是我近期面临的一个问题。为了避免被海量的信息分散注意力,我开始设置每天查看社交媒体的时间限制,同时选择关注一些高质量的内容。此外,我还定期清理不感兴趣的账号,这样能够保持信息的有效性和对内容的专注。你们有什么管理社交媒体的好方法吗?', - '每个人都可以在日常生活中采取一些简单的环保行动。我开始减少一次性塑料的使用,进行垃圾分类,并尽量节约能源。这些小措施虽然看似微不足道,但积累起来对环境的保护却能产生积极影响。我相信,关注环保不仅是为了现在的生活,也为未来的子孙着想。你们在环保方面有什么实用的建议吗?', - '她给我们发了一张照片,呃,在一个满是山、山珍海味婚礼上她拿了一个巨小的饭盒在吃,反正就一个特别清淡的啊,减脂营备的餐,然后她呢当时在群里发这个是,嗯,为了求表扬,哈哈哈!', - '我这周末过得我觉得,我真的以为是真正意义上的休息,但是没想到周一的时候去上班,呃的时候,我,我还是感觉,呃,很奇怪,就是很提不起精神的感觉哎!', - '嗯,我刚刚就在想,你说创造一个环境其实,今年,呃,我去采访一位,呃,很有就是,阅历的同龄人的时候,她劝我做一件事情就是找一个心理咨询师,呃,去聊一聊。' - ] -llm_inputs = [] -tokenizer = Tokenizer.from_file('/data/fishtts/vocab.json') -for text in texts: - pinyin = "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]) - txt = f"[zh-cn]{pinyin}" - txt = txt.replace(" ", "[SPACE]") - token_ids = tokenizer.encode(txt).ids - token_ids.insert(0, 7001) - token_ids.append(0) - token_ids.append(7003) - llm_inputs.append(token_ids) - -sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16) -prompts = [ - {"prompt_token_ids": llm_input} for llm_input in llm_inputs -] - -# ctx = cuda.Device(0).make_context() - -class Metrics: - def __init__(self, chunk_size=20): - self.chunk_size = chunk_size - self.time_start = 0 - self.time_end = 0 - self.token_times = [] - self.audio_chunk_times = [] - - def calc_non_streaming(self): - total_time = self.time_end - self.time_start - audio_time = len(self.token_times) * 50 / 1000 - rtf = total_time / audio_time - latent_time = self.token_times[-1] - self.time_start - first_byte_time = self.audio_chunk_times[0] - self.time_start - print(f'latent time: {latent_time}, first byte time: {first_byte_time}, total time: {total_time}, audio time: {audio_time}, rtf: {rtf}') - - def calc_streaming(self): - total_time = self.time_end - self.time_start - audio_time = len(self.token_times) * 50 / 1000 - rtf = total_time / audio_time - first_chunk_time = self.token_times[self.chunk_size - 1] - self.time_start - first_byte_time = self.audio_chunk_times[0] - self.time_start - print(f'first chunk time: {first_chunk_time}, first byte time: {first_byte_time}, total time: {total_time}, audio time: {audio_time}, rtf: {rtf}') - -class Generator: - def __init__(self, model_path, use_trt=False, chunk_size=20, hidden_size=1536, frame_shift=1200): - self.onnx_session = None - self.trt_engine = None - self.use_trt = use_trt - self.model_path = model_path - self.chunk_size = chunk_size - self.hidden_size = hidden_size - self.frame_shift = frame_shift - self.speaker_embedding = torch.zeros((1, 192, 1), dtype=torch.float32).to('cuda') - - def generate_audio_onnx(self, latent): - onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(self.onnx_session.get_inputs(), (latent, self.speaker_embedding))} - onnxruntime_outputs = self.onnx_session.run(None, onnxruntime_input) - onnxruntime_outputs = onnxruntime_outputs[0][0][0] - return onnxruntime_outputs - - def generate_audio_trt(self, latent): - with self.trt_engine.create_execution_context() as context: - # ctx.push() - - stream = cuda.Stream() - context.set_input_shape('input', (1, self.chunk_size, self.hidden_size)) - context.set_input_shape('speaker_embedding', (1, 192, 1)) - - bindings = [] - - # d_input = cuda.mem_alloc(latent.nbytes) - # cuda.memcpy_htod_async(d_input, latent.cpu().numpy(), stream) - # cuda.memcpy_dtod_async(d_input, latent.data_ptr(), latent.nbytes, stream) - bindings.append(latent.data_ptr()) - - # d_speaker_embedding = cuda.mem_alloc(self.speaker_embedding.nbytes) - # cuda.memcpy_htod_async(d_speaker_embedding, self.speaker_embedding.cpu().numpy(), stream) - # cuda.memcpy_dtod_async(d_speaker_embedding, self.speaker_embedding.data_ptr(), self.speaker_embedding.nbytes, stream) - # bindings.append(int(d_speaker_embedding)) - bindings.append(self.speaker_embedding.data_ptr()) - - dtype = trt.nptype(self.trt_engine.get_tensor_dtype("output")) - size = trt.volume(context.get_tensor_shape('output')) - output_buffer = cuda.pagelocked_empty(size, dtype) - output_memory = cuda.mem_alloc(output_buffer.nbytes) - bindings.append(int(output_memory)) - - for i in range(len(bindings)): - context.set_tensor_address(self.trt_engine.get_tensor_name(i), bindings[i]) - - context.execute_async_v3(stream_handle=stream.handle) - stream.synchronize() - - cuda.memcpy_dtoh_async(output_buffer, output_memory, stream) - - # ctx.pop() - return output_buffer - - def generate_audio(self, latent, metric: Metrics): - latent_len = latent.size(1) - total_audio = [] - chunk_num = latent_len // self.chunk_size - for j in range(chunk_num): - latent_chunk = latent[:,j*self.chunk_size:(j+1)*self.chunk_size] - if self.use_trt: - audio_outputs = self.generate_audio_trt(latent_chunk) - else: - audio_outputs = self.generate_audio_onnx(latent_chunk) - total_audio.append(audio_outputs) - metric.audio_chunk_times.append(time.perf_counter()) - if latent_len % self.chunk_size != 0: - latent_chunk = latent[:,chunk_num*self.chunk_size:] - latent_chunk = torch.cat([latent_chunk, torch.zeros((1, self.chunk_size - latent_chunk.size(1), self.hidden_size), dtype=torch.float32).to('cuda')], 1) - if self.use_trt: - audio_outputs = self.generate_audio_trt(latent_chunk) - else: - audio_outputs = self.generate_audio_onnx(latent_chunk) - audio_outputs = audio_outputs[:latent_len % self.chunk_size * self.frame_shift] - total_audio.append(audio_outputs) - metric.audio_chunk_times.append(time.perf_counter()) - - return total_audio - - def warm_up_onnx(self): - self.onnx_session = onnxruntime.InferenceSession('/data/fishtts/genertor.onnx', providers=['CUDAExecutionProvider']) - warmup_input = torch.zeros(1, self.chunk_size, self.hidden_size).to('cuda') - self.generate_audio_onnx(warmup_input) - print(f'warmup onnx done') - - def warm_up_trt(self): - trt_logger = trt.Logger(trt.Logger.ERROR) - trt_runtime = trt.Runtime(trt_logger) - with open('/data/fishtts/genertor.fp16.trt', 'rb') as f: - self.trt_engine = trt_runtime.deserialize_cuda_engine(f.read()) - warmup_input = torch.zeros(1, self.chunk_size, self.hidden_size).to('cuda') - self.generate_audio_trt(warmup_input) - print(f'warmup trt done') - - - def warm_up(self, use_trt): - self.use_trt = use_trt - if use_trt: - self.warm_up_trt() - else: - self.warm_up_onnx() - -def save_audio(total_audio, path): - total_audio = np.concatenate(total_audio, axis=0) - sf.write(path, total_audio, 24000) - -def run(): - llm = LLM(model='/data/fishtts', gpu_memory_utilization=0.7, dtype=torch.float32, skip_tokenizer_init=True) - for i in range(len(prompts)): - metrics = Metrics(chunk_size=generator.chunk_size) - metrics.time_start = time.perf_counter() - outputs = llm.generate(prompts[i], sampling_params) - for output in outputs: - token_ids = output.outputs[0].token_ids - output_len = len(token_ids) - cur_time = time.perf_counter() - metrics.token_times.extend([cur_time] * output_len) - print(f'{i}: {output_len}') - latent = torch.stack(output.outputs[0].hidden_states, 0).unsqueeze(0).to('cuda') - total_audio = generator.generate_audio(latent, metrics) - - save_audio(total_audio, f'hh_{i}.wav') - print(f'save audio {i}') - - metrics.time_end = time.perf_counter() - metrics.calc_non_streaming() - -async def generate_audio_streaming(latent_queue: asyncio.Queue, id, metrics: Metrics): - latent_buffer = [] - audio_data_buffer = [] - while True: - latent = await latent_queue.get() - if latent is None: - break - latent_buffer.append(latent) - if len(latent_buffer) == generator.chunk_size: - latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') - audio_data_buffer.extend(generator.generate_audio(latent, metrics)) - latent_buffer = [] - - if len(latent_buffer) > 0: - latent = torch.stack(latent_buffer, 0).unsqueeze(0).to('cuda') - audio_data_buffer.extend(generator.generate_audio(latent, metrics)) - - save_audio(audio_data_buffer, f'hh_{id}.wav') - print(f'save audio {id}') - -async def generate_token_streaming(engine: AsyncLLMEngine, prompt, id, latent_queue: asyncio.Queue, metrics: Metrics): - results_generator = engine.generate(prompt, sampling_params, request_id=id) - tokens = [] - async for request_output in results_generator: - metrics.token_times.append(time.perf_counter()) - token_ids = request_output.outputs[0].token_ids[-1] - latent = request_output.outputs[0].hidden_states[-1] - tokens.append(token_ids) - latent_queue.put_nowait(latent) - - latent_queue.put_nowait(None) - print(f'{id}: {len(tokens)}') - -async def get_request(input_requests, request_rate: float,): - requests = iter(input_requests) - for request in requests: - yield request - - if request_rate == float("inf"): - # If the request rate is infinity, then we don't need to wait. - continue - - # Sample the request interval from the exponential distribution. - interval = np.random.exponential(1.0 / request_rate) - # The next request will be sent after the interval. - await asyncio.sleep(interval) - -async def generate_streaming(engine, prompt, request_id) -> Metrics: - metrics = Metrics(chunk_size=generator.chunk_size) - metrics.time_start = time.perf_counter() - - latent_queue = asyncio.Queue() - vllm_task = asyncio.create_task(generate_token_streaming(engine, prompt, request_id, latent_queue, metrics)) - generator_task = asyncio.create_task(generate_audio_streaming(latent_queue, request_id, metrics)) - await vllm_task - await generator_task - metrics.time_end = time.perf_counter() - return metrics - -async def run_streaming(request_rate): - engine_args = AsyncEngineArgs(model='/data/fishtts', gpu_memory_utilization=0.7, dtype=torch.float32, skip_tokenizer_init=True) - engine = AsyncLLMEngine.from_engine_args(engine_args) - if request_rate < 0: - for i in range(len(prompts)): - me = await generate_streaming(engine, prompts[i], i) - me.calc_streaming() - else: - tasks: List[asyncio.Task] = [] - request_id = 0 - me = await generate_streaming(engine, prompts[0], 0) - async for prompt in get_request(prompts, request_rate): - tasks.append(asyncio.create_task(generate_streaming(engine, prompt, request_id))) - request_id += 1 - - metrics_list: List[Metrics] = await asyncio.gather(*tasks) - for metrics in metrics_list: - metrics.calc_streaming() - - -generator = Generator('/data/fishtts', use_trt=False, chunk_size=20, hidden_size=1536, frame_shift=1200) - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--streaming", action="store_true", default=False) - parser.add_argument("--request-rate", - type=float, - default=-1, - help="request rate per second") - parser.add_argument("--chunk-size", - type=int, - default=None, - help="audio chunk size") - parser.add_argument("--use-trt", action="store_true", default=False) - parser.add_argument("--fp16", action="store_true", default=False) - - args = parser.parse_args() - - if args.chunk_size: - generator.chunk_size = args.chunk_size - - generator.warm_up(args.use_trt) - - if not args.streaming: - run() - else: - asyncio.run(run_streaming(args.request_rate)) \ No newline at end of file diff --git a/requirements-cuda.txt b/requirements-cuda.txt index b7cef9e700fbf..49317a1add901 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,7 +8,7 @@ torch == 2.5.1; platform_machine != 'aarch64' # These must be updated alongside torch torchvision == 0.20.1; platform_machine != 'aarch64' # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1 -onnxruntime-gpu +onnxruntime-gpu == 1.19.2 pycuda pypinyin omegaconf \ No newline at end of file diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 31412424f17d8..fb02627eb22bd 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -247,10 +247,6 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceData - # dummy_seq_data = SequenceData([0] * seq_len) - # dummy_multi_modal_data = None - - # return dummy_seq_data, dummy_multi_modal_data return DummyData(SequenceData.from_prompt_token_counts((0, seq_len))) def register_dummy_data(self, factory: DummyDataFactory): diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 6040704de2d80..8b4cf22cc4b2a 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -82,13 +82,19 @@ def _get_logits(self, hidden_states: torch.Tensor, lm_head: Union[VocabParallelEmbedding, nn.Linear], embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: # Get the logits for the next tokens. - if isinstance(lm_head, nn.Linear): - logits = lm_head(hidden_states) + logits = lm_head.linear_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + if self.use_gather: + # None may be returned for rank > 0 + logits = tensor_model_parallel_gather(logits) else: - logits = lm_head.linear_method.apply(lm_head, - hidden_states, - bias=embedding_bias) - logits = tensor_model_parallel_gather(logits) + # Gather is not supported for some devices such as TPUs. + # Use all-gather instead. + # NOTE(woosuk): Here, the outputs of every device should not be None + # because XLA requires strict SPMD among all devices. Every device + # should execute the same operations after gathering the logits. + logits = tensor_model_parallel_all_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[..., :self.org_vocab_size] diff --git a/vllm/model_executor/models/chattts.py b/vllm/model_executor/models/chattts.py deleted file mode 100644 index 09f9887e0b0bf..0000000000000 --- a/vllm/model_executor/models/chattts.py +++ /dev/null @@ -1,197 +0,0 @@ -from array import array -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union - -import torch -from torch import nn -import torch.nn.functional as F -from torch.nn.utils.parametrizations import weight_norm -from transformers import LlamaConfig - -from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.inputs.registry import InputContext -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.multi_head_sampler import MultiheadSampler -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding -from vllm.model_executor.models.llama import LlamaModel -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.speech import SpeechPlugin -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors - -from einops import rearrange -from transformers.generation import TopKLogitsWarper, TopPLogitsWarper - -import lzma -import numpy as np - -def dummy_data_for_ttsllm(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): - - from vllm.sequence import SequenceData - - - dummy_seq_data = SequenceData([0] * seq_len) - dummy_multi_modal_data = {"audio": SpeechPlugin.sample_random_speaker()} - - return dummy_seq_data, dummy_multi_modal_data - -def get_max_speech_tokens(ctx: InputContext): - return 16 - -@MULTIMODAL_REGISTRY.register_speech_input_mapper() -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ttsllm) -@MULTIMODAL_REGISTRY.register_max_speech_tokens(get_max_speech_tokens) -class ChatTtsLlm(nn.Module): - def __init__(self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None,) -> None: - super().__init__() - - # static parameters, put them in config later - self.num_audio_tokens = config.num_audio_tokens - self.num_text_tokens = config.num_text_tokens - self.num_output_head = config.num_output_head - self.spk_emb_token_id = config.spk_emb_token_id - - self.gpt = LlamaModel(config) - self.model_dim = self.gpt.config.hidden_size - self.emb_text = VocabParallelEmbedding(self.num_text_tokens, self.model_dim) - self.emb_code = nn.ModuleList([ - VocabParallelEmbedding(self.num_audio_tokens, self.model_dim) for _ in range(self.num_output_head) - ]) - - self.lm_head = nn.ModuleList([ - ParallelLMHead( - self.num_audio_tokens, - self.model_dim, - org_num_embeddings=self.num_audio_tokens, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - quant_config=quant_config, - ) for _ in range(self.num_output_head) - ]) - self.logits_processor = LogitsProcessor(self.num_audio_tokens) - self.sampler = MultiheadSampler() - # self.samplers = [Sampler(head_idx) for head_idx in range(self.num_output_head)] - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - try: - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - except KeyError: - pass - break - else: - try: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - except KeyError: - pass - - def get_input_embeddings(self, input_ids: torch.Tensor, is_prompt: bool) -> torch.Tensor: - if is_prompt: - emb = self.emb_text(input_ids) - else: - code_emb = [ - self.emb_code[i](input_ids[:,i]) - for i in range(self.num_output_head) - ] - emb = torch.stack(code_emb, 2).sum(2) - return emb - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = [ - self.logits_processor(self.lm_head[i], hidden_states, sampling_metadata) - for i in range(self.num_output_head) - ] - logits = torch.stack(logits, 0).permute(1, 0, 2) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - is_prompt: bool = False, - **kwargs: object - ) -> Union[torch.Tensor, IntermediateTensors]: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids, is_prompt) - spk_emb = kwargs.get("speech", None) - if spk_emb is not None: - self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids) - model_output = self.gpt( - input_ids=input_ids, - inputs_embeds=hidden_states, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors - ) - return model_output - - def apply_spk_emb( - self, - emb: torch.Tensor, - spk_emb: torch.Tensor, - attn_metadata: AttentionMetadata, - input_ids: torch.Tensor, - ): - assert emb.size(-1) == spk_emb.size(-1) - assert attn_metadata.seq_lens_tensor.size(0) == spk_emb.size(0) - # convert spk_emb to the same dtype as emb - spk_emb = spk_emb.to(emb.dtype) - # find the index of the speaker token - indices = (input_ids == self.spk_emb_token_id).nonzero(as_tuple=True) - if indices[0].size(0) == 0: - return - emb.index_put_(indices, spk_emb) - - def merge_sample_results( - self, - source: SamplerOutput, - target: SamplerOutput, - ): - for o_a, o_b in zip(source.outputs, target.outputs): - for s_a, s_b in zip(o_a.samples, o_b.samples): - s_a.output_tokens.append(s_b.output_token) - \ No newline at end of file From 4dde441c7f76d66cccda1b7abcb9285f36acbeb3 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Tue, 24 Dec 2024 09:38:09 +0000 Subject: [PATCH 60/61] make llama work --- vllm/model_executor/layers/multi_head_sampler.py | 4 ++-- vllm/model_executor/models/fishtts.py | 9 ++------- vllm/sequence.py | 2 +- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/multi_head_sampler.py b/vllm/model_executor/layers/multi_head_sampler.py index 3a01e03dca338..74cf755b55c5c 100644 --- a/vllm/model_executor/layers/multi_head_sampler.py +++ b/vllm/model_executor/layers/multi_head_sampler.py @@ -74,7 +74,7 @@ def forward( ) sampled_token_ids_tensor = maybe_sampled_tokens_tensor.reshape(-1, num_heads) - id_next = sampled_token_ids_tensor.tolist() + id_next = sampled_token_ids_tensor.cpu().numpy() if self.include_gpu_probs_tensor: # Since we will defer sampler result Pythonization, @@ -273,7 +273,7 @@ def build_sampler_output(self, seq_outputs: List[SequenceOutput] = [] log_prob = { sample_result[0]: Logprob(logprob=inf, rank=None, decoded_token=None) } seq_output = SequenceOutput(seq_ids[parent_id], sample_result[0], log_prob) - seq_output.output_tokens = sample_result + seq_output.output_tokens = sample_result.tolist() seq_outputs.append(seq_output) sampler_output.append(CompletionSequenceGroupOutput(seq_outputs, prompt_logprobs=None)) diff --git a/vllm/model_executor/models/fishtts.py b/vllm/model_executor/models/fishtts.py index 67872013a358f..74a1eb2025ea2 100644 --- a/vllm/model_executor/models/fishtts.py +++ b/vllm/model_executor/models/fishtts.py @@ -171,17 +171,12 @@ def get_input_embeddings(self, input_ids: torch.Tensor, audio_ref: torch.Tensor, start_token = torch.stack(code_emb, 1).sum(1).to(emb.dtype) # find the index of the audio BOS token - indices = (input_ids == self.audio_start_token_id).nonzero(as_tuple=True) - if indices[0].size(0) != 0: - emb.index_put_(indices, start_token) + emb[-1] = start_token # batch size = 2 # inpudId 7004 7004 XXXX 7004 7001 1 2 34 | 7003 7004 7004 XXXX 7004 7001 1 2 34 7003 # speaker ref [16*2, 1536] - indices = (input_ids == self.audio_ref_token_id).nonzero(as_tuple=True) - if indices[0].size(0) != 0: - for idx, audio_ref_start in enumerate(indices[0]): - emb[audio_ref_start:audio_ref_start+16] = audio_ref[idx].to(emb.dtype) + emb[0:16] = audio_ref[0].to(emb.dtype) else: code_emb = [ diff --git a/vllm/sequence.py b/vllm/sequence.py index bb87843948862..67f118179bbbf 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -192,7 +192,7 @@ def from_prompt_token_counts( (array_full(token_id, count) for token_id, count in token_counts), ) - return SequenceData(prompt_token_ids_arr) + return SequenceData(prompt_token_ids_arr.tolist()) @staticmethod def from_seqs( From 9b5c3c1160ef9efdeaf20be69072156fe1217763 Mon Sep 17 00:00:00 2001 From: Zheng Niu Date: Thu, 26 Dec 2024 08:14:36 +0000 Subject: [PATCH 61/61] support profile run --- xtts/model_setting.py | 8 +++++--- xtts/run.py | 4 +++- xtts/tts_engine.py | 24 +++++++++++++++++++++--- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/xtts/model_setting.py b/xtts/model_setting.py index 300c1562885ce..b9ac071848fbe 100644 --- a/xtts/model_setting.py +++ b/xtts/model_setting.py @@ -15,7 +15,8 @@ def __init__(self, chunk_padding: bool = True, cut_tail: int = 150, support_lora: bool = False, - scale_rate: float = 2.7): + scale_rate: float = 2.7, + profile_run: bool = False): self.model_dir = model_dir self.runtime = runtime self.chunk_size = chunk_size @@ -34,5 +35,6 @@ def __init__(self, self.support_lora = support_lora self.use_onnx_graph = False - self.gpu_memory_utilization = 0.7 - self.scale_rate = scale_rate \ No newline at end of file + self.gpu_memory_utilization = 0.3 + self.scale_rate = scale_rate + self.profile_run = profile_run \ No newline at end of file diff --git a/xtts/run.py b/xtts/run.py index d8418d88b41dd..47f412d7dbc11 100644 --- a/xtts/run.py +++ b/xtts/run.py @@ -37,6 +37,7 @@ parser.add_argument("--temperature", type=float, default=1) parser.add_argument("--scale_rate", type=float, default=2.7) parser.add_argument("--cut_tail", type=int, default=0) + parser.add_argument("--profile-run", action="store_true", default=False) args = parser.parse_args() @@ -54,7 +55,8 @@ chunk_size=args.chunk_size, first_chunk_size=args.first_chunk_size, cut_tail=args.cut_tail, - scale_rate=args.scale_rate) + scale_rate=args.scale_rate, + profile_run=args.profile_run) if args.lora: model_setting.support_lora = True tts_engine = XTtsEngine(model_setting) diff --git a/xtts/tts_engine.py b/xtts/tts_engine.py index 940ec0d71aeef..18801f97e81cb 100644 --- a/xtts/tts_engine.py +++ b/xtts/tts_engine.py @@ -1,4 +1,5 @@ import asyncio +import os import time from typing import Any, Union @@ -32,7 +33,9 @@ def __init__(self, model_setting: ModelSetting): self.post_init() def post_init(self): - + if self.model_setting.profile_run: + os.environ["VLLM_TORCH_PROFILER_DIR"] = "/home/zhn/vllm_profile3" + # initialize tokenizer logger.info('Loading tokenizer...') self.tokenizer = Tokenizer.from_file('/home/zhn/fishtts/vocab.json') @@ -62,7 +65,8 @@ def post_init(self): gpu_memory_utilization=self.model_setting.gpu_memory_utilization, dtype=self.model_setting.dtype, skip_tokenizer_init=True) - + self.max_tokens = 2048 + logger.info('LLM loaded.') # initialize audio generator @@ -123,11 +127,21 @@ def synthesize(self, texts: List[str], output_dir: str, lora_path: str = None, lora_request = None if lora_path: lora_request=LoRARequest("lora", 1, lora_local_path=lora_path) + + if self.model_setting.profile_run: + os.environ["VLLM_TORCH_PROFILER_DIR"] = "/home/zhn/vllm_profile3" + self.max_tokens = 10 + self.llm_engine.start_profile() + for i,p in enumerate(prompts): logger.debug(f'Processing text {i+1}/{len(prompts)}...') + if self.model_setting.profile_run and i == 3: + logger.debug('Early stop for profiling') + break + metrics = TtsMetrics(chunk_size=self.model_setting.chunk_size) metrics.time_start = time.perf_counter() - sampling_params = SamplingParams(detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, + sampling_params = SamplingParams(detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=self.max_tokens, repetition_penalty=1.5, repetition_window=16, top_k=top_k, top_p=top_p, temperature=temperature) outputs = self.llm_engine.generate(p, sampling_params, lora_request=lora_request) @@ -147,6 +161,10 @@ def synthesize(self, texts: List[str], output_dir: str, lora_path: str = None, logger.debug(f'Audio generated') metrics.time_end = time.perf_counter() metrics.calc_non_streaming() + + if self.model_setting.profile_run: + self.llm_engine.stop_profile() + time.sleep(15) async def generate_token_streaming(self, engine: AsyncLLMEngine, prompt: dict[str, Any], lora_request: LoRARequest, id: str, sampling_params: SamplingParams,