From 56c5bcf5f5115ebbeeb09f6113b101cb918e9048 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Thu, 14 Nov 2024 03:22:41 +0000 Subject: [PATCH 1/5] refactor(gptfast): simplify the code --- gptfast/generate.py | 372 +++++++++++--------------------------------- 1 file changed, 89 insertions(+), 283 deletions(-) diff --git a/gptfast/generate.py b/gptfast/generate.py index 1e8446e..214fe84 100644 --- a/gptfast/generate.py +++ b/gptfast/generate.py @@ -1,11 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -import contextlib -import itertools - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. import sys import time from pathlib import Path @@ -19,38 +14,6 @@ from PIL import Image from torch.nn.attention import SDPBackend from transformers import AutoProcessor, AutoTokenizer - - -def get_model_size_in_bytes(model, ignore_embeddings=False): - """ - Returns the model size in bytes. The option to ignore embeddings - is useful for models with disproportionately large embeddings compared - to other model parameters that get quantized/sparsified. - """ - - def flat_size(tensor): - if hasattr(tensor, "__tensor_flatten__"): - size = 0 - # 0th element is a list of attributes that - # hold tensors - for attr_name in tensor.__tensor_flatten__()[0]: - sub_tensor = getattr(tensor, attr_name) - size += flat_size(sub_tensor) - return size - else: - return tensor.numel() * tensor.element_size() - - model_size = 0 - for name, child in model.named_children(): - if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings): - for p in itertools.chain( - child.parameters(recurse=False), child.buffers(recurse=False) - ): - model_size += flat_size(p) - model_size += get_model_size_in_bytes(child, ignore_embeddings) - return model_size - - from model import ModelArgs @@ -332,258 +295,101 @@ def setup_model_compilation( prefill = torch.compile(prefill, fullgraph=True, dynamic=True) -def process_generation( - model, - inputs, - tokenizer, - i, - num_samples, - profile, - device, - stop_strings=None, - **generation_kwargs, -): - t0 = time.perf_counter() - - # Encode stop strings once at the start - stop_sequences = None - if stop_strings: - stop_sequences = [ - torch.tensor(tokenizer.encode(stop), dtype=torch.int, device=device) - for stop in stop_strings - ] - - prof = ( - torch.profiler.profile(with_stack=True) - if i == num_samples - 1 and profile - else contextlib.nullcontext() - ) - - with prof: - - def callback(new_tokens): - if stop_sequences: - generated = torch.cat(new_tokens) - return any( - generated.size(0) >= stop_seq.size(0) - and torch.equal(generated[-stop_seq.size(0) :], stop_seq) - for stop_seq in stop_sequences - ) - return False - - output = generate(model, **inputs, callback=callback, **generation_kwargs) - - if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") - return None - - if hasattr(prof, "export_chrome_trace"): - prof.export_chrome_trace(f"{profile}.json") - - device_sync(device=device) - generation_time = time.perf_counter() - t0 - - print(tokenizer.decode(output)) - return output, generation_time - - -def print_metrics(tokens_per_sec, model_size): - print("==========") - tokpersec = torch.mean(torch.tensor(tokens_per_sec)).item() - bandwidth = model_size * tokpersec - mem = torch.cuda.max_memory_reserved() / 1e9 - print(f"Average tokens/sec: {tokpersec:.2f}") - print(f"Average Bandwidth: {bandwidth:.02f} GB/s") - print(f"Peak Memory Usage: {mem:.02f} GB") - print(f"Model Size: {model_size:.02f} GB") - - -def main( - checkpoint_path, - prompt: str = "Hello, my name is", - image_path: str = None, - num_samples: int = 5, - max_new_tokens: int = 100, - top_k: int = 200, - temperature: float = 0.8, - cache_size: Optional[int] = None, - linear_causal_mask: bool = False, - compile: bool = True, - compile_prefill: bool = False, - apply_regional_compilation: bool = False, - profile: Optional[Path] = None, - memory_profile: Optional[Path] = None, - device=default_device, - precision=torch.bfloat16, - stop_strings: Optional[list] = None, -) -> None: - recommended_inductor_config_setter() - assert checkpoint_path.is_file(), checkpoint_path - - model, tokenizer, processor = load_model_and_tokenizer( - checkpoint_path, device, precision - ) - - inputs = ( - prepare_image_inputs(image_path, prompt, processor, precision) - if image_path - else prepare_text_inputs(prompt, tokenizer) - ) - inputs = {k: v.to(device) if v is not None else v for k, v in inputs.items()} - - prompt_length = inputs["input_ids"].size(1) - torch.manual_seed(1234) - model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 - - setup_model_compilation(model, compile, compile_prefill, apply_regional_compilation) - - if memory_profile: - torch.cuda.memory._record_memory_history( - True, trace_alloc_max_entries=250000, trace_alloc_record_context=True +class GenerationConfig: + """Configuration class for text generation parameters.""" + def __init__( + self, + max_new_tokens: int = 100, + top_k: int = 200, + temperature: float = 0.8, + cache_size: Optional[int] = None, + linear_causal_mask: bool = False, + stop_strings: Optional[list[str]] = None + ): + self.max_new_tokens = max_new_tokens + self.top_k = top_k + self.temperature = temperature + self.cache_size = cache_size + self.linear_causal_mask = linear_causal_mask + self.stop_strings = stop_strings or ["<|im_end|>"] + +class ModelConfig: + """Configuration class for model loading and compilation settings.""" + def __init__( + self, + checkpoint_path: Path, + device: str = default_device, + precision: torch.dtype = torch.bfloat16, + compile: bool = False, + compile_prefill: bool = False, + apply_regional_compilation: bool = False + ): + self.checkpoint_path = checkpoint_path + self.device = device + self.precision = precision + self.compile = compile + self.compile_prefill = compile_prefill + self.apply_regional_compilation = apply_regional_compilation + +class Generator: + """Main class for handling text generation.""" + def __init__( + self, + model_config: ModelConfig, + generation_config: GenerationConfig + ): + self.model_config = model_config + self.generation_config = generation_config + self.model = None + self.tokenizer = None + self.processor = None + + self._setup_model() + + def _setup_model(self): + """Initialize model, tokenizer and processor.""" + self.model, self.tokenizer, self.processor = load_model_and_tokenizer( + self.model_config.checkpoint_path, + self.model_config.device, + self.model_config.precision ) - - tokens_per_sec = [] - start = -1 if compile or apply_regional_compilation else 0 - - generation_kwargs = { - "max_new_tokens": max_new_tokens, - "temperature": temperature, - "top_k": top_k, - "cache_size": cache_size, - "linear_causal_mask": linear_causal_mask, - "stop_strings": stop_strings, - } - - for i in range(start, num_samples): - if i == 0: - torch.cuda.reset_peak_memory_stats() - device_sync(device=device) - - result = process_generation( - model, - inputs, - tokenizer, - i, - num_samples, - profile, - device, - **generation_kwargs, + setup_model_compilation( + self.model, + self.model_config.compile, + self.model_config.compile_prefill, + self.model_config.apply_regional_compilation ) - if result is None: - continue - - output, generation_time = result - tokens_generated = output.size(0) - prompt_length - print(f"Tokens generated: {tokens_generated}") - tokens_sec = tokens_generated / generation_time - tokens_per_sec.append(tokens_sec) - print( - f"Time for inference {i + 1}: {generation_time:.02f} sec total, {tokens_sec:.02f} tokens/sec" + def generate(self, prompt: str, image_path: Optional[str] = None) -> str: + """Generate text from prompt and optional image.""" + inputs = ( + prepare_image_inputs(image_path, prompt, self.processor, self.model_config.precision) + if image_path + else prepare_text_inputs(prompt, self.tokenizer) ) - print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") - - if memory_profile and i == 0: - snapshot = torch.cuda.memory._snapshot() - with open(f"{memory_profile}.pickle", "wb") as f: - from pickle import dump - - dump(snapshot, f) - print( - f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use", - "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html", - ) - break - - print_metrics(tokens_per_sec, model_size) + inputs = {k: v.to(self.model_config.device) if v is not None else v for k, v in inputs.items()} + + output = generate( + self.model, + **inputs, + max_new_tokens=self.generation_config.max_new_tokens, + temperature=self.generation_config.temperature, + top_k=self.generation_config.top_k, + cache_size=self.generation_config.cache_size, + linear_causal_mask=self.generation_config.linear_causal_mask, + ) + + return self.tokenizer.decode(output) if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Your CLI description.") - parser.add_argument( - "checkpoint_path", - type=Path, - help="Model checkpoint path.", - ) - parser.add_argument( - "--prompt", - type=str, - default="Explain what is the meaning of life", - help="Input prompt.", + model_config = ModelConfig( + checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"), ) - parser.add_argument("--image_path", type=str, default=None, help="Image path.") - parser.add_argument("--num_samples", type=int, default=5, help="Number of samples.") - parser.add_argument( - "--max_new_tokens", type=int, default=200, help="Maximum number of new tokens." - ) - parser.add_argument("--top_k", type=int, default=200, help="Top-k for sampling.") - parser.add_argument( - "--temperature", type=float, default=0.8, help="Temperature for sampling." - ) - parser.add_argument( - "--cache_size", - type=int, - default=None, - help="Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size", - ) - parser.add_argument( - "--linear_causal_mask", - action="store_true", - help="Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)", - ) - parser.add_argument( - "--compile", action="store_true", help="Whether to compile the model." - ) - parser.add_argument( - "--compile_prefill", - action="store_true", - help="Whether to compile the prefill (improves prefill perf, but higher compile times)", - ) - parser.add_argument( - "--apply_regional_compilation", - action="store_true", - help="Whether to apply regional compilation to the layers of the model", - ) - parser.add_argument("--profile", type=Path, default=None, help="Profile path.") - parser.add_argument( - "--memory_profile", type=Path, default=None, help="filename for memory profile." - ) - parser.add_argument( - "--device", type=str, default=default_device, help="Device to use" - ) - parser.add_argument( - "--precision", - type=lambda x: getattr(torch, x.split(".")[-1]), - default=torch.bfloat16, - help="dtype precision to use", - ) - parser.add_argument( - "--stop_strings", - type=str, - nargs="+", - default=["<|im_end|>"], - help="List of strings that will stop generation when encountered at the end", - ) - - args = parser.parse_args() - main( - args.checkpoint_path, - args.prompt, - args.image_path, - args.num_samples, - args.max_new_tokens, - args.top_k, - args.temperature, - args.cache_size, - args.linear_causal_mask, - args.compile, - args.compile_prefill, - args.apply_regional_compilation, - args.profile, - args.memory_profile, - args.device, - args.precision, - args.stop_strings, + generation_config = GenerationConfig( + max_new_tokens=100, + top_k=200, + temperature=0.8, ) + generator = Generator(model_config, generation_config) + print(generator.generate("Hello, world!")) From 81cd3a1a079d4339f54ac2e5b07daf85a8387f11 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Thu, 14 Nov 2024 03:38:00 +0000 Subject: [PATCH 2/5] feat(gptfast): add stop_strings support --- gptfast/generate.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/gptfast/generate.py b/gptfast/generate.py index 214fe84..3c08cd3 100644 --- a/gptfast/generate.py +++ b/gptfast/generate.py @@ -369,6 +369,16 @@ def generate(self, prompt: str, image_path: Optional[str] = None) -> str: ) inputs = {k: v.to(self.model_config.device) if v is not None else v for k, v in inputs.items()} + def early_stop_generation(tokens): + # This is not efficient, but it works + for stop_string in self.generation_config.stop_strings: + + token_list = torch.cat(tokens) + decoded_string = self.tokenizer.decode(token_list) + if decoded_string.endswith(stop_string): + return True + return False + output = generate( self.model, **inputs, @@ -377,6 +387,7 @@ def generate(self, prompt: str, image_path: Optional[str] = None) -> str: top_k=self.generation_config.top_k, cache_size=self.generation_config.cache_size, linear_causal_mask=self.generation_config.linear_causal_mask, + callback=early_stop_generation ) return self.tokenizer.decode(output) From fd709fe05b51ce15fb3988cb4954e58b0a978542 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Thu, 14 Nov 2024 05:43:47 +0000 Subject: [PATCH 3/5] support image understanding --- gptfast/generate.py | 73 ++++++++++++++------------------------------- gptfast/model.py | 23 ++++++++++++-- 2 files changed, 42 insertions(+), 54 deletions(-) diff --git a/gptfast/generate.py b/gptfast/generate.py index 3c08cd3..6050365 100644 --- a/gptfast/generate.py +++ b/gptfast/generate.py @@ -14,7 +14,7 @@ from PIL import Image from torch.nn.attention import SDPBackend from transformers import AutoProcessor, AutoTokenizer -from model import ModelArgs +from model import ModelArgs, prepare_inputs_for_model def device_sync(device): @@ -128,72 +128,41 @@ def generate( # create an empty tensor of the expected final shape and fill in the current tokens device = input_ids.device - original_prompt_token_count = input_ids.numel() - - if pixel_values is not None: - input_embeds = model.prepare_embeddings(input_ids, pixel_values, pixel_mask) - prompt_token_count_after_inserting_image_tokens = input_embeds.shape[1] - else: - input_embeds = None - prompt_token_count_after_inserting_image_tokens = input_ids.numel() + T = input_ids.numel() # calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size) - max_seq_length = min( - prompt_token_count_after_inserting_image_tokens + max_new_tokens, - model.config.block_size, - ) - new_tokens = max_seq_length - prompt_token_count_after_inserting_image_tokens + max_seq_length = min(T + max_new_tokens, model.config.block_size) + new_tokens = max_seq_length - T # full prompt+output will be stored in seq seq = torch.empty(max_seq_length, dtype=input_ids.dtype, device=device) - seq[:original_prompt_token_count] = input_ids.view(-1) + seq[:T] = input_ids.view(-1) # setup model caches with torch.device(device): if cache_size is None: cache_size = max_seq_length - assert ( - cache_size >= max_seq_length - ), "need cache_size to be greater than max_new_tokens + size-of-prompt" - model.setup_caches( - max_batch_size=1, - max_seq_length=cache_size, - linear_causal_mask=linear_causal_mask, - prompt_length=prompt_token_count_after_inserting_image_tokens, - ) + assert cache_size >= max_seq_length, "need cache_size to be greater than max_new_tokens + size-of-prompt" + model.setup_caches(max_batch_size=1, max_seq_length=cache_size, linear_causal_mask=linear_causal_mask, prompt_length=T) - input_pos = torch.arange( - 0, - prompt_token_count_after_inserting_image_tokens, - device=device, - dtype=torch.int, - ) + # format model input + x, input_pos = prepare_inputs_for_model(input_ids, max_new_tokens) + if pixel_values is not None: + input_embeds = model.prepare_embeddings(x, pixel_values, pixel_mask) + else: + input_embeds = None # execute prefill - next_token = prefill( - model, input_ids, input_pos, input_embeds, **sampling_kwargs - ).clone() - seq[original_prompt_token_count] = next_token - input_pos = torch.tensor( - [prompt_token_count_after_inserting_image_tokens], - device=device, - dtype=torch.int, - ) + next_token = prefill(model, x, input_pos, input_embeds, **sampling_kwargs).clone() + seq[T] = next_token # execute token generation - generated_tokens, _ = decode_n_tokens( - model, - next_token.view(1, -1), - input_pos, - new_tokens - 1, - callback=callback, - **sampling_kwargs, - ) + input_pos = torch.tensor([T], device=device, dtype=torch.int) + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) - seq = torch.cat((seq[: original_prompt_token_count + 1], *generated_tokens)) + seq = torch.cat((seq[:T+1], *generated_tokens)) return seq - def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = tokenizer.encode(string) if bos: @@ -398,9 +367,11 @@ def early_stop_generation(tokens): checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"), ) generation_config = GenerationConfig( - max_new_tokens=100, + max_new_tokens=500, top_k=200, temperature=0.8, ) generator = Generator(model_config, generation_config) - print(generator.generate("Hello, world!")) + + image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" + print(generator.generate("describe the image", image_path)) diff --git a/gptfast/model.py b/gptfast/model.py index 7fa029c..fece907 100644 --- a/gptfast/model.py +++ b/gptfast/model.py @@ -642,10 +642,27 @@ def prepare_embeddings(self, idx: Tensor, pixel_values: Tensor, pixel_mask: Tens ) inputs_embeds = self.llm.tok_embeddings(idx) - x = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, idx + + n_image_tokens = (idx == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] * image_features.shape[1] + + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (idx == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter( + special_image_mask, image_features ) - return x + return inputs_embeds def forward( self, From b2a15af95377af10bb769524e797d79de43aff1d Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Thu, 14 Nov 2024 05:44:10 +0000 Subject: [PATCH 4/5] make format happy --- gptfast/generate.py | 60 ++++++++++++++++++++++++++++++--------------- gptfast/model.py | 8 ++---- 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/gptfast/generate.py b/gptfast/generate.py index 6050365..dc3c676 100644 --- a/gptfast/generate.py +++ b/gptfast/generate.py @@ -10,11 +10,10 @@ import torch import torch._dynamo.config import torch._inductor.config -from model import Aria, ModelArgs, Transformer +from model import Aria, ModelArgs, Transformer, prepare_inputs_for_model from PIL import Image from torch.nn.attention import SDPBackend from transformers import AutoProcessor, AutoTokenizer -from model import ModelArgs, prepare_inputs_for_model def device_sync(device): @@ -142,8 +141,15 @@ def generate( with torch.device(device): if cache_size is None: cache_size = max_seq_length - assert cache_size >= max_seq_length, "need cache_size to be greater than max_new_tokens + size-of-prompt" - model.setup_caches(max_batch_size=1, max_seq_length=cache_size, linear_causal_mask=linear_causal_mask, prompt_length=T) + assert ( + cache_size >= max_seq_length + ), "need cache_size to be greater than max_new_tokens + size-of-prompt" + model.setup_caches( + max_batch_size=1, + max_seq_length=cache_size, + linear_causal_mask=linear_causal_mask, + prompt_length=T, + ) # format model input x, input_pos = prepare_inputs_for_model(input_ids, max_new_tokens) @@ -157,12 +163,20 @@ def generate( seq[T] = next_token # execute token generation input_pos = torch.tensor([T], device=device, dtype=torch.int) - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) + generated_tokens, _ = decode_n_tokens( + model, + next_token.view(1, -1), + input_pos, + new_tokens - 1, + callback=callback, + **sampling_kwargs, + ) - seq = torch.cat((seq[:T+1], *generated_tokens)) + seq = torch.cat((seq[: T + 1], *generated_tokens)) return seq + def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = tokenizer.encode(string) if bos: @@ -266,6 +280,7 @@ def setup_model_compilation( class GenerationConfig: """Configuration class for text generation parameters.""" + def __init__( self, max_new_tokens: int = 100, @@ -273,7 +288,7 @@ def __init__( temperature: float = 0.8, cache_size: Optional[int] = None, linear_causal_mask: bool = False, - stop_strings: Optional[list[str]] = None + stop_strings: Optional[list[str]] = None, ): self.max_new_tokens = max_new_tokens self.top_k = top_k @@ -282,8 +297,10 @@ def __init__( self.linear_causal_mask = linear_causal_mask self.stop_strings = stop_strings or ["<|im_end|>"] + class ModelConfig: """Configuration class for model loading and compilation settings.""" + def __init__( self, checkpoint_path: Path, @@ -291,7 +308,7 @@ def __init__( precision: torch.dtype = torch.bfloat16, compile: bool = False, compile_prefill: bool = False, - apply_regional_compilation: bool = False + apply_regional_compilation: bool = False, ): self.checkpoint_path = checkpoint_path self.device = device @@ -300,19 +317,17 @@ def __init__( self.compile_prefill = compile_prefill self.apply_regional_compilation = apply_regional_compilation + class Generator: """Main class for handling text generation.""" - def __init__( - self, - model_config: ModelConfig, - generation_config: GenerationConfig - ): + + def __init__(self, model_config: ModelConfig, generation_config: GenerationConfig): self.model_config = model_config self.generation_config = generation_config self.model = None self.tokenizer = None self.processor = None - + self._setup_model() def _setup_model(self): @@ -320,23 +335,28 @@ def _setup_model(self): self.model, self.tokenizer, self.processor = load_model_and_tokenizer( self.model_config.checkpoint_path, self.model_config.device, - self.model_config.precision + self.model_config.precision, ) setup_model_compilation( self.model, self.model_config.compile, self.model_config.compile_prefill, - self.model_config.apply_regional_compilation + self.model_config.apply_regional_compilation, ) def generate(self, prompt: str, image_path: Optional[str] = None) -> str: """Generate text from prompt and optional image.""" inputs = ( - prepare_image_inputs(image_path, prompt, self.processor, self.model_config.precision) + prepare_image_inputs( + image_path, prompt, self.processor, self.model_config.precision + ) if image_path else prepare_text_inputs(prompt, self.tokenizer) ) - inputs = {k: v.to(self.model_config.device) if v is not None else v for k, v in inputs.items()} + inputs = { + k: v.to(self.model_config.device) if v is not None else v + for k, v in inputs.items() + } def early_stop_generation(tokens): # This is not efficient, but it works @@ -356,9 +376,9 @@ def early_stop_generation(tokens): top_k=self.generation_config.top_k, cache_size=self.generation_config.cache_size, linear_causal_mask=self.generation_config.linear_causal_mask, - callback=early_stop_generation + callback=early_stop_generation, ) - + return self.tokenizer.decode(output) diff --git a/gptfast/model.py b/gptfast/model.py index fece907..5220bed 100644 --- a/gptfast/model.py +++ b/gptfast/model.py @@ -656,12 +656,8 @@ def prepare_embeddings(self, idx: Tensor, pixel_values: Tensor, pixel_mask: Tens .expand_as(inputs_embeds) .to(inputs_embeds.device) ) - image_features = image_features.to( - inputs_embeds.device, inputs_embeds.dtype - ) - inputs_embeds = inputs_embeds.masked_scatter( - special_image_mask, image_features - ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) return inputs_embeds def forward( From 7f70c95c3f58d6db8b93cf334e4395a56c1b3809 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Thu, 14 Nov 2024 05:44:48 +0000 Subject: [PATCH 5/5] remove unused code --- gptfast/model.py | 72 ------------------------------------------------ 1 file changed, 72 deletions(-) diff --git a/gptfast/model.py b/gptfast/model.py index 5220bed..358566e 100644 --- a/gptfast/model.py +++ b/gptfast/model.py @@ -562,78 +562,6 @@ def __init__(self, config: ModelArgs): self.llm = Transformer(config) - def _merge_input_ids_with_image_features( - self, image_features, inputs_embeds, input_ids - ): - num_images, num_image_patches, embed_dim = image_features.shape - batch_size, sequence_length = input_ids.shape - - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == self.config.image_token_index - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # Compute the maximum embed dimension - max_embed_dim = ( - num_special_image_tokens.max() * (num_image_patches - 1) - ) + sequence_length - batch_indices, non_image_indices = torch.where( - input_ids != self.config.image_token_index - ) - - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - new_token_positions = ( - torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - - 1 - ) - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - - # 3. Create the full embedding, already padded to the maximum position - final_embedding = torch.zeros( - batch_size, - max_embed_dim, - embed_dim, - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[ - batch_indices, non_image_indices - ] - - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) - image_to_overwrite = torch.full( - (batch_size, max_embed_dim), - True, - dtype=torch.bool, - device=inputs_embeds.device, - ) - image_to_overwrite[batch_indices, text_to_overwrite] = False - - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): - raise ValueError( - f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" - f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." - ) - - final_embedding[image_to_overwrite] = ( - image_features.contiguous().reshape(-1, embed_dim).to(target_device) - ) - - return final_embedding - def prepare_embeddings(self, idx: Tensor, pixel_values: Tensor, pixel_mask: Tensor): image_outputs, image_attn_mask = self.vision_tower(pixel_values, pixel_mask) selected_image_feature = image_outputs.last_hidden_state