diff --git a/demo.py b/demo.py index 52103e5b..9bb166c5 100644 --- a/demo.py +++ b/demo.py @@ -1,153 +1,7 @@ -import os -import logging - import numpy as np -import svgwrite -import drawing import lyrics -from rnn import rnn - - -class Hand(object): - - def __init__(self): - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' - self.nn = rnn( - log_dir='logs', - checkpoint_dir='checkpoints', - prediction_dir='predictions', - learning_rates=[.0001, .00005, .00002], - batch_sizes=[32, 64, 64], - patiences=[1500, 1000, 500], - beta1_decays=[.9, .9, .9], - validation_batch_size=32, - optimizer='rms', - num_training_steps=100000, - warm_start_init_step=17900, - regularization_constant=0.0, - keep_prob=1.0, - enable_parameter_averaging=False, - min_steps_to_checkpoint=2000, - log_interval=20, - logging_level=logging.CRITICAL, - grad_clip=10, - lstm_size=400, - output_mixture_components=20, - attention_mixture_components=10 - ) - self.nn.restore() - - def write(self, filename, lines, biases=None, styles=None, stroke_colors=None, stroke_widths=None): - valid_char_set = set(drawing.alphabet) - for line_num, line in enumerate(lines): - if len(line) > 75: - raise ValueError( - ( - "Each line must be at most 75 characters. " - "Line {} contains {}" - ).format(line_num, len(line)) - ) - - for char in line: - if char not in valid_char_set: - raise ValueError( - ( - "Invalid character {} detected in line {}. " - "Valid character set is {}" - ).format(char, line_num, valid_char_set) - ) - - strokes = self._sample(lines, biases=biases, styles=styles) - self._draw(strokes, lines, filename, stroke_colors=stroke_colors, stroke_widths=stroke_widths) - - def _sample(self, lines, biases=None, styles=None): - num_samples = len(lines) - max_tsteps = 40*max([len(i) for i in lines]) - biases = biases if biases is not None else [0.5]*num_samples - - x_prime = np.zeros([num_samples, 1200, 3]) - x_prime_len = np.zeros([num_samples]) - chars = np.zeros([num_samples, 120]) - chars_len = np.zeros([num_samples]) - - if styles is not None: - for i, (cs, style) in enumerate(zip(lines, styles)): - x_p = np.load('styles/style-{}-strokes.npy'.format(style)) - c_p = np.load('styles/style-{}-chars.npy'.format(style)).tostring().decode('utf-8') - - c_p = str(c_p) + " " + cs - c_p = drawing.encode_ascii(c_p) - c_p = np.array(c_p) - - x_prime[i, :len(x_p), :] = x_p - x_prime_len[i] = len(x_p) - chars[i, :len(c_p)] = c_p - chars_len[i] = len(c_p) - - else: - for i in range(num_samples): - encoded = drawing.encode_ascii(lines[i]) - chars[i, :len(encoded)] = encoded - chars_len[i] = len(encoded) - - [samples] = self.nn.session.run( - [self.nn.sampled_sequence], - feed_dict={ - self.nn.prime: styles is not None, - self.nn.x_prime: x_prime, - self.nn.x_prime_len: x_prime_len, - self.nn.num_samples: num_samples, - self.nn.sample_tsteps: max_tsteps, - self.nn.c: chars, - self.nn.c_len: chars_len, - self.nn.bias: biases - } - ) - samples = [sample[~np.all(sample == 0.0, axis=1)] for sample in samples] - return samples - - def _draw(self, strokes, lines, filename, stroke_colors=None, stroke_widths=None): - stroke_colors = stroke_colors or ['black']*len(lines) - stroke_widths = stroke_widths or [2]*len(lines) - - line_height = 60 - view_width = 1000 - view_height = line_height*(len(strokes) + 1) - - dwg = svgwrite.Drawing(filename=filename) - dwg.viewbox(width=view_width, height=view_height) - dwg.add(dwg.rect(insert=(0, 0), size=(view_width, view_height), fill='white')) - - initial_coord = np.array([0, -(3*line_height / 4)]) - for offsets, line, color, width in zip(strokes, lines, stroke_colors, stroke_widths): - - if not line: - initial_coord[1] -= line_height - continue - - offsets[:, :2] *= 1.5 - strokes = drawing.offsets_to_coords(offsets) - strokes = drawing.denoise(strokes) - strokes[:, :2] = drawing.align(strokes[:, :2]) - - strokes[:, 1] *= -1 - strokes[:, :2] -= strokes[:, :2].min() + initial_coord - strokes[:, 0] += (view_width - strokes[:, 0].max()) / 2 - - prev_eos = 1.0 - p = "M{},{} ".format(0, 0) - for x, y, eos in zip(*strokes.T): - p += '{}{},{} '.format('M' if prev_eos == 1.0 else 'L', x, y) - prev_eos = eos - path = svgwrite.path.Path(p) - path = path.stroke(color=color, width=width, linecap='round').fill("none") - dwg.add(path) - - initial_coord[1] -= line_height - - dwg.save() - +from hand import Hand if __name__ == '__main__': hand = Hand() diff --git a/hand.py b/hand.py new file mode 100644 index 00000000..462b54ab --- /dev/null +++ b/hand.py @@ -0,0 +1,148 @@ +import logging +import os + +import numpy as np +import svgwrite + +import drawing +from rnn import rnn + + +class Hand(object): + + def __init__(self): + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + self.nn = rnn( + log_dir='logs', + checkpoint_dir='checkpoints', + prediction_dir='predictions', + learning_rates=[.0001, .00005, .00002], + batch_sizes=[32, 64, 64], + patiences=[1500, 1000, 500], + beta1_decays=[.9, .9, .9], + validation_batch_size=32, + optimizer='rms', + num_training_steps=100000, + warm_start_init_step=17900, + regularization_constant=0.0, + keep_prob=1.0, + enable_parameter_averaging=False, + min_steps_to_checkpoint=2000, + log_interval=20, + logging_level=logging.CRITICAL, + grad_clip=10, + lstm_size=400, + output_mixture_components=20, + attention_mixture_components=10 + ) + self.nn.restore() + + def write(self, filename, lines, biases=None, styles=None, stroke_colors=None, stroke_widths=None): + valid_char_set = set(drawing.alphabet) + for line_num, line in enumerate(lines): + if len(line) > 75: + raise ValueError( + ( + "Each line must be at most 75 characters. " + "Line {} contains {}" + ).format(line_num, len(line)) + ) + + for char in line: + if char not in valid_char_set: + raise ValueError( + ( + "Invalid character {} detected in line {}. " + "Valid character set is {}" + ).format(char, line_num, valid_char_set) + ) + + strokes = self._sample(lines, biases=biases, styles=styles) + self._draw(strokes, lines, filename, stroke_colors=stroke_colors, stroke_widths=stroke_widths) + + def _sample(self, lines, biases=None, styles=None): + num_samples = len(lines) + max_tsteps = 40 * max([len(i) for i in lines]) + biases = biases if biases is not None else [0.5] * num_samples + + x_prime = np.zeros([num_samples, 1200, 3]) + x_prime_len = np.zeros([num_samples]) + chars = np.zeros([num_samples, 120]) + chars_len = np.zeros([num_samples]) + + if styles is not None: + for i, (cs, style) in enumerate(zip(lines, styles)): + x_p = np.load('styles/style-{}-strokes.npy'.format(style)) + c_p = np.load('styles/style-{}-chars.npy'.format(style)).tostring().decode('utf-8') + + c_p = str(c_p) + " " + cs + c_p = drawing.encode_ascii(c_p) + c_p = np.array(c_p) + + x_prime[i, :len(x_p), :] = x_p + x_prime_len[i] = len(x_p) + chars[i, :len(c_p)] = c_p + chars_len[i] = len(c_p) + + else: + for i in range(num_samples): + encoded = drawing.encode_ascii(lines[i]) + chars[i, :len(encoded)] = encoded + chars_len[i] = len(encoded) + + [samples] = self.nn.session.run( + [self.nn.sampled_sequence], + feed_dict={ + self.nn.prime: styles is not None, + self.nn.x_prime: x_prime, + self.nn.x_prime_len: x_prime_len, + self.nn.num_samples: num_samples, + self.nn.sample_tsteps: max_tsteps, + self.nn.c: chars, + self.nn.c_len: chars_len, + self.nn.bias: biases + } + ) + samples = [sample[~np.all(sample == 0.0, axis=1)] for sample in samples] + return samples + + def _draw(self, strokes, lines, filename, stroke_colors=None, stroke_widths=None): + stroke_colors = stroke_colors or ['black'] * len(lines) + stroke_widths = stroke_widths or [2] * len(lines) + + line_height = 60 + view_width = 1000 + view_height = line_height * (len(strokes) + 1) + + dwg = svgwrite.Drawing(filename=filename) + dwg.viewbox(width=view_width, height=view_height) + dwg.add(dwg.rect(insert=(0, 0), size=(view_width, view_height), fill='white')) + + initial_coord = np.array([0, -(3 * line_height / 4)]) + for offsets, line, color, width in zip(strokes, lines, stroke_colors, stroke_widths): + + if not line: + initial_coord[1] -= line_height + continue + + offsets[:, :2] *= 1.5 + strokes = drawing.offsets_to_coords(offsets) + strokes = drawing.denoise(strokes) + strokes[:, :2] = drawing.align(strokes[:, :2]) + + strokes[:, 1] *= -1 + strokes[:, :2] -= strokes[:, :2].min() + initial_coord + strokes[:, 0] += (view_width - strokes[:, 0].max()) / 2 + + prev_eos = 1.0 + p = "M{},{} ".format(0, 0) + for x, y, eos in zip(*strokes.T): + p += '{}{},{} '.format('M' if prev_eos == 1.0 else 'L', x, y) + prev_eos = eos + path = svgwrite.path.Path(p) + path = path.stroke(color=color, width=width, linecap='round').fill("none") + dwg.add(path) + + initial_coord[1] -= line_height + + dwg.save() diff --git a/requirements.txt b/requirements.txt index a3cdb45a..32010345 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ pandas>= 0.22.0 scikit-learn>=0.19.1 scipy>=1.0.0 svgwrite>=1.1.12 -tensorflow==1.6.0 +tensorflow diff --git a/rnn_ops.py b/rnn_ops.py index c202236e..83959eef 100644 --- a/rnn_ops.py +++ b/rnn_ops.py @@ -1,16 +1,16 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops.rnn_cell_impl import _concat, _like_rnncell from tensorflow.python.ops.rnn import _maybe_tensor_shape_from_tensor +from tensorflow.python.ops.rnn_cell_impl import _concat, _like_rnncell +from tensorflow.python.util import is_in_graph_mode from tensorflow.python.util import nest -from tensorflow.python.framework import tensor_shape -from tensorflow.python.eager import context def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None): @@ -37,7 +37,7 @@ def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=No # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if context.in_graph_mode(): + if is_in_graph_mode: if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -136,6 +136,7 @@ def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_ def _copy_some_through(current, candidate): """Copy some tensors through via array_ops.where.""" + def copy_fn(cur_i, cand_i): # TensorArray and scalar get passed through. if isinstance(cur_i, tensor_array_ops.TensorArray): @@ -145,6 +146,7 @@ def copy_fn(cur_i, cand_i): # Otherwise propagate the old or the new value. with ops.colocate_with(cand_i): return array_ops.where(elements_finished, cur_i, cand_i) + return nest.map_structure(copy_fn, current, candidate) emit_output = _copy_some_through(zero_emit, emit_output)