Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 14, 2023
1 parent 9533f48 commit 446a465
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions infer_vae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import argparse, glob, hashlib
import os, random, re, shutil
import argparse
import glob
import hashlib
import os
import random
import re
import shutil
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
Expand Down Expand Up @@ -362,7 +367,7 @@ def main():
channels=args.channels,
layers=args.layers,
discr_layers=args.discr_layers,
).to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
).to("cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")

if args.latest_checkpoint:
accelerator.print("Finding latest checkpoint...")
Expand Down Expand Up @@ -427,7 +432,7 @@ def main():
args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2

# move vae to device
vae = vae.to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
vae = vae.to("cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")

# Use the parameters() method to get an iterator over all the learnable parameters of the model
total_params = sum(p.numel() for p in vae.parameters())
Expand Down Expand Up @@ -458,7 +463,9 @@ def main():
)

_, ids, _ = vae.encode(
dataset[image_id][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
dataset[image_id][None].to(
"cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}"
)
)
recon = vae.decode_from_ids(ids)
save_image(recon, f"{args.results_dir}/outputs/output.{str(args.input_image).split('.')[-1]}")
Expand All @@ -471,7 +478,9 @@ def main():
save_image(dataset[image_id], f"{args.results_dir}/outputs/input.png")

_, ids, _ = vae.encode(
dataset[image_id][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
dataset[image_id][None].to(
"cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}"
)
)
recon = vae.decode_from_ids(ids)
save_image(recon, f"{args.results_dir}/outputs/output.png")
Expand All @@ -490,7 +499,13 @@ def main():
if not args.use_paintmind:
# encode
_, ids, _ = vae.encode(
dataset[i][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
dataset[i][None].to(
"cpu"
if args.cpu
else accelerator.device
if args.gpu == 0
else f"cuda:{args.gpu}"
)
)
# decode
recon = vae.decode_from_ids(ids)
Expand All @@ -499,7 +514,13 @@ def main():
else:
# encode
encoded, _, _ = vae.encode(
dataset[i][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
dataset[i][None].to(
"cpu"
if args.cpu
else accelerator.device
if args.gpu == 0
else f"cuda:{args.gpu}"
)
)

# decode
Expand Down Expand Up @@ -531,10 +552,15 @@ def main():
os.remove(f"{output_dir}/input.png")
os.remove(f"{output_dir}/output.png")
else:
os.makedirs(os.path.join(output_dir, 'originals'), exist_ok=True)
shutil.move(f"{output_dir}/input.png", f"{os.path.join(output_dir, 'originals')}/input_{now}.png")
shutil.move(f"{output_dir}/output.png", f"{os.path.join(output_dir, 'originals')}/output_{now}.png")

os.makedirs(os.path.join(output_dir, "originals"), exist_ok=True)
shutil.move(
f"{output_dir}/input.png",
f"{os.path.join(output_dir, 'originals')}/input_{now}.png",
)
shutil.move(
f"{output_dir}/output.png",
f"{os.path.join(output_dir, 'originals')}/output_{now}.png",
)

del _
del ids
Expand Down

0 comments on commit 446a465

Please sign in to comment.