Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#73 from ZeroCool940711/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroCool940711 authored Sep 16, 2023
2 parents 70c7bde + 32d1acb commit c219b44
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 114 deletions.
109 changes: 2 additions & 107 deletions infer_vae.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import argparse
import glob
import hashlib
import os
import random
import re
import shutil
from dataclasses import dataclass
from datetime import datetime
from typing import Optional

import accelerate
import PIL
import torch
from accelerate.utils import ProjectConfiguration
from datasets import Dataset, Image, load_dataset
from torchvision.utils import save_image
from tqdm import tqdm

from muse_maskgit_pytorch import (
VQGanVAE,
Expand All @@ -28,6 +21,7 @@
)
from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
vae_folder_validation,
)
from muse_maskgit_pytorch.vqvae import VQVAE

Expand Down Expand Up @@ -458,106 +452,7 @@ def main():
save_image(recon, f"{args.results_dir}/outputs/output.png")

if args.input_folder:
# Create output directory and save input images and reconstructions as grids
output_dir = os.path.join(args.results_dir, "outputs", os.path.basename(args.input_folder))
os.makedirs(output_dir, exist_ok=True)

for i in tqdm(range(len(dataset))):
retries = 0
while True:
try:
save_image(dataset[i], f"{output_dir}/input.png")

if not args.use_paintmind:
# encode
_, ids, _ = vae.encode(
dataset[i][None].to(
"cpu"
if args.cpu
else accelerator.device
if args.gpu == 0
else f"cuda:{args.gpu}"
)
)
# decode
recon = vae.decode_from_ids(ids)
# print (recon.shape) # torch.Size([1, 3, 512, 1136])
save_image(recon, f"{output_dir}/output.png")
else:
# encode
encoded, _, _ = vae.encode(
dataset[i][None].to(
"cpu"
if args.cpu
else accelerator.device
if args.gpu == 0
else f"cuda:{args.gpu}"
)
)

# decode
recon = vae.decode(encoded).squeeze(0)
recon = torch.clamp(recon, -1.0, 1.0)
save_image(recon, f"{output_dir}/output.png")

# Load input and output images
input_image = PIL.Image.open(f"{output_dir}/input.png")
output_image = PIL.Image.open(f"{output_dir}/output.png")

# Create horizontal grid with input and output images
grid_image = PIL.Image.new(
"RGB" if args.channels == 3 else "RGBA",
(input_image.width + output_image.width, input_image.height),
)
grid_image.paste(input_image, (0, 0))
grid_image.paste(output_image, (input_image.width, 0))

# Save grid
now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
hash = hashlib.sha1(input_image.tobytes()).hexdigest()

filename = f"{hash}_{now}-{os.path.basename(args.vae_path)}.png"
grid_image.save(f"{output_dir}/{filename}", format="PNG")

if not args.save_originals:
# Remove input and output images after the grid was made.
os.remove(f"{output_dir}/input.png")
os.remove(f"{output_dir}/output.png")
else:
os.makedirs(os.path.join(output_dir, "originals"), exist_ok=True)
shutil.move(
f"{output_dir}/input.png",
f"{os.path.join(output_dir, 'originals')}/input_{now}.png",
)
shutil.move(
f"{output_dir}/output.png",
f"{os.path.join(output_dir, 'originals')}/output_{now}.png",
)

del _
del ids
del recon

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

break # Exit the retry loop if there were no errors

except RuntimeError as e:
if "out of memory" in str(e) and retries < args.max_retries:
retries += 1
# print(f"Out of Memory. Retry #{retries}")
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
continue # Retry the loop

else:
if "out of memory" not in str(e):
print(f"\n{e}")
else:
print(f"Skipping image {i} after {retries} retries due to out of memory error")
break # Exit the retry loop after too many retries

vae_folder_validation(accelerator, vae, dataset, args=args, checkpoint_name=args.vae_path, save_originals=args.save_originals)

if __name__ == "__main__":
main()
26 changes: 23 additions & 3 deletions muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch
import torch, os
from accelerate import Accelerator
from diffusers.optimization import get_scheduler
from ema_pytorch import EMA
Expand All @@ -8,7 +8,10 @@
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
from tqdm import tqdm

