From 2918cbeec82c90db4b08f96d815746291ed17cb2 Mon Sep 17 00:00:00 2001 From: Jakub Kosek Date: Tue, 5 Sep 2023 03:12:21 -0700 Subject: [PATCH] Updated expected patterns for NeMo prompt test --- CHANGELOG.md | 2 +- .../test.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5677d59..f4d6fc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ limitations under the License. # Changelog -## 0.3.0 (2023-09-01) +## 0.3.0 (2023-09-05) - new: Support for multiple Python versions starting from 3.8+ - new: Added support for [decoupled models](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/decoupled_models.md) enabling to support streaming models (alpha state) diff --git a/tests/functional/L1_example_nemo_megatron_gpt_prompt_learning/test.py b/tests/functional/L1_example_nemo_megatron_gpt_prompt_learning/test.py index 4798870..d38cb1d 100755 --- a/tests/functional/L1_example_nemo_megatron_gpt_prompt_learning/test.py +++ b/tests/functional/L1_example_nemo_megatron_gpt_prompt_learning/test.py @@ -42,14 +42,18 @@ def verify_client_output(client_output): else: LOGGER.info(f'Found "{expected_pattern}" in client output') - expected_patterns = [r"neutral", r"set the alarm", r"seven am"] - for expected_pattern in expected_patterns: - output_match = re.search(expected_pattern, client_output, re.MULTILINE) - output_array = output_match.group(0) if output_match else None - if not output_array: - raise ValueError(f"Could not find {expected_pattern} in client output. Output: {client_output}") + # NeMo model might return neutral or positive sentiment for given task - both are acceptable in test + expected_patterns = [[r"neutral", r"positive"], [r"set the alarm"], [r"seven am"]] + for patterns in expected_patterns: + matches = [re.search(pattern, client_output, re.MULTILINE) for pattern in patterns] + output_array = [match.group(0) if match else None for match in matches] + + if not any(output_array): + raise ValueError( + f'Could not find any of patterns "{", ".join(patterns)}" in client output. Output: {client_output}' + ) else: - LOGGER.info(f'Found "{expected_pattern}" in client output') + LOGGER.info(f'Found at least one of patterns "{", ".join(patterns)}" in client output') def main():