diff --git a/infer_vae.py b/infer_vae.py index eff1456..d1fb903 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -1,5 +1,10 @@ -import argparse, glob, hashlib -import os, random, re, shutil +import argparse +import glob +import hashlib +import os +import random +import re +import shutil from dataclasses import dataclass from datetime import datetime from typing import Optional @@ -362,7 +367,7 @@ def main(): channels=args.channels, layers=args.layers, discr_layers=args.discr_layers, - ).to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + ).to("cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") if args.latest_checkpoint: accelerator.print("Finding latest checkpoint...") @@ -427,7 +432,7 @@ def main(): args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2 # move vae to device - vae = vae.to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + vae = vae.to("cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") # Use the parameters() method to get an iterator over all the learnable parameters of the model total_params = sum(p.numel() for p in vae.parameters()) @@ -458,7 +463,9 @@ def main(): ) _, ids, _ = vae.encode( - dataset[image_id][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + dataset[image_id][None].to( + "cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}" + ) ) recon = vae.decode_from_ids(ids) save_image(recon, f"{args.results_dir}/outputs/output.{str(args.input_image).split('.')[-1]}") @@ -471,7 +478,9 @@ def main(): save_image(dataset[image_id], f"{args.results_dir}/outputs/input.png") _, ids, _ = vae.encode( - dataset[image_id][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + dataset[image_id][None].to( + "cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}" + ) ) recon = vae.decode_from_ids(ids) save_image(recon, f"{args.results_dir}/outputs/output.png") @@ -490,7 +499,13 @@ def main(): if not args.use_paintmind: # encode _, ids, _ = vae.encode( - dataset[i][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + dataset[i][None].to( + "cpu" + if args.cpu + else accelerator.device + if args.gpu == 0 + else f"cuda:{args.gpu}" + ) ) # decode recon = vae.decode_from_ids(ids) @@ -499,7 +514,13 @@ def main(): else: # encode encoded, _, _ = vae.encode( - dataset[i][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + dataset[i][None].to( + "cpu" + if args.cpu + else accelerator.device + if args.gpu == 0 + else f"cuda:{args.gpu}" + ) ) # decode @@ -531,10 +552,15 @@ def main(): os.remove(f"{output_dir}/input.png") os.remove(f"{output_dir}/output.png") else: - os.makedirs(os.path.join(output_dir, 'originals'), exist_ok=True) - shutil.move(f"{output_dir}/input.png", f"{os.path.join(output_dir, 'originals')}/input_{now}.png") - shutil.move(f"{output_dir}/output.png", f"{os.path.join(output_dir, 'originals')}/output_{now}.png") - + os.makedirs(os.path.join(output_dir, "originals"), exist_ok=True) + shutil.move( + f"{output_dir}/input.png", + f"{os.path.join(output_dir, 'originals')}/input_{now}.png", + ) + shutil.move( + f"{output_dir}/output.png", + f"{os.path.join(output_dir, 'originals')}/output_{now}.png", + ) del _ del ids