Skip to content
This repository has been archived by the owner on Dec 29, 2022. It is now read-only.

Fixes to work with TensorFlow 1.2 #254

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ generated-members=set_shape,np.float32
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
contextmanager-decorators=contextlib.contextmanager,tensorflow.python.util.tf_contextlib.contextmanager


[VARIABLES]
Expand Down
9 changes: 7 additions & 2 deletions seq2seq/contrib/seq2seq/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@

import six

from tensorflow.contrib.distributions.python.ops import bernoulli
from tensorflow.contrib.distributions.python.ops import categorical
try:
from tensorflow.python.ops.distributions import bernoulli
from tensorflow.python.ops.distributions import categorical
except:
# Backwards compatibility with TensorFlow prior to 1.2.
from tensorflow.contrib.distributions.python.ops import bernoulli
from tensorflow.contrib.distributions.python.ops import categorical
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import base as layers_base
Expand Down
2 changes: 2 additions & 0 deletions seq2seq/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import six

import tensorflow as tf
# pylint: disable=no-name-in-module
from tensorflow.contrib.slim.python.slim.data import tfexample_decoder
# pylint: enable=no-name-in-module

from seq2seq.configurable import Configurable
from seq2seq.data import split_tokens_decoder, parallel_data_provider
Expand Down
2 changes: 2 additions & 0 deletions seq2seq/data/parallel_data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import numpy as np

import tensorflow as tf
# pylint: disable=no-name-in-module
from tensorflow.contrib.slim.python.slim.data import data_provider
from tensorflow.contrib.slim.python.slim.data import parallel_reader
# pylint: enable=no-name-in-module

from seq2seq.data import split_tokens_decoder

Expand Down
2 changes: 2 additions & 0 deletions seq2seq/data/sequence_example_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
"""A decoder for tf.SequenceExample"""

import tensorflow as tf
# pylint: disable=no-name-in-module
from tensorflow.contrib.slim.python.slim.data import data_decoder
# pylint: enable=no-name-in-module


class TFSEquenceExampleDecoder(data_decoder.DataDecoder):
Expand Down
2 changes: 2 additions & 0 deletions seq2seq/data/split_tokens_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from __future__ import unicode_literals

import tensorflow as tf
# pylint: disable=no-name-in-module
from tensorflow.contrib.slim.python.slim.data import data_decoder
# pylint: enable=no-name-in-module


class SplitTokensDecoder(data_decoder.DataDecoder):
Expand Down
2 changes: 2 additions & 0 deletions seq2seq/encoders/image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from __future__ import print_function

import tensorflow as tf
# pylint: disable=no-name-in-module
from tensorflow.contrib.slim.python.slim.nets.inception_v3 \
import inception_v3_base
# pylint: enable=no-name-in-module

from seq2seq.encoders.encoder import Encoder, EncoderOutput

Expand Down
3 changes: 1 addition & 2 deletions seq2seq/encoders/rnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import copy
import tensorflow as tf
from tensorflow.contrib.rnn.python.ops import rnn

from seq2seq.encoders.encoder import Encoder, EncoderOutput
from seq2seq.training import utils as training_utils
Expand Down Expand Up @@ -186,7 +185,7 @@ def encode(self, inputs, sequence_length, **kwargs):
cells_fw = _unpack_cell(cell_fw)
cells_bw = _unpack_cell(cell_bw)

result = rnn.stack_bidirectional_dynamic_rnn(
result = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
cells_fw=cells_fw,
cells_bw=cells_bw,
inputs=inputs,
Expand Down
2 changes: 2 additions & 0 deletions seq2seq/metrics/metric_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@

import tensorflow as tf
from tensorflow.contrib import metrics
# pylint: disable=no-name-in-module
from tensorflow.contrib.learn import MetricSpec
# pylint: enable=no-name-in-module

from seq2seq.data import postproc
from seq2seq.configurable import Configurable
Expand Down
14 changes: 10 additions & 4 deletions seq2seq/test/hooks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,22 @@ class TestPrintModelAnalysisHook(tf.test.TestCase):
def test_begin(self):
model_dir = tempfile.mkdtemp()
outfile = tempfile.NamedTemporaryFile()
tf.get_variable("weigths", [128, 128])
tf.get_variable("weights", [128, 128])
hook = hooks.PrintModelAnalysisHook(
params={}, model_dir=model_dir, run_config=tf.contrib.learn.RunConfig())
hook.begin()

with gfile.GFile(os.path.join(model_dir, "model_analysis.txt")) as file:
file_contents = file.read().strip()

self.assertEqual(file_contents.decode(), "_TFProfRoot (--/16.38k params)\n"
" weigths (128x128, 16.38k/16.38k params)")
lines = tf.compat.as_text(file_contents).split("\n")
if len(lines) == 3:
# TensorFlow v1.2 includes an extra header line
self.assertEqual(lines[0], "node name | # parameters")

self.assertEqual(lines[-2], "_TFProfRoot (--/16.38k params)")
self.assertEqual(lines[-1], " weights (128x128, 16.38k/16.38k params)")

outfile.close()


Expand Down Expand Up @@ -125,7 +131,7 @@ def tearDown(self):
def test_capture(self):
global_step = tf.contrib.framework.get_or_create_global_step()
# Some test computation
some_weights = tf.get_variable("weigths", [2, 128])
some_weights = tf.get_variable("weights", [2, 128])
computation = tf.nn.softmax(some_weights)

hook = hooks.MetadataCaptureHook(
Expand Down