-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeal_matlab_img.py
79 lines (68 loc) · 2.75 KB
/
deal_matlab_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/python
#coding=utf-8
''' deal with matlab img data is from
http://ufldl.stanford.edu/housenumbers/train_32x32.mat
http://ufldl.stanford.edu/housenumbers/test_32x32.mat
install this numpy in windows resolve from scipy.linalg import _fblas
http://www.lfd.uci.edu/~gohlke/pythonlibs/#numpy
'''
import time
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
def reformat(samples, labels):
'''
( 0 1 2 3) (3 0 1 2)
(图片高,图片宽,通道数,图片数) -> (图片数,图片高,图片宽,通道数)'''
new = np.transpose(samples, (3, 0, 1, 2)).astype(np.float32)
# labels 变成 one-hot encoding,[2] -> [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
# digit 0 , represented as 10
# labels 变成 one-hot encoding,[10] -> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
labels = np.array([x[0] for x in labels])
one_hot_labels = []
for num in labels:
one_hot = [0.0] * 10
if num == 10:
one_hot[0] = 1.0
else:
one_hot[num] = 1.0
one_hot_labels.append(one_hot)
labels = np.array(one_hot_labels).astype(np.float32)
return new, labels
def normalize(samples):
"""
灰度化: 从三色通道 -> 单色通道 省内存 ,加快训练速度
(R + G + B) / 3
将图片从 0 ~ 255 线性映射到-1.0 ~ +1.0 """
rgb = np.add.reduce(samples, keepdims=True, axis=3)
# shape (图片数,图片高,图片宽,通道数),将rgb三个值,reduce到一个float里面
rgb = rgb / 3.0
return rgb / 128.0 - 1.0
def inspect(dataset, labels, i):
''' 将图片显示出来 如果第四维的维度为1,说明rgb的3维合并到一维了,
就可以取消掉最后的维度,reshape成三维'''
if dataset.shape[3] == 1:
shape = dataset.shape
dataset = dataset.reshape(shape[0], shape[1], shape[2])
print labels[i]
plt.figure(figsize=(3, 3), dpi=80)
plt.imshow(dataset[i])
plt.show()
def getdata(nslicenum):
''' get data from mat'''
train = loadmat('./train_32x32.mat')
print 'train ', train['X'].shape, train['y'].shape
print type(train['X']), type(train['y'])
train_samples = train['X'][:, :, :, 0:nslicenum].copy()
train_labels = train['y'][0:nslicenum, :].copy()
print 'slice ', train_samples.shape, train_labels.shape
del train
return train_samples, train_labels
def main():
''' go '''
train_samples, train_labels = getdata(5000)
_train_samples, _train_labels = reformat(train_samples, train_labels)
normalize(_train_samples)
inspect(_train_samples, _train_labels, 993)
if __name__ == '__main__':
main()