forked from rh01/unet-caffe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mydatalayer.py
83 lines (68 loc) · 2.5 KB
/
mydatalayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import caffe
import numpy as np
import cv2
import numpy.random as random
class DataLayer(caffe.Layer):
def setup(self, bottom, top):
self.imgdir = "/root/UNet/data/train/image/"
self.maskdir = "/root/UNet/data/train/label/"
self.imgtxt = "/root/UNet/data/img.txt"
self.random = True
self.seed = None
if len(top) != 2:
raise Exception("Need to define two tops: data and mask.")
if len(bottom) != 0:
raise Exception("Do not define a bottom.")
self.lines = open(self.imgtxt, 'r').readlines()
self.idx = 0
if self.random:
random.seed(self.seed)
self.idx = random.randint(0, len(self.lines) - 1)
def reshape(self, bottom, top):
# load image + label image pair
self.data = self.load_image(self.idx)
self.mask = self.load_mask(self.idx)
# reshape tops to fit (leading 1 is for batch dimension)
top[0].reshape(1, *self.data.shape)
top[1].reshape(1, *self.mask.shape)
def forward(self, bottom, top):
# assign output
top[0].data[...] = self.data
top[1].data[...] = self.mask
# pick next input
if self.random:
self.idx = random.randint(0, len(self.lines) - 1)
else:
self.idx += 1
if self.idx == len(self.lines):
self.idx = 0
def backward(self, top, propagate_down, bottom):
pass
def load_image(self, idx):
imname = self.imgdir + self.lines[idx]
imname = imname.strip()
#imname = imname[:-2]
print 'load img %s?' %imname
im = cv2.imread(imname)
#im = cv2.imread(imname)
#print im.shape
im = cv2.resize(im,(572,572))
im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
im = np.array(im, np.float64)
im /= 255.0
im -= 0.5
return im[np.newaxis, :]
def load_mask(self, idx):
outimg = np.empty((2,572,572))
imname = self.maskdir + self.lines[idx]
imname = imname.strip()
print 'load mask %s' %imname
im = cv2.imread(imname)
im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
im = cv2.resize(im,(572,572))
ret, img = cv2.threshold(im, 0.5, 1.0, cv2.THRESH_BINARY)
#ret, back = cv2.threshold(im, 0.5, 1.0, cv2.THRESH_BINARY_INV)
#outimg[0, ...] = img;
#outimg[1, ...] = back;
#outimg.astype(np.uint8)
return img[np.newaxis, :]