diff --git a/infer_vae.py b/infer_vae.py index e89f776..2d876df 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -1,21 +1,14 @@ import argparse -import glob -import hashlib import os import random -import re -import shutil from dataclasses import dataclass -from datetime import datetime from typing import Optional import accelerate -import PIL import torch from accelerate.utils import ProjectConfiguration from datasets import Dataset, Image, load_dataset from torchvision.utils import save_image -from tqdm import tqdm from muse_maskgit_pytorch import ( VQGanVAE, @@ -28,6 +21,7 @@ ) from muse_maskgit_pytorch.utils import ( get_latest_checkpoints, + vae_folder_validation, ) from muse_maskgit_pytorch.vqvae import VQVAE @@ -458,106 +452,7 @@ def main(): save_image(recon, f"{args.results_dir}/outputs/output.png") if args.input_folder: - # Create output directory and save input images and reconstructions as grids - output_dir = os.path.join(args.results_dir, "outputs", os.path.basename(args.input_folder)) - os.makedirs(output_dir, exist_ok=True) - - for i in tqdm(range(len(dataset))): - retries = 0 - while True: - try: - save_image(dataset[i], f"{output_dir}/input.png") - - if not args.use_paintmind: - # encode - _, ids, _ = vae.encode( - dataset[i][None].to( - "cpu" - if args.cpu - else accelerator.device - if args.gpu == 0 - else f"cuda:{args.gpu}" - ) - ) - # decode - recon = vae.decode_from_ids(ids) - # print (recon.shape) # torch.Size([1, 3, 512, 1136]) - save_image(recon, f"{output_dir}/output.png") - else: - # encode - encoded, _, _ = vae.encode( - dataset[i][None].to( - "cpu" - if args.cpu - else accelerator.device - if args.gpu == 0 - else f"cuda:{args.gpu}" - ) - ) - - # decode - recon = vae.decode(encoded).squeeze(0) - recon = torch.clamp(recon, -1.0, 1.0) - save_image(recon, f"{output_dir}/output.png") - - # Load input and output images - input_image = PIL.Image.open(f"{output_dir}/input.png") - output_image = PIL.Image.open(f"{output_dir}/output.png") - - # Create horizontal grid with input and output images - grid_image = PIL.Image.new( - "RGB" if args.channels == 3 else "RGBA", - (input_image.width + output_image.width, input_image.height), - ) - grid_image.paste(input_image, (0, 0)) - grid_image.paste(output_image, (input_image.width, 0)) - - # Save grid - now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S") - hash = hashlib.sha1(input_image.tobytes()).hexdigest() - - filename = f"{hash}_{now}-{os.path.basename(args.vae_path)}.png" - grid_image.save(f"{output_dir}/{filename}", format="PNG") - - if not args.save_originals: - # Remove input and output images after the grid was made. - os.remove(f"{output_dir}/input.png") - os.remove(f"{output_dir}/output.png") - else: - os.makedirs(os.path.join(output_dir, "originals"), exist_ok=True) - shutil.move( - f"{output_dir}/input.png", - f"{os.path.join(output_dir, 'originals')}/input_{now}.png", - ) - shutil.move( - f"{output_dir}/output.png", - f"{os.path.join(output_dir, 'originals')}/output_{now}.png", - ) - - del _ - del ids - del recon - - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - - break # Exit the retry loop if there were no errors - - except RuntimeError as e: - if "out of memory" in str(e) and retries < args.max_retries: - retries += 1 - # print(f"Out of Memory. Retry #{retries}") - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - continue # Retry the loop - - else: - if "out of memory" not in str(e): - print(f"\n{e}") - else: - print(f"Skipping image {i} after {retries} retries due to out of memory error") - break # Exit the retry loop after too many retries - + vae_folder_validation(accelerator, vae, dataset, args=args, checkpoint_name=args.vae_path, save_originals=args.save_originals) if __name__ == "__main__": main() diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 6594e65..ae0bc89 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -1,4 +1,4 @@ -import torch +import torch, os from accelerate import Accelerator from diffusers.optimization import get_scheduler from ema_pytorch import EMA @@ -8,7 +8,10 @@ from torch.utils.data import DataLoader from torchvision.utils import make_grid, save_image from tqdm import tqdm - +from typing import Optional +from muse_maskgit_pytorch.utils import ( + vae_folder_validation, +) from muse_maskgit_pytorch.trainers.base_accelerated_trainer import ( BaseAcceleratedTrainer, get_optimizer, @@ -66,6 +69,7 @@ def __init__( use_8bit_adam=False, num_cycles=1, scheduler_power=1.0, + validation_folder_at_end_of_epoch: Optional[DataLoader] = None, args=None, ): super().__init__( @@ -91,6 +95,8 @@ def __init__( # we are going to use them later to save them to a config file. self.args = args + self.validation_folder_at_end_of_epoch = validation_folder_at_end_of_epoch + self.current_step = current_step # vae @@ -266,6 +272,7 @@ def train(self): else: proc_label = f"[P{self.accelerator.process_index:03d}][Worker]" + for epoch in range(self.current_step // len(self.dl), self.num_epochs): for img in self.dl: loss = 0.0 @@ -340,7 +347,11 @@ def train(self): ) logs["lr"] = self.lr_scheduler.get_last_lr()[0] - self.accelerator.log(logs, step=steps) + try: + self.accelerator.log(logs, step=steps) + except ConnectionResetError: + print ("There was an error with the Wandb connection. Retrying...") + self.accelerator.log(logs, step=steps) # update exponential moving averaged generator @@ -386,6 +397,15 @@ def train(self): self.steps += 1 + # + + if self.validation_folder_at_end_of_epoch: + vae_folder_validation(self.accelerator, self.model, self.validation_folder_at_end_of_epoch, + self.args, + checkpoint_name=os.path.join(self.results_dir, f'vae.{steps}.pt'), + + ) + # if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: # self.accelerator.print( # f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..." diff --git a/muse_maskgit_pytorch/utils.py b/muse_maskgit_pytorch/utils.py index 83b1cd0..3b96b1a 100644 --- a/muse_maskgit_pytorch/utils.py +++ b/muse_maskgit_pytorch/utils.py @@ -1,11 +1,15 @@ from __future__ import print_function import glob +import shutil import os import re - +import PIL import torch - +import hashlib +from tqdm import tqdm +from torchvision.utils import save_image +from datetime import datetime def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_image_size=False): """Gets the latest checkpoint paths for both the non-ema and ema VAEs. @@ -142,3 +146,101 @@ def remove_duplicate_weights(ema_state_dict, non_ema_state_dict): if key in non_ema_state_dict and torch.equal(ema_state_dict[key], non_ema_state_dict[key]): del ema_state_dict_copy[key] return ema_state_dict_copy + +def vae_folder_validation(accelerator, vae, dataset, args=None, checkpoint_name="vae", save_originals=False): + + # Create output directory and save input images and reconstructions as grids + output_dir = os.path.join(args.results_dir, "outputs", + os.path.basename(args.input_folder if args.input_folder else args.validation_folder_at_end_of_epoch)) + os.makedirs(output_dir, exist_ok=True) + + for i in tqdm(range(len(dataset))): + retries = 0 + while True: + try: + save_image(dataset[i], f"{output_dir}/input.png") + + try: + # encode + encoded, _, _ = vae.encode( + dataset[i][None].to( + "cpu" + if args.cpu + else accelerator.device + if args.gpu == 0 + else f"cuda:{args.gpu}" + ) + ) + except AttributeError: + # encode + encoded, _, _ = vae.encode( + dataset[i][None].to( + accelerator.device + if accelerator.device + else f"cuda:{args.gpu}" + ) + ) + + # decode + recon = vae.decode(encoded).squeeze(0) + recon = torch.clamp(recon, -1.0, 1.0) + save_image(recon, f"{output_dir}/output.png") + + # Load input and output images + input_image = PIL.Image.open(f"{output_dir}/input.png") + output_image = PIL.Image.open(f"{output_dir}/output.png") + + # Create horizontal grid with input and output images + grid_image = PIL.Image.new( + "RGB" if args.channels == 3 else "RGBA", + (input_image.width + output_image.width, input_image.height), + ) + grid_image.paste(input_image, (0, 0)) + grid_image.paste(output_image, (input_image.width, 0)) + + # Save grid + now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S") + hash = hashlib.sha1(input_image.tobytes()).hexdigest() + + filename = f"{hash}_{now}-{os.path.basename(checkpoint_name)}.png" + grid_image.save(f"{output_dir}/{filename}", format="PNG") + + if not save_originals: + # Remove input and output images after the grid was made. + os.remove(f"{output_dir}/input.png") + os.remove(f"{output_dir}/output.png") + else: + os.makedirs(os.path.join(output_dir, "originals"), exist_ok=True) + shutil.move( + f"{output_dir}/input.png", + f"{os.path.join(output_dir, 'originals')}/input_{now}.png", + ) + shutil.move( + f"{output_dir}/output.png", + f"{os.path.join(output_dir, 'originals')}/output_{now}.png", + ) + + del _ + del recon + + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + dataset[i][None].to("cpu") + + break # Exit the retry loop if there were no errors + + except RuntimeError as e: + if "out of memory" in str(e) and retries < args.max_retries: + retries += 1 + # print(f"Out of Memory. Retry #{retries}") + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + continue # Retry the loop + + else: + if "out of memory" not in str(e): + print(f"\n{e}") + else: + print(f"Skipping image {i} after {retries} retries due to out of memory error") + break # Exit the retry loop after too many retries \ No newline at end of file diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 3df69ad..ce72b64 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -613,7 +613,7 @@ def main(): else: ema_vae = None - print(f"Resuming VAE from latest checkpoint: {args.resume_path}") + print(f"Resuming VAE from latest checkpoint: {args.vae_path}") else: accelerator.print("Resuming VAE from: ", args.vae_path) ema_vae = None diff --git a/train_muse_vae.py b/train_muse_vae.py index 7ae1460..c336ccf 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -5,7 +5,7 @@ import wandb from accelerate.utils import ProjectConfiguration -from datasets import load_dataset +from datasets import load_dataset, Dataset, Image from omegaconf import OmegaConf from muse_maskgit_pytorch import ( @@ -319,6 +319,19 @@ action="store_true", help="Use F.mse_loss instead of F.l1_loss.", ) +parser.add_argument( + "--validation_folder_at_end_of_epoch", + type=str, + default=None, + help="Path to a folder containing images that will be used for validation/reconstruction." + " At the end of each Epoch this folder will be used for validation and reconstructions will be saved to a subfolder called 'outputs/validation'.", +) +parser.add_argument( + "--exclude_folders", + type=str, + default=None, + help="List of folders we want to exclude when doing reconstructions from an input folder.", +) @dataclass @@ -383,6 +396,9 @@ class Arguments: use_l2_recon_loss: bool = False debug: bool = False config_path: Optional[str] = None + validation_folder_at_end_of_epoch: Optional[str] = None + input_folder = None + exclude_folders: Optional[str] = None def preprocess_webdataset(args, image): @@ -575,6 +591,43 @@ def main(): dataloader, validation_dataloader = split_dataset_into_dataloaders( dataset, args.valid_frac, args.seed, args.batch_size ) + + if args.validation_folder_at_end_of_epoch: + # Create dataset from input folder + extensions = ["jpg", "jpeg", "png", "webp"] + exclude_folders = args.exclude_folders.split(",") if args.exclude_folders else [] + + filepaths = [] + for root, dirs, files in os.walk(args.validation_folder_at_end_of_epoch, followlinks=True): + # Resolve symbolic link to actual path and exclude based on actual path + resolved_root = os.path.realpath(root) + for exclude_folder in exclude_folders: + if exclude_folder in resolved_root: + dirs[:] = [] + break + for file in files: + if file.lower().endswith(tuple(extensions)): + filepaths.append(os.path.join(root, file)) + + if not filepaths: + print(f"No images with extensions {extensions} found in {args.validation_folder_at_end_of_epoch}.") + exit(1) + + epoch_validation_dataset = Dataset.from_dict({"image": filepaths}).cast_column("image", Image()) + + epoch_validation_dataset = ImageDataset( + epoch_validation_dataset, + image_size=512, + image_column=args.image_column, + center_crop=False, + flip=False, + random_crop=False, + alpha_channel=False if args.channels == 3 else True, + ) + + else: + epoch_validation_dataset = None + trainer = VQGanVAETrainer( vae, dataloader, @@ -606,6 +659,7 @@ def main(): num_cycles=args.num_cycles, scheduler_power=args.scheduler_power, num_epochs=args.num_epochs, + validation_folder_at_end_of_epoch=epoch_validation_dataset, args=args, )