From 18b95edaf84fae88c93dee3389aa0ebeabdc8fb2 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 6 Oct 2023 12:51:50 -0700 Subject: [PATCH] Fixed error for case when mltensor has data as null Signed-off-by: Martin Gaievski --- .../ml/MLCommonsClientAccessor.java | 9 +++++ .../ml/MLCommonsClientAccessorTests.java | 34 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 1c09f5996..59855bbf4 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -12,12 +12,14 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; +import org.apache.logging.log4j.util.Strings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -187,6 +189,13 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { for (final ModelTensors tensors : tensorOutputList) { final List tensorsList = tensors.getMlModelTensors(); for (final ModelTensor tensor : tensorsList) { + if (Objects.isNull(tensor.getData())) { + String exceptionMessage = "the system encountered an unexpected error during processing"; + if (Objects.nonNull(tensor.getDataAsMap()) && Strings.isNotBlank((String) tensor.getDataAsMap().get("message"))) { + exceptionMessage = (String) tensor.getDataAsMap().get("message"); + } + throw new IllegalStateException(exceptionMessage); + } vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList())); } } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index ce2773f2f..d093b9e71 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.ml; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -328,6 +329,21 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); } + public void testInferenceMultimodal_whenInvalidInputAndEmptyTensorOutput_thenFail() { + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createEmptyModelTensorOutput()); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(any()); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -355,4 +371,22 @@ private ModelTensorOutput createModelTensorOutput(final Map map) tensorsList.add(modelTensors); return new ModelTensorOutput(tensorsList); } + + private ModelTensorOutput createEmptyModelTensorOutput() { + final List tensorsList = new ArrayList<>(); + final List mlModelTensorList = new ArrayList<>(); + final ModelTensor tensor = new ModelTensor( + "someValue", + null, + new long[] { 1, 2 }, + MLResultDataType.FLOAT64, + ByteBuffer.wrap(new byte[12]), + "mockResult", + ImmutableMap.of("message", "The system encountered an unexpected error during processing. Try your request again.") + ); + mlModelTensorList.add(tensor); + final ModelTensors modelTensors = new ModelTensors(mlModelTensorList); + tensorsList.add(modelTensors); + return new ModelTensorOutput(tensorsList); + } }