From 547f5dffb33883448f8d91ae7cf245b7093cf05d Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Thu, 22 Feb 2024 21:01:37 +0400 Subject: [PATCH] [TF FE] Add layer test for tf.keras.layers.TextVectorization and for LookupTableFindV2 with string key (#23011) **Details:** Add layer test for tf.keras.layers.TextVectorization and for LookupTableFindV2 with string key **Ticket:** 132910 --------- Signed-off-by: Kazantsev, Roman --- .../test_tf2_keras_text_vectorization.py | 53 +++++++++++++++++++ .../test_tf_LookupTableFind.py | 8 +++ 2 files changed, 61 insertions(+) create mode 100644 tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_text_vectorization.py diff --git a/tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_text_vectorization.py b/tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_text_vectorization.py new file mode 100644 index 00000000000000..b42bb80fa1f330 --- /dev/null +++ b/tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_text_vectorization.py @@ -0,0 +1,53 @@ +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +import pytest +import tensorflow as tf + +from common.tf2_layer_test_class import CommonTF2LayerTest +from tensorflow.keras.layers import TextVectorization + +rng = np.random.default_rng() + + +class TestTextVectorization(CommonTF2LayerTest): + def _prepare_input(self, inputs_info): + assert 'text_input' in inputs_info + input_shape = inputs_info['text_input'] + inputs_data = {} + strings_dictionary = ['hi OpenVINO here ', ' hello OpenVINO there', ' привет ОПЕНВИНО \n', + 'hello PyTorch here ', ' hi TensorFlow here', ' hi JAX here \t'] + inputs_data['text_input'] = rng.choice(strings_dictionary, input_shape) + return inputs_data + + def create_text_vectorization_net(self, input_shape, vocabulary, output_mode, output_sequence_length): + assert len(input_shape) > 0 + tf.keras.backend.clear_session() + text_input = tf.keras.Input(shape=input_shape[1:], name='text_input', + dtype=tf.string) + vectorized_text = TextVectorization(vocabulary=vocabulary, + output_mode=output_mode, + output_sequence_length=output_sequence_length, + name='text_vectorizer')(text_input) + tf2_net = tf.keras.Model(inputs=[text_input], outputs=[vectorized_text]) + + return tf2_net, None + + @pytest.mark.parametrize('input_shape', [[2, 1], [2, 3]]) + @pytest.mark.parametrize('vocabulary', [['hello', 'there', 'OpenVINO', 'check', 'привет', 'ОПЕНВИНО']]) + @pytest.mark.parametrize('output_mode', ['int']) + @pytest.mark.parametrize('output_sequence_length', [32, 64]) + @pytest.mark.precommit_tf_fe + @pytest.mark.xfail(reason='132692 - Add support of TextVectorization') + @pytest.mark.nightly + def test_text_vectorization(self, input_shape, vocabulary, output_mode, output_sequence_length, ie_device, + precision, ir_version, temp_dir, use_legacy_frontend): + params = {} + params['input_shape'] = input_shape + params['vocabulary'] = vocabulary + params['output_mode'] = output_mode + params['output_sequence_length'] = output_sequence_length + self._test(*self.create_text_vectorization_net(**params), ie_device, precision, + temp_dir=temp_dir, ir_version=ir_version, use_legacy_frontend=use_legacy_frontend, **params) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_LookupTableFind.py b/tests/layer_tests/tensorflow_tests/test_tf_LookupTableFind.py index 6f4a0c3b272e92..6ff6daeda99045 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_LookupTableFind.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_LookupTableFind.py @@ -18,6 +18,9 @@ def _prepare_input(self, inputs_info): if np.issubdtype(self.keys_type, np.integer): data = rng.choice(self.all_keys, keys_shape) inputs_data['keys:0'] = mix_array_with_value(data, self.invalid_key) + elif self.keys_type == str: + data = rng.choice(self.all_keys + [self.invalid_key], keys_shape) + inputs_data['keys:0'] = data else: raise "Unsupported type {}".format(self.keys_type) @@ -64,6 +67,11 @@ def create_lookup_table_find_net(self, hash_table_type, keys_shape, keys_type, v dict(keys_type=np.int32, values_type=tf.string, all_keys=[20, 10, 33, -22, 44, 11], all_values=['PyTorch', 'TensorFlow', 'JAX', 'Lightning', 'MindSpore', 'OpenVINO'], default_value='UNKNOWN', invalid_key=1000), + pytest.param(dict(keys_type=str, values_type=np.int64, + all_keys=['PyTorch', 'TensorFlow', 'JAX', 'Lightning', 'MindSpore', 'OpenVINO'], + all_values=[200, 100, 0, -3, 10, 1], + default_value=0, invalid_key='AbraCadabra'), + marks=pytest.mark.xfail(reason="132669 - Support LookupTableFindV2 with string key")), ] @pytest.mark.parametrize("hash_table_type", [0, 1])