Skip to content

Commit

Permalink
Allow updates on dynamic index setting max_depth_limit
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 3, 2023
1 parent 69813ce commit fd8b1fd
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ private void retryableInferenceSentencesWithSingleVectorResult(
MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputObjects, vector);
log.debug("Inference Response for input sentence is : {} ", vector);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env),
TextImageEmbeddingProcessor.TYPE,
new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env)
new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService())
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
Expand Down Expand Up @@ -50,6 +53,7 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor {

private final MLCommonsClientAccessor mlCommonsClientAccessor;
private final Environment environment;
private final ClusterService clusterService;

public TextImageEmbeddingProcessor(
final String tag,
Expand All @@ -58,7 +62,8 @@ public TextImageEmbeddingProcessor(
final String embedding,
final Map<String, String> fieldMap,
final MLCommonsClientAccessor clientAccessor,
final Environment environment
final Environment environment,
final ClusterService clusterService
) {
super(tag, description);
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it");
Expand All @@ -69,6 +74,7 @@ public TextImageEmbeddingProcessor(
this.fieldMap = fieldMap;
this.mlCommonsClientAccessor = clientAccessor;
this.environment = environment;
this.clusterService = clusterService;
}

private void validateEmbeddingConfiguration(final Map<String, String> fieldMap) {
Expand Down Expand Up @@ -176,7 +182,8 @@ private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) {
}
Class<?> sourceValueClass = sourceValue.getClass();
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1);
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1, indexName);
} else if (!String.class.isAssignableFrom(sourceValueClass)) {
throw new IllegalArgumentException("field [" + mappedSourceKey + "] is neither string nor nested type, can not process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
Expand All @@ -187,17 +194,23 @@ private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) {
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void validateNestedTypeValue(final String sourceKey, final Object sourceValue, final Supplier<Integer> maxDepthSupplier) {
private void validateNestedTypeValue(
final String sourceKey,
final Object sourceValue,
final Supplier<Integer> maxDepthSupplier,
final String indexName
) {
int maxDepth = maxDepthSupplier.get();
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
Settings indexSettings = clusterService.state().metadata().index(indexName).getSettings();
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings)) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it");
} else if ((List.class.isAssignableFrom(sourceValue.getClass()))) {
validateListTypeValue(sourceKey, (List) sourceValue);
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1, indexName));
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,22 @@

import java.util.Map;

import lombok.AllArgsConstructor;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;

