From b0986be486c8d9b3992fa7c04c4817807a0d6c2e Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Fri, 29 Sep 2023 17:59:04 +0800 Subject: [PATCH] Early precomputation implementation I still have to implement embed loading, but its a very slow working POC --- muse_maskgit_pytorch/dataset.py | 50 +++++++++++-- .../trainers/maskgit_trainer.py | 11 +-- train_muse_maskgit.py | 73 ++++++++++++++++++- 3 files changed, 122 insertions(+), 12 deletions(-) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 68f7a0d..6eb3693 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -119,6 +119,7 @@ def __init__( stream=False, using_taming=False, random_crop=False, + embeds=[], ): super().__init__( dataset, @@ -132,21 +133,28 @@ def __init__( ) self.caption_column: str = caption_column self.tokenizer: T5Tokenizer = tokenizer + self.embeds: list = embeds def __getitem__(self, index): try: image = self.dataset[index][self.image_column] descriptions = self.dataset[index][self.caption_column] + if self.embeds: + embed = self.embeds[index] except PIL.UnidentifiedImageError: print("Error reading image, most likely corrupt, skipping...") image_found = False + embed = None current_index = 1 while not image_found: try: image = self.dataset[index + current_index][self.image_column] descriptions = self.dataset[index + current_index][self.caption_column] + if self.embeds: + embed = self.embeds[index + current_index] image_found = True except PIL.UnidentifiedImageError: + embed = None current_index += 1 if self.caption_column is None or descriptions is None: @@ -171,9 +179,15 @@ def __getitem__(self, index): attn_mask = encoded.attention_mask if self.using_taming: - return self.transform(image) - 0.5, input_ids[0], attn_mask[0] + if self.embeds: + return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed + else: + return self.transform(image), input_ids[0], attn_mask[0], [] else: - return self.transform(image), input_ids[0], attn_mask[0] + if self.embeds: + return self.transform(image), input_ids[0], attn_mask[0], embed + else: + return self.transform(image), input_ids[0], attn_mask[0], [] class URLTextDataset(ImageDataset): @@ -187,6 +201,7 @@ def __init__( flip=True, center_crop=True, using_taming=True, + embeds=[], ): super().__init__( dataset, @@ -198,16 +213,21 @@ def __init__( ) self.caption_column: str = caption_column self.tokenizer: T5Tokenizer = tokenizer + self.embeds: list = embeds def __getitem__(self, index): try: image = pImage.open(BytesIO(requests.get(self.dataset[index][self.image_column]).content)) + if self.embeds: + embed = self.embeds[index] except ConnectionError: try: print("Image request failure, attempting next image") index += 1 image = pImage.open(BytesIO(requests.get(self.dataset[index][self.image_column]).content)) + if self.embeds: + embed = self.embeds[index] except ConnectionError: raise ConnectionError("Unable to request image from the Dataset") @@ -232,10 +252,17 @@ def __getitem__(self, index): input_ids = encoded.input_ids attn_mask = encoded.attention_mask + if self.using_taming: - return self.transform(image) - 0.5, input_ids[0], attn_mask[0] + if self.embeds: + return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed + else: + return self.transform(image), input_ids[0], attn_mask[0], [] else: - return self.transform(image), input_ids[0], attn_mask[0] + if self.embeds: + return self.transform(image), input_ids[0], attn_mask[0], embed + else: + return self.transform(image), input_ids[0], attn_mask[0], [] class LocalTextImageDataset(Dataset): @@ -249,10 +276,12 @@ def __init__( using_taming=False, random_crop=False, alpha_channel=False, + embeds=[], ): super().__init__() self.tokenizer = tokenizer self.using_taming = using_taming + self.embeds: list = embeds print("Building dataset...") @@ -305,6 +334,9 @@ def __getitem__(self, index): else: text = Path(descriptions).read_text(encoding="utf-8").split("\n") + if self.embeds: + embed = self.embeds[index] + # max length from the paper encoded = self.tokenizer.batch_encode_plus( [str(text)], @@ -317,9 +349,15 @@ def __getitem__(self, index): input_ids = encoded.input_ids attn_mask = encoded.attention_mask if self.using_taming: - return self.transform(image) - 0.5, input_ids[0], attn_mask[0] + if self.embeds: + return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed + else: + return self.transform(image), input_ids[0], attn_mask[0], [] else: - return self.transform(image), input_ids[0], attn_mask[0] + if self.embeds: + return self.transform(image), input_ids[0], attn_mask[0], embed + else: + return self.transform(image), input_ids[0], attn_mask[0], [] def get_directory_size(path): diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index d00f5dc..cfce2f1 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -154,14 +154,15 @@ def train(self): # logs for epoch in range(self.current_step // len(self.dl), self.num_epochs): - for imgs, input_ids, attn_mask in iter(self.dl): + for imgs, input_ids, attn_mask, text_embeds in iter(self.dl): train_loss = 0.0 steps = int(self.steps.item()) - with torch.no_grad(): - text_embeds = t5_encode_text_from_encoded( - input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device - ) + if not text_embeds: + with torch.no_grad(): + text_embeds = t5_encode_text_from_encoded( + input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device + ) with self.accelerator.accumulate(self.model), self.accelerator.autocast(): loss = self.model(imgs, text_embeds=text_embeds) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index ce72b64..30118d8 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -7,15 +7,18 @@ import accelerate import datasets import diffusers +import torch import transformers -import wandb from accelerate.utils import ProjectConfiguration from datasets import load_dataset from diffusers.optimization import SchedulerType, get_scheduler from omegaconf import OmegaConf from rich import inspect from torch.optim import Optimizer +from tqdm import tqdm +import wandb +from muse_maskgit_pytorch.t5 import t5_encode_text_from_encoded from muse_maskgit_pytorch.utils import ( get_latest_checkpoints, ) @@ -424,6 +427,12 @@ default="flash", help="what type of attention to use [ein, flash, xformers] | Default: flash", ) +parser.add_argument( + "--precompute", + action="store_true", + default=False, + help="whether to precompute text embeds", +) @dataclass @@ -497,6 +506,7 @@ class Arguments: debug: bool = False config_path: Optional[str] = None attention_type: str = "flash" + precompute: bool = False def main(): @@ -852,6 +862,67 @@ def main(): }, ) + embeds = [] + if args.precompute: + accelerator.print("Beginning pre-computation of embeddings using T5...") + maskgit.transformer.t5.requires_grad_(False) + for imgs, input_ids, attn_mask, _ in tqdm(iter(dataloader)): + with torch.no_grad(): + embedding = t5_encode_text_from_encoded(input_ids, attn_mask, maskgit.transformer.t5, "cpu") + embeds.append(embedding) + + with accelerator.main_process_first(): + if args.no_cache and args.train_data_dir: + dataset = LocalTextImageDataset( + args.train_data_dir, + args.image_size, + tokenizer=transformer.tokenizer, + center_crop=False if args.no_center_crop else True, + flip=False if args.no_flip else True, + using_taming=False if not args.taming_model_path else True, + random_crop=args.random_crop if args.random_crop else False, + alpha_channel=False if args.channels == 3 else True, + embeds=embeds, + ) + elif args.link: + if not args.dataset_name: + raise AssertionError("You can only use links in huggingface datasets") + + dataset = URLTextDataset( + dataset, + args.image_size, + transformer.tokenizer, + image_column=args.image_column, + caption_column=args.caption_column, + center_crop=False if args.no_center_crop else True, + flip=False if args.no_flip else True, + using_taming=False if not args.taming_model_path else True, + embeds=embeds, + ) + else: + dataset = ImageTextDataset( + dataset, + args.image_size, + transformer.tokenizer, + image_column=args.image_column, + caption_column=args.caption_column, + center_crop=False if args.no_center_crop else True, + flip=False if args.no_flip else True, + stream=args.streaming, + using_taming=False if not args.taming_model_path else True, + embeds=embeds, + ) + + accelerator.print("Embeddings pre-computed!") + + # Create the dataloaders + dataloader, validation_dataloader = split_dataset_into_dataloaders( + dataset, + args.valid_frac if not args.streaming else 0, + args.seed, + args.batch_size, + ) + # Create the trainer accelerator.wait_for_everyone() trainer = MaskGitTrainer(