-
Notifications
You must be signed in to change notification settings - Fork 1
/
generate_images.py
62 lines (52 loc) · 1.82 KB
/
generate_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from PIL import Image, ImageDraw
from os import path as osp
import os
import numpy.random as rdn
class ImageGenerator:
"""
Class to generate adversarial task images as mentioned in the paper.
"""
def __init__(self, bgcolor=(0,0,0), output_dir="data/", num_classes=10, num_squares=16):
self.bgcolor = bgcolor
self.count = 0
self.out_dir = output_dir
self.num_classes = list(range(1,num_classes+1))
self.num_squares = list(range(1,num_squares+1))
self._test_out_dir()
def save_output(self, img, label):
img.save(osp.join(self.out_dir, str(self.count)+".jpg"))
print("Saving {}".format(osp.join(self.out_dir, str(self.count)+".jpg")))
with open(osp.join(self.out_dir, str(self.count)+".txt"), "w") as f:
f.write(str(label))
def get_prob(self):
weights = [0.1] * 10
label = int(rdn.choice(self.num_classes, 1, p=weights))
where = rdn.choice(self.num_squares, label, replace=False)
return label, where
def draw_squares(self, img):
label, where = self.get_prob()
draw = ImageDraw.Draw(img)
place_map = {1:[0,0,9,9],2:[18,0,27,9],3:[9,9,18,18],4:[27,9,36,18],5:[0,18,9,27],\
6:[9,18,18,27],7:[27,0,36,9],8:[18,18,27,27],9:[18,27,27,36],10:[9,27,18,36],11:[9,0,18,9],\
12:[18,9,27,18],13:[27,27,36,36],14:[0,9,9,18],15:[0,27,9,36],16:[27,18,36,27]}
for i in range(label):
place = where[i]
grid = place_map[place]
draw.rectangle(grid, fill=(255,255,255), outline=(0,0,0))
return img, label
def _test_out_dir(self):
if not osp.exists(self.out_dir):
os.mkdir(self.out_dir)
def generate(self, num=10):
for i in range(num):
img = Image.new('RGB', (36,36), color=(0,0,0))
#draw rectangle
img, label = self.draw_squares(img)
#save img
self.save_output(img, label)
self.count += 1
def main():
generator = ImageGenerator()
generator.generate()
if __name__ == "__main__":
main()