-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_utils.py
89 lines (73 loc) · 2.68 KB
/
image_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from __future__ import print_function
from future import standard_library
standard_library.install_aliases()
from builtins import range
import urllib.request, urllib.error, urllib.parse, os, tempfile
import numpy as np
from scipy.misc import imread, imresize
"""
Utility functions used for viewing and processing images.
"""
def blur_image(X):
"""
A very gentle image blurring operation, to be used as a regularizer for
image generation.
Inputs:
- X: Image data of shape (N, 3, H, W)
Returns:
- X_blur: Blurred version of X, of shape (N, 3, H, W)
"""
from cs231n.fast_layers import conv_forward_fast
w_blur = np.zeros((3, 3, 3, 3))
b_blur = np.zeros(3)
blur_param = {'stride': 1, 'pad': 1}
for i in range(3):
w_blur[i, i] = np.asarray([[1, 2, 1], [2, 188, 2], [1, 2, 1]],
dtype=np.float32)
w_blur /= 200.0
return conv_forward_fast(X, w_blur, b_blur, blur_param)[0]
SQUEEZENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
SQUEEZENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def preprocess_image(img):
"""Preprocess an image for squeezenet.
Subtracts the pixel mean and divides by the standard deviation.
"""
return (img.astype(np.float32)/255.0 - SQUEEZENET_MEAN) / SQUEEZENET_STD
def deprocess_image(img, rescale=False):
"""Undo preprocessing on an image and convert back to uint8."""
img = (img * SQUEEZENET_STD + SQUEEZENET_MEAN)
if rescale:
vmin, vmax = img.min(), img.max()
img = (img - vmin) / (vmax - vmin)
return np.clip(255 * img, 0.0, 255.0).astype(np.uint8)
def image_from_url(url):
"""
Read an image from a URL. Returns a numpy array with the pixel data.
We write the image to a temporary file then read it back. Kinda gross.
"""
try:
f = urllib.request.urlopen(url)
_, fname = tempfile.mkstemp()
with open(fname, 'wb') as ff:
ff.write(f.read())
img = imread(fname)
os.remove(fname)
return img
except urllib.error.URLError as e:
print('URL Error: ', e.reason, url)
except urllib.error.HTTPError as e:
print('HTTP Error: ', e.code, url)
def load_image(filename, size=None):
"""Load and resize an image from disk.
Inputs:
- filename: path to file
- size: size of shortest dimension after rescaling
"""
img = imread(filename)
if size is not None:
orig_shape = np.array(img.shape[:2])
min_idx = np.argmin(orig_shape)
scale_factor = float(size) / orig_shape[min_idx]
new_shape = (orig_shape * scale_factor).astype(int)
img = imresize(img, scale_factor)
return img