forked from torch/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path1_data.lua
204 lines (177 loc) · 5.85 KB
/
1_data.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
local data_verbose = false
function getdata(datafile, inputsize, std)
local data = torch.load(datafile, 'ascii')
local dataset ={}
local std = std or 0.2
local nsamples = data:size(1)
local nrows = data:size(2)
local ncols = data:size(3)
function dataset:size()
return nsamples
end
function dataset:selectPatch(nr,nc)
local imageok = false
if simdata_verbose then
print('selectPatch')
end
while not imageok do
--image index
local i = math.ceil(torch.uniform(1e-12,nsamples))
local im = data:select(1,i)
-- select some patch for original that contains original + pos
local ri = math.ceil(torch.uniform(1e-12,nrows-nr))
local ci = math.ceil(torch.uniform(1e-12,ncols-nc))
local patch = im:narrow(1,ri,nr)
patch = patch:narrow(2,ci,nc)
local patchstd = patch:std()
if data_verbose then
print('Image ' .. i .. ' ri= ' .. ri .. ' ci= ' .. ci .. ' std= ' .. patchstd)
end
if patchstd > std then
if data_verbose then
print(patch:min(),patch:max())
end
return patch,i,im
end
end
end
local dsample = torch.Tensor(inputsize*inputsize)
function dataset:conv()
dsample = torch.Tensor(1,inputsize,inputsize)
end
setmetatable(dataset, {__index = function(self, index)
local sample,i,im = self:selectPatch(inputsize, inputsize)
dsample:copy(sample)
return {dsample,dsample,im}
end})
return dataset
end
function getdatacam(inputsize, std)
require 'camera'
local frow = 60
local fcol = 80
local gs = 5
local cam = image.Camera{width=fcol,height=frow}
local dataset ={}
local counter = 1
local std = std or 0.2
local nsamples = 10000
local gfh = image.gaussian{width=gs,height=1,normalize=true}
local gfv = image.gaussian{width=1,height=gs,normalize=true}
local gf = image.gaussian{width=gs,height=gs,normalize=true}
function dataset:size()
return nsamples
end
local imsq = torch.Tensor()
local lmnh = torch.Tensor()
local lmn = torch.Tensor()
local lmnsqh = torch.Tensor()
local lmnsq = torch.Tensor()
local lvar = torch.Tensor()
local function lcn(im)
local mn = im:mean()
local std = im:std()
if data_verbose then
print('im',mn,std,im:min(),im:max())
end
im:add(-mn)
im:div(std)
if data_verbose then
print('im',im:min(),im:max(),im:mean(), im:std())
end
imsq:resizeAs(im):copy(im):cmul(im)
if data_verbose then
print('imsq',imsq:min(),imsq:max())
end
torch.conv2(lmnh,im,gfh)
torch.conv2(lmn,lmnh,gfv)
if data_verbose then
print('lmn',lmn:min(),lmn:max())
end
--local lmn = torch.conv2(im,gf)
torch.conv2(lmnsqh,imsq,gfh)
torch.conv2(lmnsq,lmnsqh,gfv)
if data_verbose then
print('lmnsq',lmnsq:min(),lmnsq:max())
end
lvar:resizeAs(lmn):copy(lmn):cmul(lmn)
lvar:mul(-1)
lvar:add(lmnsq)
if data_verbose then
print('2',lvar:min(),lvar:max())
end
lvar:apply(function (x) if x<0 then return 0 else return x end end)
if data_verbose then
print('2',lvar:min(),lvar:max())
end
local lstd = lvar
lstd:sqrt()
lstd:apply(function (x) if x<1 then return 1 else return x end end)
if data_verbose then
print('lstd',lstd:min(),lstd:max())
end
local shift = (gs+1)/2
local nim = im:narrow(1,shift,im:size(1)-(gs-1)):narrow(2,shift,im:size(2)-(gs-1))
nim:add(-1,lmn)
nim:cdiv(lstd)
if data_verbose then
print('nim',nim:min(),nim:max())
end
return nim
end
function dataset:selectPatch(nr,nc)
local imageok = false
if simdata_verbose then
print('selectPatch')
end
counter = counter + 1
local imgray = image.rgb2y(cam:forward())
local nim = lcn(imgray[1]:clone())
while not imageok do
-- select some patch for original that contains original + pos
local ri = math.ceil(torch.uniform(1e-12,nim:size(1)-nr))
local ci = math.ceil(torch.uniform(1e-12,nim:size(2)-nc))
local patch = nim:narrow(1,ri,nr)
patch = patch:narrow(2,ci,nc)
local patchstd = patch:std()
if data_verbose then
print('Image ' .. 0 .. ' ri= ' .. ri .. ' ci= ' .. ci .. ' std= ' .. patchstd)
end
if patchstd > std then
if data_verbose then
print(patch:min(),patch:max())
end
return patch,i,nim
end
end
end
local dsample = torch.Tensor(inputsize*inputsize)
setmetatable(dataset, {__index = function(self, index)
local sample,i,im = self:selectPatch(inputsize, inputsize)
dsample:copy(sample)
return {dsample,dsample,im}
end})
return dataset
end
-- dataset, dataset=createDataset(....)
-- nsamples, how many samples to display from dataset
-- nrow, number of samples per row for displaying samples
-- zoom, zoom at which to draw dataset
function displayData(dataset, nsamples, nrow, zoom)
require 'image'
local nsamples = nsamples or 100
local zoom = zoom or 1
local nrow = nrow or 10
cntr = 1
local ex = {}
for i=1,nsamples do
local exx = dataset[1]
ex[cntr] = exx[1]:clone():unfold(1,math.sqrt(exx[1]:size(1)),math.sqrt(exx[1]:size(1)))
cntr = cntr + 1
end
if itorch then
itorch.image(ex)
else
print('For visualization, run the script in itorch')
end
end