Skip to content

Commit

Permalink
Add ignoreFailure and pipelineContext (#152)
Browse files Browse the repository at this point in the history
* Add ignoreFailure and pipelineContext

Signed-off-by: Mingshi Liu <[email protected]>

* use lowercase

Signed-off-by: Mingshi Liu <[email protected]>

---------

Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl authored Jul 11, 2023
1 parent 25d3e08 commit c15274e
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public List<SearchExtSpec<?>> getSearchExts() {
}

@Override
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Processor.Parameters parameters) {
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Parameters parameters) {
return Map.of(PersonalizeRankingResponseProcessor.TYPE, new PersonalizeRankingResponseProcessor.Factory(PersonalizeClientSettings.getClientSettings(parameters.env.settings())),
KendraRankingResponseProcessor.TYPE, new KendraRankingResponseProcessor.Factory(this.kendraClientSettings));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class ConfigurationUtils {
* Get result transformer configurations from Search Request
*
* @param settings all index settings configured for this plugin
* @param resultTransformerMap map of transformed results
* @return ordered and validated list of result transformers, empty list if not specified
*/
public static List<ResultTransformerConfiguration> getResultTransformersFromIndexConfiguration(Settings settings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.profile.SearchProfileShardResults;
Expand All @@ -35,7 +36,7 @@
/**
* This is a {@link SearchResponseProcessor} that applies kendra intelligence ranking
*/
public class KendraRankingResponseProcessor implements SearchResponseProcessor {
public class KendraRankingResponseProcessor extends AbstractProcessor implements SearchResponseProcessor {
/**
* key to reference this processor type from a search pipeline
*/
Expand All @@ -54,13 +55,14 @@ public class KendraRankingResponseProcessor implements SearchResponseProcessor {
*
* @param tag processor tag
* @param description processor description
* @param ignoreFailure processor ignoreFailure config
* @param titleField titleField applied to kendra re-ranking
* @param bodyField bodyField applied to kendra re-ranking
* @param inputDocLimit docLimit applied to kendra re-ranking
* @param kendraClient kendraClient to connect with kendra
*/
public KendraRankingResponseProcessor(String tag, String description, List<String> titleField, List<String> bodyField, Integer inputDocLimit, KendraHttpClient kendraClient) {
super();
public KendraRankingResponseProcessor(String tag, String description, boolean ignoreFailure, List<String> titleField, List<String> bodyField, Integer inputDocLimit, KendraHttpClient kendraClient) {
super(tag, description, ignoreFailure);
this.titleField = titleField;
this.bodyField = bodyField;
this.tag = tag;
Expand Down Expand Up @@ -99,6 +101,7 @@ public String getDescription() {
return description;
}


/**
* Transform the response hit and apply kendra re-ranking logic
*/
Expand Down Expand Up @@ -156,7 +159,9 @@ public KendraRankingResponseProcessor create(
Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
String tag,
String description,
Map<String, Object> config
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) throws Exception {
List<String> titleField = Collections.singletonList(ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, "title_field"));
List<String> bodyField = Collections.singletonList(ConfigurationUtils.readStringProperty(TYPE, tag, config, "body_field"));
Expand All @@ -168,7 +173,7 @@ public KendraRankingResponseProcessor create(
} else {
docLimit = Integer.parseInt(inputDocLimit);
}
return new KendraRankingResponseProcessor(tag, description, titleField, bodyField, docLimit, kendraClient);
return new KendraRankingResponseProcessor(tag, description, ignoreFailure, titleField, bodyField, docLimit, kendraClient);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.profile.SearchProfileShardResults;
Expand All @@ -36,7 +37,7 @@
/**
* This is a {@link SearchResponseProcessor} that applies Personalized intelligent ranking
*/
public class PersonalizeRankingResponseProcessor implements SearchResponseProcessor {
public class PersonalizeRankingResponseProcessor extends AbstractProcessor implements SearchResponseProcessor {

private static final Logger logger = LogManager.getLogger(PersonalizeRankingResponseProcessor.class);

Expand All @@ -51,14 +52,16 @@ public class PersonalizeRankingResponseProcessor implements SearchResponseProces
*
* @param tag processor tag
* @param description processor description
* @param ignoreFailure processor ignoreFailure config
* @param rankerConfig personalize ranker config
* @param client personalize client
*/
public PersonalizeRankingResponseProcessor(String tag,
String description,
boolean ignoreFailure,
PersonalizeIntelligentRankerConfiguration rankerConfig,
PersonalizeClient client) {
super();
super(tag, description, ignoreFailure);
this.tag = tag;
this.description = description;
this.rankerConfig = rankerConfig;
Expand Down Expand Up @@ -150,7 +153,7 @@ public Factory(PersonalizeClientSettings settings) {
}

@Override
public PersonalizeRankingResponseProcessor create(Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, String tag, String description, Map<String, Object> config) throws Exception {
public PersonalizeRankingResponseProcessor create(Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, String tag, String description, boolean ignoreFailure, Map<String, Object> config, PipelineContext pipelineContext) throws Exception {
String personalizeCampaign = ConfigurationUtils.readStringProperty(TYPE, tag, config, CAMPAIGN_ARN_CONFIG_NAME);
String iamRoleArn = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, IAM_ROLE_ARN_CONFIG_NAME);
String recipe = ConfigurationUtils.readStringProperty(TYPE, tag, config, RECIPE_CONFIG_NAME);
Expand All @@ -162,7 +165,7 @@ public PersonalizeRankingResponseProcessor create(Map<String, Processor.Factory<
new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, awsRegion, weight);
AWSCredentialsProvider credentialsProvider = PersonalizeCredentialsProviderFactory.getCredentialsProvider(personalizeClientSettings, iamRoleArn, awsRegion);
PersonalizeClient personalizeClient = clientBuilder.apply(credentialsProvider, awsRegion);
return new PersonalizeRankingResponseProcessor(tag, description, rankerConfig, personalizeClient);
return new PersonalizeRankingResponseProcessor(tag, description, ignoreFailure, rankerConfig, personalizeClient);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ public void testFactory() throws Exception {
Collections.emptyMap(),
null,
null,
Collections.emptyMap()
false,
Collections.emptyMap(),
null
));

//test create with all fields
Expand All @@ -76,15 +78,15 @@ public void testFactory() throws Exception {
configuration.put("title_field","field");
configuration.put("body_field","body");
configuration.put("doc_limit","500");
KendraRankingResponseProcessor processorWithAllFields = factory.create(Collections.emptyMap(),"tmp0","testingAllFields", configuration);
KendraRankingResponseProcessor processorWithAllFields = factory.create(Collections.emptyMap(),"tmp0","testingAllFields", false, configuration,null);
assertEquals(TYPE, processorWithAllFields.getType());
assertEquals("tmp0", processorWithAllFields.getTag());
assertEquals("testingAllFields", processorWithAllFields.getDescription());

//test create with required field
Map<String,Object> shortConfiguration = new HashMap<>();
shortConfiguration.put("body_field","body");
KendraRankingResponseProcessor processorWithOneFields = factory.create(Collections.emptyMap(),"tmp1","testingBodyField", shortConfiguration);
KendraRankingResponseProcessor processorWithOneFields = factory.create(Collections.emptyMap(),"tmp1","testingBodyField", false, shortConfiguration, null);
assertEquals(TYPE, processorWithOneFields.getType());
assertEquals("tmp1", processorWithOneFields.getTag());
assertEquals("testingBodyField", processorWithOneFields.getDescription());
Expand All @@ -93,7 +95,7 @@ public void testFactory() throws Exception {
Map<String,Object> nullDocLimitConfiguration = new HashMap<>();
nullDocLimitConfiguration.put("body_field","body");
nullDocLimitConfiguration.put("doc_limit",null);
KendraRankingResponseProcessor processorWithNullDocLimit = factory.create(Collections.emptyMap(),"tmp2","testingNullDocLimit", nullDocLimitConfiguration);
KendraRankingResponseProcessor processorWithNullDocLimit = factory.create(Collections.emptyMap(),"tmp2","testingNullDocLimit", false, nullDocLimitConfiguration, null );
assertEquals(TYPE, processorWithNullDocLimit.getType());
assertEquals("tmp2", processorWithNullDocLimit.getTag());
assertEquals("testingNullDocLimit", processorWithNullDocLimit.getDescription());
Expand All @@ -102,7 +104,7 @@ public void testFactory() throws Exception {
Map<String,Object> nullTitleConfiguration = new HashMap<>();
nullTitleConfiguration.put("body_field","body");
nullTitleConfiguration.put("title_field",null);
KendraRankingResponseProcessor processorWithNullTitleField = factory.create(Collections.emptyMap(),"tmp3","testingNullTitleField", nullTitleConfiguration);
KendraRankingResponseProcessor processorWithNullTitleField = factory.create(Collections.emptyMap(),"tmp3","testingNullTitleField", false, nullTitleConfiguration, null);
assertEquals(TYPE, processorWithNullTitleField.getType());
assertEquals("tmp3", processorWithNullTitleField.getTag());
assertEquals("testingNullTitleField", processorWithNullTitleField.getDescription());
Expand All @@ -116,18 +118,18 @@ public void testRankingResponse() throws Exception {
bodyField.add("body");

//test response with titleField, bodyField and docLimit
KendraRankingResponseProcessor processorWtOptionalConfig = new KendraRankingResponseProcessor(null,null,titleField,bodyField,500,kendraClient);
KendraRankingResponseProcessor processorWtOptionalConfig = new KendraRankingResponseProcessor(null,null,false, titleField,bodyField,500,kendraClient);
int size = 5;
SearchResponse reRankedResponse0 = processorWtOptionalConfig.processResponse(createRequest(),createResponse(size));
assertEquals(size,reRankedResponse0.getHits().getHits().length);

//test response with null doc limit
KendraRankingResponseProcessor processorWtTwoConfig = new KendraRankingResponseProcessor(null,null,titleField,bodyField,null,kendraClient);
KendraRankingResponseProcessor processorWtTwoConfig = new KendraRankingResponseProcessor(null,null,false, titleField,bodyField,null,kendraClient);
SearchResponse reRankedResponse1 = processorWtTwoConfig.processResponse(createRequest(),createResponse(size));
assertEquals(size,reRankedResponse1.getHits().getHits().length);

//test response with null doc limit and null title field
KendraRankingResponseProcessor processorWtOneConfig = new KendraRankingResponseProcessor(null,null,null,bodyField,null,kendraClient);
KendraRankingResponseProcessor processorWtOneConfig = new KendraRankingResponseProcessor(null,null,false,null,bodyField,null,kendraClient);
SearchResponse reRankedResponse2 = processorWtOneConfig.processResponse(createRequest(),createResponse(size));
assertEquals(size,reRankedResponse2.getHits().getHits().length);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ public void testCreateFactoryThrowsExceptionWithEmptyConfig() {
Collections.emptyMap(),
null,
null,
Collections.emptyMap()
false,
Collections.emptyMap(),
null
));
}

Expand All @@ -72,7 +74,7 @@ public void testCreateFactoryWithAllPersonalizeConfig() throws Exception {
configuration.put("aws_region", region);

PersonalizeRankingResponseProcessor personalizeResponseProcessor =
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration);
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null);

assertEquals(TYPE, personalizeResponseProcessor.getType());
assertEquals("testTag", personalizeResponseProcessor.getTag());
Expand All @@ -94,7 +96,7 @@ public void testProcessorWithNoHits() throws Exception {
configuration.put("aws_region", region);

PersonalizeRankingResponseProcessor personalizeResponseProcessor =
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration);
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null);
SearchRequest searchRequest = new SearchRequest();
SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f);
SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0);
Expand All @@ -118,7 +120,7 @@ public void testProcessorWithHits() throws Exception {
configuration.put("aws_region", region);

PersonalizeRankingResponseProcessor personalizeResponseProcessor =
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration);
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null);
SearchRequest searchRequest = new SearchRequest();
SearchHit[] searchHits = new SearchHit[10];
for (int i = 0; i < searchHits.length; i++) {
Expand Down Expand Up @@ -147,7 +149,7 @@ public void testProcessorWithHitsAndSearchProcessorExt() throws Exception {
configuration.put("aws_region", region);

PersonalizeRankingResponseProcessor personalizeResponseProcessor =
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration);
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null);

Map<String, Object> personalizeContext = new HashMap<>();
personalizeContext.put("contextKey2", "contextValue2");
Expand Down Expand Up @@ -186,7 +188,7 @@ public void testProcessorWithHitsWithInvalidPersonalizeContext() throws Exceptio
configuration.put("aws_region", region);

PersonalizeRankingResponseProcessor personalizeResponseProcessor =
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration);
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration,null);

Map<String, Object> personalizeContext = new HashMap<>();
personalizeContext.put("contextKey2", 5);
Expand Down

0 comments on commit c15274e

Please sign in to comment.