Skip to content

Commit

Permalink
Add function test to tfcompile.
Browse files Browse the repository at this point in the history
This will serve as a regression test for future CLs changing how
functions are instantiated in tfcompile.
Change: 152085567
  • Loading branch information
skye authored and tensorflower-gardener committed Apr 4, 2017
1 parent b657f5a commit 7a825ba
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 2 deletions.
11 changes: 11 additions & 0 deletions tensorflow/compiler/aot/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ genrule(
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
"test_graph_tfmatmulandadd.pb",
"test_graph_tffunction.pb",
],
cmd = "$(location :make_test_graphs) --out_dir $(@D)",
tags = ["manual"],
Expand Down Expand Up @@ -114,6 +115,15 @@ tf_library(
tags = ["manual"],
)

tf_library(
name = "test_graph_tffunction",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
cpp_class = "FunctionComp",
graph = "test_graph_tffunction.pb",
tags = ["manual"],
)

cc_test(
name = "tfcompile_test",
srcs = ["tfcompile_test.cc"],
Expand All @@ -122,6 +132,7 @@ cc_test(
":test_graph_tfadd",
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
":test_graph_tfmatmulandadd",
Expand Down
16 changes: 14 additions & 2 deletions tensorflow/compiler/aot/tests/make_test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
Expand Down Expand Up @@ -95,6 +96,17 @@ def tfmatmulandadd(_):
math_ops.add(x, y, name='x_y_sum')


def tffunction(_):

@function.Defun(dtypes.int32, dtypes.int32)
def test_func(a, b):
return a + b

x = constant_op.constant([1], name='x_const')
y = constant_op.constant([2], name='y_const')
test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg


def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
Expand All @@ -112,6 +124,7 @@ def main(_):
write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)


if __name__ == '__main__':
Expand All @@ -121,7 +134,6 @@ def main(_):
'--out_dir',
type=str,
default='',
help='Output directory for graphs, checkpoints and savers.'
)
help='Output directory for graphs, checkpoints and savers.')
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)
16 changes: 16 additions & 0 deletions tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x_const" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_const" }
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "func_call" }
}
16 changes: 16 additions & 0 deletions tensorflow/compiler/aot/tests/tfcompile_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
Expand Down Expand Up @@ -376,6 +377,21 @@ TEST(TFCompileTest, MatMulAndAdd1) {
}
}

TEST(TFCompileTest, Function) {
// The function is equivalent to an addition
FunctionComp add_fn;
EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]);
EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]);

add_fn.arg0() = 1;
add_fn.arg1() = 2;
EXPECT_TRUE(add_fn.Run());
EXPECT_EQ(add_fn.error_msg(), "");
EXPECT_EQ(add_fn.result0(), 3);
EXPECT_EQ(add_fn.result0_data()[0], 3);
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
}

} // namespace
} // namespace tfcompile
} // namespace tensorflow

0 comments on commit 7a825ba

Please sign in to comment.