forked from IBM/tensorflow-hangul-recognition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hangul-image-generator.py
executable file
·133 lines (108 loc) · 4.82 KB
/
hangul-image-generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python
import argparse
import glob
import io
import os
import random
import numpy
from PIL import Image, ImageFont, ImageDraw
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
# Default data paths.
DEFAULT_LABEL_FILE = os.path.join(SCRIPT_PATH,
'../labels/2350-common-hangul.txt')
DEFAULT_FONTS_DIR = os.path.join(SCRIPT_PATH, '../fonts')
DEFAULT_OUTPUT_DIR = os.path.join(SCRIPT_PATH, '../image-data')
# Number of random distortion images to generate per font and character.
DISTORTION_COUNT = 3
# Width and height of the resulting image.
IMAGE_WIDTH = 64
IMAGE_HEIGHT = 64
def generate_hangul_images(label_file, fonts_dir, output_dir):
"""Generate Hangul image files.
This will take in the passed in labels file and will generate several
images using the font files provided in the font directory. The font
directory is expected to be populated with *.ttf (True Type Font) files.
The generated images will be stored in the given output directory. Image
paths will have their corresponding labels listed in a CSV file.
"""
with io.open(label_file, 'r', encoding='utf-8') as f:
labels = f.read().splitlines()
image_dir = os.path.join(output_dir, 'hangul-images')
if not os.path.exists(image_dir):
os.makedirs(os.path.join(image_dir))
# Get a list of the fonts.
fonts = glob.glob(os.path.join(fonts_dir, '*.ttf'))
labels_csv = io.open(os.path.join(output_dir, 'labels-map.csv'), 'w',
encoding='utf-8')
total_count = 0
prev_count = 0
for character in labels:
# Print image count roughly every 5000 images.
if total_count - prev_count > 5000:
prev_count = total_count
print('{} images generated...'.format(total_count))
for font in fonts:
total_count += 1
image = Image.new('L', (IMAGE_WIDTH, IMAGE_HEIGHT), color=0)
font = ImageFont.truetype(font, 48)
drawing = ImageDraw.Draw(image)
w, h = drawing.textsize(character, font=font)
drawing.text(
((IMAGE_WIDTH-w)/2, (IMAGE_HEIGHT-h)/2),
character,
fill=(255),
font=font
)
file_string = 'hangul_{}.jpeg'.format(total_count)
file_path = os.path.join(image_dir, file_string)
image.save(file_path, 'JPEG')
labels_csv.write(u'{},{}\n'.format(file_path, character))
for i in range(DISTORTION_COUNT):
total_count += 1
file_string = 'hangul_{}.jpeg'.format(total_count)
file_path = os.path.join(image_dir, file_string)
arr = numpy.array(image)
distorted_array = elastic_distort(
arr, alpha=random.randint(30, 36),
sigma=random.randint(5, 6)
)
distorted_image = Image.fromarray(distorted_array)
distorted_image.save(file_path, 'JPEG')
labels_csv.write(u'{},{}\n'.format(file_path, character))
print('Finished generating {} images.'.format(total_count))
labels_csv.close()
def elastic_distort(image, alpha, sigma):
"""Perform elastic distortion on an image.
Here, alpha refers to the scaling factor that controls the intensity of the
deformation. The sigma variable refers to the Gaussian filter standard
deviation.
"""
random_state = numpy.random.RandomState(None)
shape = image.shape
dx = gaussian_filter(
(random_state.rand(*shape) * 2 - 1),
sigma, mode="constant"
) * alpha
dy = gaussian_filter(
(random_state.rand(*shape) * 2 - 1),
sigma, mode="constant"
) * alpha
x, y = numpy.meshgrid(numpy.arange(shape[0]), numpy.arange(shape[1]))
indices = numpy.reshape(y+dy, (-1, 1)), numpy.reshape(x+dx, (-1, 1))
return map_coordinates(image, indices, order=1).reshape(shape)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--label-file', type=str, dest='label_file',
default=DEFAULT_LABEL_FILE,
help='File containing newline delimited labels.')
parser.add_argument('--font-dir', type=str, dest='fonts_dir',
default=DEFAULT_FONTS_DIR,
help='Directory of ttf fonts to use.')
parser.add_argument('--output-dir', type=str, dest='output_dir',
default=DEFAULT_OUTPUT_DIR,
help='Output directory to store generated images and '
'label CSV file.')
args = parser.parse_args()
generate_hangul_images(args.label_file, args.fonts_dir, args.output_dir)