diff --git a/.gitignore b/.gitignore index 920c459b..e3c91991 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ data/raw/ascii data/raw/lineStrokes data/raw/original data/processed - +__pycache__ +img logs predictions diff --git a/rnn_ops.py b/rnn_ops.py index c202236e..1f9a603f 100644 --- a/rnn_ops.py +++ b/rnn_ops.py @@ -6,11 +6,11 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops.rnn_cell_impl import _concat, _like_rnncell +from tensorflow.python.ops.rnn_cell_impl import _concat, assert_like_rnncell from tensorflow.python.ops.rnn import _maybe_tensor_shape_from_tensor from tensorflow.python.util import nest from tensorflow.python.framework import tensor_shape -from tensorflow.python.eager import context +from tensorflow.python.util import is_in_graph_mode def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None): @@ -26,8 +26,8 @@ def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=No final cell state, ) """ - if not _like_rnncell(cell): - raise TypeError("cell must be an instance of RNNCell") + assert_like_rnncell("Raw rnn cell",cell) + if not callable(loop_fn): raise TypeError("loop_fn must be a callable") @@ -37,7 +37,7 @@ def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=No # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if context.in_graph_mode(): + if is_in_graph_mode.IS_IN_GRAPH_MODE(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device)