Skip to content

Commit

Permalink
[TF FE] Add layer test for tf.keras.layers.TextVectorization and for …
Browse files Browse the repository at this point in the history
…LookupTableFindV2 with string key (openvinotoolkit#23011)

**Details:** Add layer test for tf.keras.layers.TextVectorization and
for LookupTableFindV2 with string key

**Ticket:** 132910

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Feb 22, 2024
1 parent a4dcf65 commit 547f5df
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np

import pytest
import tensorflow as tf

from common.tf2_layer_test_class import CommonTF2LayerTest
from tensorflow.keras.layers import TextVectorization

rng = np.random.default_rng()


class TestTextVectorization(CommonTF2LayerTest):
def _prepare_input(self, inputs_info):
assert 'text_input' in inputs_info
input_shape = inputs_info['text_input']
inputs_data = {}
strings_dictionary = ['hi OpenVINO here ', ' hello OpenVINO there', ' привет ОПЕНВИНО \n',
'hello PyTorch here ', ' hi TensorFlow here', ' hi JAX here \t']
inputs_data['text_input'] = rng.choice(strings_dictionary, input_shape)
return inputs_data

def create_text_vectorization_net(self, input_shape, vocabulary, output_mode, output_sequence_length):
assert len(input_shape) > 0
tf.keras.backend.clear_session()
text_input = tf.keras.Input(shape=input_shape[1:], name='text_input',
dtype=tf.string)
vectorized_text = TextVectorization(vocabulary=vocabulary,
output_mode=output_mode,
output_sequence_length=output_sequence_length,
name='text_vectorizer')(text_input)
tf2_net = tf.keras.Model(inputs=[text_input], outputs=[vectorized_text])

return tf2_net, None

@pytest.mark.parametrize('input_shape', [[2, 1], [2, 3]])
@pytest.mark.parametrize('vocabulary', [['hello', 'there', 'OpenVINO', 'check', 'привет', 'ОПЕНВИНО']])
@pytest.mark.parametrize('output_mode', ['int'])
@pytest.mark.parametrize('output_sequence_length', [32, 64])
@pytest.mark.precommit_tf_fe
@pytest.mark.xfail(reason='132692 - Add support of TextVectorization')
@pytest.mark.nightly
def test_text_vectorization(self, input_shape, vocabulary, output_mode, output_sequence_length, ie_device,
precision, ir_version, temp_dir, use_legacy_frontend):
params = {}
params['input_shape'] = input_shape
params['vocabulary'] = vocabulary
params['output_mode'] = output_mode
params['output_sequence_length'] = output_sequence_length
self._test(*self.create_text_vectorization_net(**params), ie_device, precision,
temp_dir=temp_dir, ir_version=ir_version, use_legacy_frontend=use_legacy_frontend, **params)
8 changes: 8 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_LookupTableFind.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def _prepare_input(self, inputs_info):
if np.issubdtype(self.keys_type, np.integer):
data = rng.choice(self.all_keys, keys_shape)
inputs_data['keys:0'] = mix_array_with_value(data, self.invalid_key)
elif self.keys_type == str:
data = rng.choice(self.all_keys + [self.invalid_key], keys_shape)
inputs_data['keys:0'] = data
else:
raise "Unsupported type {}".format(self.keys_type)

Expand Down Expand Up @@ -64,6 +67,11 @@ def create_lookup_table_find_net(self, hash_table_type, keys_shape, keys_type, v
dict(keys_type=np.int32, values_type=tf.string, all_keys=[20, 10, 33, -22, 44, 11],
all_values=['PyTorch', 'TensorFlow', 'JAX', 'Lightning', 'MindSpore', 'OpenVINO'],
default_value='UNKNOWN', invalid_key=1000),
pytest.param(dict(keys_type=str, values_type=np.int64,
all_keys=['PyTorch', 'TensorFlow', 'JAX', 'Lightning', 'MindSpore', 'OpenVINO'],
all_values=[200, 100, 0, -3, 10, 1],
default_value=0, invalid_key='AbraCadabra'),
marks=pytest.mark.xfail(reason="132669 - Support LookupTableFindV2 with string key")),
]

@pytest.mark.parametrize("hash_table_type", [0, 1])
Expand Down

0 comments on commit 547f5df

Please sign in to comment.