from typing import Optional
from muse_maskgit_pytorch.utils import (
vae_folder_validation,
)
from muse_maskgit_pytorch.trainers.base_accelerated_trainer import (
BaseAcceleratedTrainer,
get_optimizer,
Expand Down Expand Up @@ -66,6 +69,7 @@ def __init__(
use_8bit_adam=False,
num_cycles=1,
scheduler_power=1.0,
validation_folder_at_end_of_epoch: Optional[DataLoader] = None,
args=None,
):
super().__init__(
Expand All @@ -91,6 +95,8 @@ def __init__(
# we are going to use them later to save them to a config file.
self.args = args

self.validation_folder_at_end_of_epoch = validation_folder_at_end_of_epoch

self.current_step = current_step

# vae
Expand Down Expand Up @@ -266,6 +272,7 @@ def train(self):
else:
proc_label = f"[P{self.accelerator.process_index:03d}][Worker]"


for epoch in range(self.current_step // len(self.dl), self.num_epochs):
for img in self.dl:
loss = 0.0
Expand Down Expand Up @@ -340,7 +347,11 @@ def train(self):
)

logs["lr"] = self.lr_scheduler.get_last_lr()[0]
self.accelerator.log(logs, step=steps)
try:
self.accelerator.log(logs, step=steps)
except ConnectionResetError:
print ("There was an error with the Wandb connection. Retrying...")
self.accelerator.log(logs, step=steps)

# update exponential moving averaged generator

Expand Down Expand Up @@ -386,6 +397,15 @@ def train(self):

self.steps += 1

#

if self.validation_folder_at_end_of_epoch:
vae_folder_validation(self.accelerator, self.model, self.validation_folder_at_end_of_epoch,
self.args,
checkpoint_name=os.path.join(self.results_dir, f'vae.{steps}.pt'),

)

# if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps:
# self.accelerator.print(
# f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..."
Expand Down
106 changes: 104 additions & 2 deletions muse_maskgit_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import print_function

import glob
import shutil
import os
import re

import PIL
import torch

import hashlib
from tqdm import tqdm
from torchvision.utils import save_image
from datetime import datetime

def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_image_size=False):
"""Gets the latest checkpoint paths for both the non-ema and ema VAEs.
Expand Down Expand Up @@ -142,3 +146,101 @@ def remove_duplicate_weights(ema_state_dict, non_ema_state_dict):
if key in non_ema_state_dict and torch.equal(ema_state_dict[key], non_ema_state_dict[key]):
del ema_state_dict_copy[key]
return ema_state_dict_copy

def vae_folder_validation(accelerator, vae, dataset, args=None, checkpoint_name="vae", save_originals=False):

# Create output directory and save input images and reconstructions as grids
output_dir = os.path.join(args.results_dir, "outputs",
os.path.basename(args.input_folder if args.input_folder else args.validation_folder_at_end_of_epoch))
os.makedirs(output_dir, exist_ok=True)

for i in tqdm(range(len(dataset))):
retries = 0
while True:
try:
save_image(dataset[i], f"{output_dir}/input.png")

try:
# encode
encoded, _, _ = vae.encode(
dataset[i][None].to(
"cpu"
if args.cpu
else accelerator.device
if args.gpu == 0
else f"cuda:{args.gpu}"
)
)
except AttributeError:
# encode
encoded, _, _ = vae.encode(
dataset[i][None].to(
accelerator.device
if accelerator.device
else f"cuda:{args.gpu}"
)
)

# decode
recon = vae.decode(encoded).squeeze(0)
recon = torch.clamp(recon, -1.0, 1.0)
save_image(recon, f"{output_dir}/output.png")

# Load input and output images
input_image = PIL.Image.open(f"{output_dir}/input.png")
output_image = PIL.Image.open(f"{output_dir}/output.png")

# Create horizontal grid with input and output images
grid_image = PIL.Image.new(
"RGB" if args.channels == 3 else "RGBA",
(input_image.width + output_image.width, input_image.height),
)
grid_image.paste(input_image, (0, 0))
grid_image.paste(output_image, (input_image.width, 0))

# Save grid
now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
hash = hashlib.sha1(input_image.tobytes()).hexdigest()

filename = f"{hash}_{now}-{os.path.basename(checkpoint_name)}.png"
grid_image.save(f"{output_dir}/{filename}", format="PNG")

if not save_originals:
# Remove input and output images after the grid was made.
os.remove(f"{output_dir}/input.png")
os.remove(f"{output_dir}/output.png")
else:
os.makedirs(os.path.join(output_dir, "originals"), exist_ok=True)
shutil.move(
f"{output_dir}/input.png",
f"{os.path.join(output_dir, 'originals')}/input_{now}.png",
)
shutil.move(
f"{output_dir}/output.png",
f"{os.path.join(output_dir, 'originals')}/output_{now}.png",
)

del _
del recon

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

dataset[i][None].to("cpu")

break # Exit the retry loop if there were no errors

except RuntimeError as e:
if "out of memory" in str(e) and retries < args.max_retries:
retries += 1
# print(f"Out of Memory. Retry #{retries}")
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
continue # Retry the loop

else:
if "out of memory" not in str(e):
print(f"\n{e}")
else:
print(f"Skipping image {i} after {retries} retries due to out of memory error")
break # Exit the retry loop after too many retries
2 changes: 1 addition & 1 deletion train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def main():
else:
ema_vae = None

print(f"Resuming VAE from latest checkpoint: {args.resume_path}")
print(f"Resuming VAE from latest checkpoint: {args.vae_path}")
else:
accelerator.print("Resuming VAE from: ", args.vae_path)
ema_vae = None
Expand Down
Loading

0 comments on commit c219b44

Please sign in to comment.