-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
95 lines (82 loc) · 3.18 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Original Version: Taehoon Kim (http://carpedm20.github.io)
# + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/utils.py
# + License: MIT
"""
Some codes from https://github.com/Newmu/dcgan_code
"""
from __future__ import division
import scipy.misc
import scipy.io
import numpy as np
import os
import sys
import tarfile
import zipfile
from six.moves import urllib
def get_vgg_model(dir_path, model_url):
maybe_download_and_extract(dir_path, model_url)
filename = model_url.split("/")[-1]
filepath = os.path.join(dir_path, filename)
if not os.path.exists(filepath):
raise IOError("VGG Model not found!")
data = scipy.io.loadmat(filepath)
return data
def maybe_download_and_extract(dir_path, url_name, is_tarfile=False, is_zipfile=False):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
filename = url_name.split('/')[-1]
filepath = os.path.join(dir_path, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write(
'\r>> Downloading %s %.1f%%' % (filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(url_name, filepath, reporthook=_progress)
print()
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
if is_tarfile:
tarfile.open(filepath, 'r:gz').extractall(dir_path)
elif is_zipfile:
with zipfile.ZipFile(filepath) as zf:
zf.extractall(dir_path)
def get_image(image_path, image_size, is_crop=True):
return transform(imread(image_path), image_size, is_crop)
def save_images(images, size, image_path):
return imsave(inverse_transform(images), size, image_path)
def save_image(image, save_dir, name):
scipy.misc.imsave(os.path.join(save_dir, name + '.png'), image)
def imread(path, size=None):
input = scipy.misc.imread(path, mode='RGB').astype(np.float32)
if not size:
return input
else:
return scipy.misc.imresize(input, (size, size))
def merge(images, size):
h, w = images.shape[1], images.shape[2]
img = np.zeros((int(h * size[0]), int(w * size[1]), 3))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j*h:j*h+h, i*w:i*w+w, :] = image
return img
def imsave(images, size, path):
img = merge(images, size)
return scipy.misc.imsave(path, (255*img).astype(np.uint8))
def center_crop(x, crop_h, crop_w=None, resize_w=64):
if crop_w is None:
crop_w = crop_h
h, w = x.shape[:2]
j = int(round((h - crop_h)/2.))
i = int(round((w - crop_w)/2.))
return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w],
[resize_w, resize_w])
def transform(image, npx=64, is_crop=True, resize_w=64):
# npx : # of pixels width/height of image
if is_crop:
cropped_image = center_crop(image, npx, resize_w=resize_w)
else:
cropped_image = image
return np.array(cropped_image)/127.5 - 1.
def inverse_transform(images):
return (images+1.)/2.