/**
* Factory for text_image embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
*/
@AllArgsConstructor
public class TextImageEmbeddingProcessorFactory implements Factory {

private final MLCommonsClientAccessor clientAccessor;

private final Environment environment;

public TextImageEmbeddingProcessorFactory(final MLCommonsClientAccessor clientAccessor, final Environment environment) {
this.clientAccessor = clientAccessor;
this.environment = environment;
}
private final ClusterService clusterService;

@Override
public TextImageEmbeddingProcessor create(
Expand All @@ -43,6 +42,15 @@ public TextImageEmbeddingProcessor create(
String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD);
String embedding = readStringProperty(TYPE, processorTag, config, EMBEDDING_FIELD);
Map<String, String> filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD);
return new TextImageEmbeddingProcessor(processorTag, description, modelId, embedding, filedMap, clientAccessor, environment);
return new TextImageEmbeddingProcessor(
processorTag,
description,
modelId,
embedding,
filedMap,
clientAccessor,
environment,
clusterService
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Map;
import java.util.Optional;

import org.opensearch.ingest.IngestService;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
Expand Down Expand Up @@ -56,7 +57,17 @@ public void testQueryPhaseSearcher() {

public void testProcessors() {
NeuralSearch plugin = new NeuralSearch();
Processor.Parameters processorParams = mock(Processor.Parameters.class);
Processor.Parameters processorParams = new Processor.Parameters(
null,
null,
null,
null,
null,
null,
mock(IngestService.class),
null,
null
);
Map<String, Processor.Factory> processors = plugin.getProcessors(processorParams);
assertNotNull(processors);
assertNotNull(processors.get(TextEmbeddingProcessor.TYPE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchParseException;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
Expand All @@ -48,9 +53,16 @@ public class TextImageEmbeddingProcessorTests extends OpenSearchTestCase {

@Mock
private MLCommonsClientAccessor mlCommonsClientAccessor;

@Mock
private Environment env;
@Mock
private ClusterService clusterService;
@Mock
private ClusterState clusterState;
@Mock
private Metadata metadata;
@Mock
private IndexMetadata indexMetadata;

@InjectMocks
private TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory;
Expand All @@ -62,6 +74,10 @@ public void setup() {
MockitoAnnotations.openMocks(this);
Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build();
when(env.settings()).thenReturn(settings);
when(clusterService.state()).thenReturn(clusterState);
when(clusterState.metadata()).thenReturn(metadata);
when(metadata.index(anyString())).thenReturn(indexMetadata);
when(indexMetadata.getSettings()).thenReturn(settings);
}

@SneakyThrows
Expand Down Expand Up @@ -98,7 +114,16 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t
// create with null type mapping
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> new TextImageEmbeddingProcessor(PROCESSOR_TAG, DESCRIPTION, modelId, embeddingField, null, mlCommonsClientAccessor, env)
() -> new TextImageEmbeddingProcessor(
PROCESSOR_TAG,
DESCRIPTION,
modelId,
embeddingField,
null,
mlCommonsClientAccessor,
env,
clusterService
)
);
assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage());

Expand All @@ -112,7 +137,8 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t
embeddingField,
Map.of("", "my_field"),
mlCommonsClientAccessor,
env
env,
clusterService
)
);
assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage());
Expand All @@ -131,7 +157,8 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t
embeddingField,
typeMapping,
mlCommonsClientAccessor,
env
env,
clusterService
)
);
assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage());
Expand Down Expand Up @@ -183,7 +210,11 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
Map<String, Processor.Factory> registry = new HashMap<>();
MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class);
TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory(accessor, env);
TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory(
accessor,
env,
clusterService
);

Map<String, Object> config = new HashMap<>();
config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
Expand Down Expand Up @@ -223,6 +254,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "hello world");
sourceAndMetadata.put("my_text_field", ret);
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance();
BiConsumer handler = mock(BiConsumer.class);
Expand Down Expand Up @@ -254,6 +286,7 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", map1);
sourceAndMetadata.put("my_text_field", map2);
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance();
BiConsumer handler = mock(BiConsumer.class);
Expand All @@ -267,6 +300,7 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", map1);
sourceAndMetadata.put("my_text_field", map2);
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance();
BiConsumer handler = mock(BiConsumer.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import lombok.SneakyThrows;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
Expand All @@ -30,7 +31,8 @@ public class TextImageEmbeddingProcessorFactoryTests extends OpenSearchTestCase
public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() {
TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory(
mock(MLCommonsClientAccessor.class),
mock(Environment.class)
mock(Environment.class),
mock(ClusterService.class)
);

final Map<String, org.opensearch.ingest.Processor.Factory> processorFactories = new HashMap<>();
Expand All @@ -55,7 +57,8 @@ public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() {
public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() {
TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory(
mock(MLCommonsClientAccessor.class),
mock(Environment.class)
mock(Environment.class),
mock(ClusterService.class)
);

final Map<String, org.opensearch.ingest.Processor.Factory> processorFactories = new HashMap<>();
Expand Down Expand Up @@ -88,7 +91,8 @@ public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() {
public void testNormalizationProcessor_whenMixOfParamsOrEmptyParams_thenFail() {
TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory(
mock(MLCommonsClientAccessor.class),
mock(Environment.class)
mock(Environment.class),
mock(ClusterService.class)
);

final Map<String, org.opensearch.ingest.Processor.Factory> processorFactories = new HashMap<>();
Expand Down

0 comments on commit fd8b1fd

Please sign in to comment.