Skip to content

Commit

Permalink
Add text embedding processor to neural search (#18)
Browse files Browse the repository at this point in the history
* # This is a combination of 14 commits.
# This is the 1st commit message:

Add text embedding processor to neural search

Signed-off-by: Zan Niu <[email protected]>

# The commit message #2 will be skipped:

# Code format
#
# Signed-off-by: Zan Niu <[email protected]>

# The commit message #3 will be skipped:

# Address review comments
#
# Signed-off-by: Zan Niu <[email protected]>

# The commit message #4 will be skipped:

# Add blocking text embedding method for pipeline processor
#
# Signed-off-by: Zan Niu <[email protected]>

# The commit message #5 will be skipped:

# Add BaseNeuralSearchIT and address other review comments
#
# Signed-off-by: Zan Niu <[email protected]>

# The commit message #6 will be skipped:

# Add BaseNeuralSearchIT and address other review comments
#
# Signed-off-by: Zan Niu <[email protected]>

# The commit message #7 will be skipped:

# Add BaseNeuralSearchIT and address other review comments
#
# Signed-off-by: Zan Niu <[email protected]>

# The commit message #8 will be skipped:

# Fix naming convention and IT function move to base
#
# Signed-off-by: Zan Niu <[email protected]>

# The commit message #9 will be skipped:

# Fix naming convention and IT function move to base
#
# Signed-off-by: Zan Niu <[email protected]>

# The commit message #10 will be skipped:

# Update src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
#
# Co-authored-by: Navneet Verma <[email protected]>

# The commit message #11 will be skipped:

# Update src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java
#
# Co-authored-by: Navneet Verma <[email protected]>

# The commit message #12 will be skipped:

# Fix code review comments
#
# Signed-off-by: Zan Niu <[email protected]>

# The commit message #13 will be skipped:

# Fix text embedding processor NPE
#
# Signed-off-by: Zan Niu <[email protected]>

# The commit message #14 will be skipped:

# Remove jackson dependencies and fix tests with XCoontent
#
# Signed-off-by: Zan Niu <[email protected]>

* Add text embedding processor to neural search

Signed-off-by: Zan Niu <[email protected]>

* Remove unnecessary parameters in TextEmbeddingProcessor method

Signed-off-by: Zan Niu <[email protected]>

* Remove unnecessary empty string checks

Signed-off-by: Zan Niu <[email protected]>

* Add field max depth limit to prevent malicious attack

Signed-off-by: Zan Niu <[email protected]>

Signed-off-by: Zan Niu <[email protected]>
  • Loading branch information
zane-neo authored Oct 20, 2022
1 parent 272d803 commit 799c402
Show file tree
Hide file tree
Showing 13 changed files with 1,291 additions and 16 deletions.
3 changes: 3 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
* Learn more about Gradle by exploring our samples at https://docs.gradle.org/7.5.1/samples
* This project uses @Incubating APIs which are subject to change.
*/

import org.opensearch.gradle.test.RestIntegTestTask

import java.util.concurrent.Callable

apply plugin: 'java'
Expand Down Expand Up @@ -137,6 +139,7 @@ dependencies {
zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}"
compileOnly fileTree(dir: knnJarDirectory, include: '*.jar')
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.10'
}

// From maven, we can get the k-NN plugin as a zip. In order to add the jar to the classpath, we need to unzip the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelTaskType;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
Expand Down Expand Up @@ -99,23 +102,54 @@ public void inferenceSentences(
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
final MLInput mlInput = new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset, MLModelTaskType.TEXT_EMBEDDING);
final List<List<Float>> vector = new ArrayList<>();

MLInput mlInput = createMLInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
}
}
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
listener.onResponse(vector);
}, listener::onFailure));
}

/**
* Abstraction to call predict function of api of MLClient with provided targetResponseFilters. It uses the
* custom model provided as modelId and run the {@link MLModelTaskType#TEXT_EMBEDDING}. The return will be sent
* using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of
* inputText. We are not making this function generic enough to take any function or TaskType as currently we need
* to run only TextEmbedding tasks only. Please note this method is a blocking method, use this only when the processing
* needs block waiting for response, otherwise please use {@link #inferenceSentences(String, List, ActionListener)}
* instead.
* @param modelId {@link String}
* @param inputText {@link List} of {@link String} on which inference needs to happen.
* @return {@link List} of {@link List} of {@link String} represents the text embedding vector result.
* @throws ExecutionException If the underlying task failed, this exception will be thrown in the future.get().
* @throws InterruptedException If the thread is interrupted, this will be thrown.
*/
public List<List<Float>> inferenceSentences(@NonNull final String modelId, @NonNull final List<String> inputText)
throws ExecutionException, InterruptedException {
final MLInput mlInput = createMLInput(TARGET_RESPONSE_FILTERS, inputText);
final ActionFuture<MLOutput> outputActionFuture = mlClient.predict(modelId, mlInput);
final List<List<Float>> vector = buildVectorFromResponse(outputActionFuture.get());
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
return vector;
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset, MLModelTaskType.TEXT_EMBEDDING);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
}
}
return vector;
}

}
17 changes: 14 additions & 3 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

import org.opensearch.action.ActionRequest;
Expand All @@ -19,12 +20,16 @@
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.plugin.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.transport.MLPredictAction;
import org.opensearch.neuralsearch.transport.MLPredictTransportAction;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
Expand All @@ -35,7 +40,9 @@
/**
* Neural Search plugin class
*/
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin {
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin {

private MLCommonsClientAccessor clientAccessor;

@Override
public Collection<Object> createComponents(
Expand All @@ -51,8 +58,6 @@ public Collection<Object> createComponents(
final IndexNameExpressionResolver indexNameExpressionResolver,
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
final MachineLearningNodeClient machineLearningNodeClient = new MachineLearningNodeClient(client);
final MLCommonsClientAccessor clientAccessor = new MLCommonsClientAccessor(machineLearningNodeClient);
NeuralQueryBuilder.initialize(clientAccessor);
return List.of(clientAccessor);
}
Expand All @@ -72,4 +77,10 @@ public List<QuerySpec<?>> getQueries() {
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent)
);
}

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
}
}
Loading

0 comments on commit 799c402

Please sign in to comment.