diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverter.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverter.java index ab2bd426..1001f40b 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverter.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverter.java @@ -1,16 +1,18 @@ package ai.h2o.mojos.deploy.common.transform; -import static java.util.Arrays.asList; import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; -import ai.h2o.mojos.deploy.common.rest.model.ContributionResponse; +import ai.h2o.mojos.deploy.common.rest.model.PredictionInterval; import ai.h2o.mojos.deploy.common.rest.model.Row; import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest; import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse; import ai.h2o.mojos.runtime.frame.MojoFrame; import com.google.common.base.Strings; + import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -18,6 +20,7 @@ import java.util.UUID; import java.util.function.BiFunction; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; /** @@ -27,26 +30,215 @@ public class MojoFrameToScoreResponseConverter implements BiFunction { + private static final String LOWER_BOUND = ".lower"; + private static final String UPPER_BOUND = ".upper"; + + // If true then pipeline support prediction interval, otherwise false. + // Note: guarantee that pipeline supports Prediction interval. + private Boolean supportPredictionInterval; + + public MojoFrameToScoreResponseConverter(boolean supportPredictionInterval) { + this.supportPredictionInterval = supportPredictionInterval; + } + + public MojoFrameToScoreResponseConverter() { + supportPredictionInterval = false; + } + + /** + * Transform MOJO response frame into ScoreResponse. + * @param mojoFrame mojo response frame. + * @param scoreRequest score request. + * @return score response. + */ @Override - public ScoreResponse apply(MojoFrame mojoFrame, ScoreRequest scoreRequest) { + public ScoreResponse apply( + MojoFrame mojoFrame, ScoreRequest scoreRequest) { Set includedFields = getSetOfIncludedFields(scoreRequest); List outputRows = Stream.generate(Row::new).limit(mojoFrame.getNrows()).collect(Collectors.toList()); copyFilteredInputFields(scoreRequest, includedFields, outputRows); - Utils.copyResultFields(mojoFrame, outputRows); + fillOutputRows(mojoFrame, outputRows); ScoreResponse response = new ScoreResponse(); response.setScore(outputRows); if (!Boolean.TRUE.equals(scoreRequest.isNoFieldNamesInOutput())) { List outputFieldNames = getFilteredInputFieldNames(scoreRequest, includedFields); - outputFieldNames.addAll(asList(mojoFrame.getColumnNames())); + outputFieldNames.addAll(getTargetField(mojoFrame)); response.setFields(outputFieldNames); } - + fillWithPredictionInterval(mojoFrame, scoreRequest, response); return response; } + /** + * Populate target column rows into outputRows. + * When prediction interval is returned from MOJO + * response frame, only one column rows will + * be populated into the outputRows to ensure + * backward compatible. + */ + private void fillOutputRows( + MojoFrame mojoFrame, List outputRows) { + List targetRows = getTargetRows(mojoFrame); + for (int rowIdx = 0; rowIdx < mojoFrame.getNrows(); rowIdx++) { + outputRows.get(rowIdx).addAll(targetRows.get(rowIdx)); + } + } + + /** + * Populate Prediction Interval value into response field. + * Only when score request set requestPredictionIntervals be true + * and MOJO pipeline support prediction interval. + */ + private void fillWithPredictionInterval( + MojoFrame mojoFrame, ScoreRequest scoreRequest, ScoreResponse scoreResponse) { + if (Boolean.TRUE.equals(scoreRequest.isRequestPredictionIntervals())) { + if (!supportPredictionInterval) { + throw new IllegalStateException( + "Unexpected error, prediction interval should be supported, but actually not"); + } + if (mojoFrame.getNcols() > 1) { + int targetIdx = getTargetColIdx(Arrays.asList(mojoFrame.getColumnNames())); + PredictionInterval predictionInterval = new PredictionInterval(); + predictionInterval.setFields(getPredictionIntervalFields(mojoFrame, targetIdx)); + predictionInterval.setRows(getPredictionIntervalRows(mojoFrame, targetIdx)); + scoreResponse.setPredictionIntervals(predictionInterval); + } else { + scoreResponse.setPredictionIntervals( + new PredictionInterval().fields(new Row()).rows(Collections.emptyList())); + } + } + } + + /** + * Extract target column rows from MOJO response frame. + * Note: To ensure backward compatibility, + * if prediction interval is enabled then extracts only one + * column rows from response columns. + */ + private List getTargetRows(MojoFrame mojoFrame) { + List taretRows = Stream + .generate(Row::new) + .limit(mojoFrame.getNrows()) + .collect(Collectors.toList()); + for (int row = 0; row < mojoFrame.getNrows(); row++) { + for (int col : getTargetFieldIndices(mojoFrame)) { + String cell = mojoFrame.getColumn(col).getDataAsStrings()[row]; + taretRows.get(row).add(cell); + } + } + return taretRows; + } + + /** + * Extract target columns from MOJO response frame. + * When prediction interval is enabled, extracts only one + * column from MOJO frame, otherwise all columns names + * will be extracted. + */ + private List getTargetField( + MojoFrame mojoFrame) { + if (mojoFrame.getNcols() > 0) { + List targetColumns = Arrays.asList(mojoFrame.getColumnNames()); + if (supportPredictionInterval) { + int targetIdx = getTargetColIdx(targetColumns); + if (targetIdx < 0) { + throw new IllegalStateException( + "Unexpected error, target column does not exist in MOJO response frame" + ); + } + return targetColumns.subList(targetIdx, targetIdx + 1); + } + return targetColumns; + } else { + return Collections.emptyList(); + } + } + + /** + * Extract target columns indices from MOJO response frame. + * When prediction interval is enabled, extracts only one + * column index from MOJO frame, otherwise all + * columns indices will be extracted. + */ + private List getTargetFieldIndices(MojoFrame mojoFrame) { + if (mojoFrame.getNcols() > 0) { + List targetColumns = Arrays.asList(mojoFrame.getColumnNames()); + if (supportPredictionInterval) { + int targetIdx = getTargetColIdx(targetColumns); + if (targetIdx < 0) { + throw new IllegalStateException( + "Unexpected error, target column does not exist in MOJO response frame" + ); + } + return Collections.singletonList(targetIdx); + } + return IntStream.range(0, mojoFrame.getNcols()).boxed().collect(Collectors.toList()); + } else { + return Collections.emptyList(); + } + } + + /** + * Extract prediction interval columns rows from MOJO response frame. + * Note: Assumption is prediction interval should already be enabled + * and response frame has expected structure. + */ + private List getPredictionIntervalRows(MojoFrame mojoFrame, int targetIdx) { + List predictionIntervalRows = Stream + .generate(Row::new) + .limit(mojoFrame.getNrows()) + .collect(Collectors.toList()); + for (int row = 0; row < mojoFrame.getNrows(); row++) { + for (int col = 0; col < mojoFrame.getNcols(); col++) { + if (col == targetIdx) { + continue; + } + String cell = mojoFrame.getColumn(col).getDataAsStrings()[row]; + predictionIntervalRows.get(row).add(cell); + } + } + return predictionIntervalRows; + } + + /** + * Extract prediction interval columns names from MOJO response frame. + * Note: Assumption is prediction interval should already be enabled + * and response frame has expected structure. + */ + private Row getPredictionIntervalFields(MojoFrame mojoFrame, int targetIdx) { + Row row = new Row(); + List mojoColumns = Arrays.asList(mojoFrame.getColumnNames()); + + row.addAll(mojoColumns.subList(0, targetIdx)); + row.addAll(mojoColumns.subList(targetIdx + 1, mojoFrame.getNcols())); + return row; + } + + /** + * Extract target column index from list of column names. + * Note: Assumption is prediction interval should already be enabled + * and response columns list has expected structure. + */ + private int getTargetColIdx(List mojoColumns) { + if (mojoColumns.size() == 1) { + return 0; + } + String[] columns = mojoColumns.toArray(new String[0]); + Arrays.sort(columns); + StringBuilder builder = new StringBuilder(); + for (int idx = 0, cmpIdx = columns.length - 1; idx < columns[0].length(); idx++) { + if (columns[0].charAt(idx) == columns[cmpIdx].charAt(idx)) { + builder.append(columns[0].charAt(idx)); + } else { + break; + } + } + return mojoColumns.indexOf(builder.toString()); + } + private static void copyFilteredInputFields( ScoreRequest scoreRequest, Set includedFields, List outputRows) { if (includedFields.isEmpty()) { diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java index 629f259b..395946cd 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java @@ -9,6 +9,7 @@ import ai.h2o.mojos.deploy.common.rest.model.ShapleyType; import ai.h2o.mojos.runtime.MojoPipeline; import ai.h2o.mojos.runtime.api.MojoPipelineService; +import ai.h2o.mojos.runtime.api.PipelineConfig; import ai.h2o.mojos.runtime.frame.MojoFrame; import ai.h2o.mojos.runtime.frame.MojoFrameMeta; import ai.h2o.mojos.runtime.lic.LicenseException; @@ -44,8 +45,11 @@ public class MojoScorer { private static final String MOJO_PIPELINE_PATH_PROPERTY = "mojo.path"; private static final String MOJO_PIPELINE_PATH = System.getProperty(MOJO_PIPELINE_PATH_PROPERTY); - private static final MojoPipeline pipeline = loadMojoPipelineFromFile(); - + public static final boolean supportPredictionInterval = checkIfPredictionIntervalSupport(); + private static final MojoPipeline pipeline = + supportPredictionInterval + ? loadMojoPipelineFromFile(buildPipelineConfigWithPredictionInterval()) + : loadMojoPipelineFromFile(); private final ShapleyLoadOption enabledShapleyTypes; private final boolean shapleyEnabled; private static MojoPipeline pipelineTransformedShapley; @@ -96,11 +100,19 @@ public MojoScorer( * @return response {@link ScoreResponse} */ public ScoreResponse score(ScoreRequest request) { + if (Boolean.TRUE.equals(request.isRequestPredictionIntervals()) + && !supportPredictionInterval) { + throw new IllegalArgumentException( + "requestPredictionIntervals set to true, but model does not support it" + ); + } + scoreRequestTransformer.accept(request, getModelInfo().getSchema().getInputFields()); MojoFrame requestFrame = scoreRequestConverter .apply(request, pipeline.getInputFrameBuilder()); MojoFrame responseFrame = doScore(requestFrame); - ScoreResponse response = scoreResponseConverter.apply(responseFrame, request); + ScoreResponse response = scoreResponseConverter.apply( + responseFrame, request); response.id(pipeline.getUuid()); ShapleyType requestShapleyType = request.getRequestShapleyValueType(); @@ -245,7 +257,8 @@ public ScoreResponse scoreCsv(String csvFilePath) throws IOException { requestFrame = csvConverter.apply(csvStream, pipeline.getInputFrameBuilder()); } MojoFrame responseFrame = doScore(requestFrame); - ScoreResponse response = scoreResponseConverter.apply(responseFrame, new ScoreRequest()); + ScoreResponse response = scoreResponseConverter.apply( + responseFrame, new ScoreRequest()); response.id(pipeline.getUuid()); return response; } @@ -319,6 +332,10 @@ public ShapleyLoadOption getEnabledShapleyTypes() { return enabledShapleyTypes; } + public boolean isPredictionIntervalSupport() { + return supportPredictionInterval; + } + /** * Method to load mojo pipelines for shapley scoring based on configuration * @@ -356,6 +373,32 @@ private void loadMojoPipelinesForShapley() { } private static MojoPipeline loadMojoPipelineFromFile() { + File mojoFile = getMojoFile(); + try { + MojoPipeline mojoPipeline = MojoPipelineService.loadPipeline(mojoFile); + log.info("Mojo pipeline successfully loaded ({}).", mojoPipeline.getUuid()); + return mojoPipeline; + } catch (IOException e) { + throw new RuntimeException("Unable to load mojo from " + mojoFile, e); + } catch (LicenseException e) { + throw new RuntimeException("License file not found", e); + } + } + + private static MojoPipeline loadMojoPipelineFromFile(PipelineConfig pipelineConfig) { + File mojoFile = getMojoFile(); + try { + MojoPipeline mojoPipeline = MojoPipelineService.loadPipeline(mojoFile, pipelineConfig); + log.info("Mojo pipeline successfully loaded ({}).", mojoPipeline.getUuid()); + return mojoPipeline; + } catch (IOException e) { + throw new RuntimeException("Unable to load mojo from " + mojoFile, e); + } catch (LicenseException e) { + throw new RuntimeException("License file not found", e); + } + } + + private static File getMojoFile() { Preconditions.checkArgument( !Strings.isNullOrEmpty(MOJO_PIPELINE_PATH), "Path to mojo pipeline not specified, set the %s property.", @@ -372,14 +415,25 @@ private static MojoPipeline loadMojoPipelineFromFile() { if (!mojoFile.isFile()) { throw new RuntimeException("Could not load mojo from file: " + mojoFile); } + return mojoFile; + } + + private static boolean checkIfPredictionIntervalSupport() { + File mojoFile = getMojoFile(); try { - MojoPipeline mojoPipeline = MojoPipelineService.loadPipeline(mojoFile); - log.info("Mojo pipeline successfully loaded ({}).", mojoPipeline.getUuid()); - return mojoPipeline; - } catch (IOException e) { + MojoPipelineService.loadPipeline(mojoFile, buildPipelineConfigWithPredictionInterval()); + return true; + } catch (IllegalArgumentException e) { + log.debug("Prediction interval is not supported for the given model", e); + return false; + } catch (IOException | LicenseException e) { throw new RuntimeException("Unable to load mojo from " + mojoFile, e); - } catch (LicenseException e) { - throw new RuntimeException("License file not found", e); } } + + private static PipelineConfig buildPipelineConfigWithPredictionInterval() { + PipelineConfig.Builder builder = PipelineConfig.builder(); + builder.withPredictionInterval(true); + return builder.build(); + } } diff --git a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverterTest.java b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverterTest.java index d0158c52..90d4ea7a 100644 --- a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverterTest.java +++ b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverterTest.java @@ -19,21 +19,25 @@ import ai.h2o.mojos.runtime.frame.MojoFrameBuilder; import ai.h2o.mojos.runtime.frame.MojoFrameMeta; import ai.h2o.mojos.runtime.frame.MojoRowBuilder; + +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Stream; + import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; class MojoFrameToScoreResponseConverterTest { - private final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); @Test void convertEmptyRowsResponse_succeeds() { // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); ScoreRequest scoreRequest = new ScoreRequest(); MojoFrame mojoFrame = new MojoFrameBuilder( @@ -51,13 +55,16 @@ void convertEmptyRowsResponse_succeeds() { @Test void convertSingleFieldResponse_succeeds() { // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); String[] fields = {"field"}; Type[] types = {Str}; String[][] values = {{"value"}}; ScoreRequest scoreRequest = new ScoreRequest(); // When - ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply( + buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()).containsExactly(asRow("value")); @@ -67,6 +74,8 @@ void convertSingleFieldResponse_succeeds() { @Test void convertSingleFieldResponse_withoutFieldNames_succeeds() { // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); String[] fields = {"field"}; Type[] types = {Str}; String[][] values = {{"value"}}; @@ -74,7 +83,8 @@ void convertSingleFieldResponse_withoutFieldNames_succeeds() { scoreRequest.setNoFieldNamesInOutput(true); // When - ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply( + buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()).containsExactly(asRow("value")); @@ -84,6 +94,8 @@ void convertSingleFieldResponse_withoutFieldNames_succeeds() { @Test void convertIncludesOneField_succeeds() { // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); String[] fields = {"outputField"}; Type[] types = {Str}; String[][] values = {{"outputValue"}}; @@ -93,7 +105,8 @@ void convertIncludesOneField_succeeds() { scoreRequest.addRowsItem(asRow("inputValue")); // When - ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply( + buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()).containsExactly(asRow("inputValue", "outputValue")); @@ -103,6 +116,8 @@ void convertIncludesOneField_succeeds() { @Test void convertIncludesSomeFields_succeeds() { // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); String[] fields = {"outputField1", "outputField2"}; Type[] types = {Str, Str}; String[][] values = {{"outputValue1", "outputValue2"}}; @@ -112,7 +127,8 @@ void convertIncludesSomeFields_succeeds() { scoreRequest.addRowsItem(asRow("inputValue1", "omittedValue", "inputValue3")); // When - ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply( + buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()) @@ -125,6 +141,8 @@ void convertIncludesSomeFields_succeeds() { @Test void convertIncludePresentIdField_succeeds() { // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); String[] fields = {"outputField"}; Type[] types = {Str}; String[][] values = {{"outputValue"}}; @@ -135,7 +153,8 @@ void convertIncludePresentIdField_succeeds() { scoreRequest.setIdField("id"); // When - ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply( + buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()).containsExactly(asRow("inputValue", "testId", "outputValue")); @@ -145,6 +164,8 @@ void convertIncludePresentIdField_succeeds() { @Test void convertIncludeMissingIdField_generateUuid() { // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); String[] fields = {"outputField"}; Type[] types = {Str}; String[][] values = {{"outputValue"}}; @@ -155,7 +176,8 @@ void convertIncludeMissingIdField_generateUuid() { scoreRequest.setIdField("id"); // When - ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply( + buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()).hasSize(1); @@ -171,13 +193,16 @@ void convertIncludeMissingIdField_generateUuid() { @Test void convertMoreRowsResponse_succeeds() { // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); String[] fields = {"field"}; Type[] types = {Str}; String[][] values = {{"value1"}, {"value2"}, {"value3"}}; ScoreRequest scoreRequest = new ScoreRequest(); // When - ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply( + buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()) @@ -191,36 +216,33 @@ void convertMoreRowsResponse_succeeds() { @MethodSource("provideValues_convertMoreTypesResponse_succeeds") void convertMoreTypesResponse_succeeds(String[][] values) { // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); Type[] types = {Str, Float32, Float64, Bool, Int32, Int64}; ScoreRequest scoreRequest = new ScoreRequest(); // When ScoreResponse result = converter.apply( - buildMojoFrame( - Stream.of(types).map(Object::toString).toArray(String[]::new), types, values), - scoreRequest); + buildMojoFrame( + Stream.of(types).map(Object::toString).toArray(String[]::new), types, values), + scoreRequest); // Then assertThat(result.getScore()) .containsExactly(Stream.of(values) - .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); + .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); assertThat(result.getFields()) - .containsExactly("Str", "Float32", "Float64", "Bool", "Int32", "Int64") - .inOrder(); - } - - @SuppressWarnings("unused") - private static Stream provideValues_convertMoreTypesResponse_succeeds() { - return Stream.of( - Arguments.of((Object) new String[][] {{"str", "1.1", "2.2", "1", "123", "123456789"}}), - Arguments.of((Object) new String[][] {{null, null, null, null, null, null}})); + .containsExactly("Str", "Float32", "Float64", "Bool", "Int32", "Int64") + .inOrder(); } @ParameterizedTest @MethodSource("provideValues_convertMoreTypesResponse_actualValues_succeeds") void convertMoreTypesResponse_actualValues_succeeds(Object[][] values, String[][] expValues) { // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); Type[] types = {Str, Float32, Float64, Bool, Int32, Int64}; ScoreRequest scoreRequest = new ScoreRequest(); @@ -231,7 +253,7 @@ void convertMoreTypesResponse_actualValues_succeeds(Object[][] values, String[][ Stream.of(types).map(Object::toString).toArray(String[]::new), types, values, - (rb, type, col, value) -> setJavaValue(rb, type, col, value)), + MojoFrameToScoreResponseConverterTest::setJavaValue), scoreRequest); // Then @@ -243,6 +265,187 @@ void convertMoreTypesResponse_actualValues_succeeds(Object[][] values, String[][ .inOrder(); } + @ParameterizedTest + @MethodSource("provideValues_predictionIntervalEnabledResponse_succeeds") + void convertMoreTypesResponse_enablePredictionIntervalSameType_succeeds( + Object[][] values, String[][] expValues) { + // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(true); + Type[] types = {Float64, Float64, Float64}; + ScoreRequest scoreRequest = new ScoreRequest().requestPredictionIntervals(true); + + // When + ScoreResponse result = + converter.apply( + buildMojoFrame( + new String[]{"result.upper", "result", "result.lower"}, + types, values, MojoFrameToScoreResponseConverterTest::setJavaValue), + scoreRequest); + + // Then + assertThat(result.getScore()) + .containsExactly( + Stream.of(expValues) + .map(input -> Arrays.asList(input).subList(1, 2)) + .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); + assertThat(result.getFields()) + .containsExactly("result") + .inOrder(); + assertThat(result.getPredictionIntervals().getFields()) + .containsExactly("result.upper", "result.lower") + .inOrder(); + assertThat(result.getPredictionIntervals().getRows()) + .containsExactly( + Stream.of(expValues) + .map(input -> { + List intervalRow = new ArrayList<>(2); + intervalRow.add(input[0]); + intervalRow.add(input[2]); + return intervalRow; + }) + .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); + } + + @ParameterizedTest + @MethodSource("provideValues_predictionIntervalEnabledResponse_succeeds") + void convertMoreTypesResponse_disablePredictionIntervalSameType_succeeds( + Object[][] values, String[][] expValues) { + // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(true); + Type[] types = {Float64, Float64, Float64}; + ScoreRequest scoreRequest = new ScoreRequest(); + + // When + ScoreResponse result = + converter.apply( + buildMojoFrame( + new String[]{"result.upper", "result", "result.lower"}, + types, values, MojoFrameToScoreResponseConverterTest::setJavaValue), + scoreRequest); + + // Then + assertThat(result.getScore()) + .containsExactly( + Stream.of(expValues) + .map(input -> Arrays.asList(input).subList(1, 2)) + .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); + assertThat(result.getFields()) + .containsExactly("result") + .inOrder(); + assertThat(result.getPredictionIntervals()) + .isNull(); + } + + @ParameterizedTest + @MethodSource("provideValues_predictionIntervalEnabledResponse_succeeds") + void convertMoreTypesResponse_disablePredictionIntervalNotSupportSameType_succeeds( + Object[][] values, String[][] expValues) { + // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); + Type[] types = {Float64, Float64, Float64}; + ScoreRequest scoreRequest = new ScoreRequest(); + + // When + ScoreResponse result = + converter.apply( + buildMojoFrame( + new String[]{"result.upper", "result", "result.lower"}, + types, values, MojoFrameToScoreResponseConverterTest::setJavaValue), + scoreRequest); + + // Then + assertThat(result.getScore()) + .containsExactly( + Stream.of(expValues) + .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); + assertThat(result.getFields()) + .containsExactly("result.upper", "result", "result.lower") + .inOrder(); + assertThat(result.getPredictionIntervals()) + .isNull(); + } + + @ParameterizedTest + @MethodSource("provideValue_predictionIntervalEnabledResponse_fails") + void convertMoreTypesResponse_enablePredictionIntervalDiffType_fails( + Object[][] values, Type[] types) { + // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(false); + ScoreRequest scoreRequest = new ScoreRequest().requestPredictionIntervals(true); + + // When & Then + try { + converter.apply( + buildMojoFrame( + Stream.of(types).map(Object::toString).toArray(String[]::new), + types, values, MojoFrameToScoreResponseConverterTest::setJavaValue), + scoreRequest); + } catch (Exception e) { + assertThat(e instanceof IllegalStateException).isTrue(); + } + } + + @ParameterizedTest + @MethodSource("provideValue_predictionIntervalEnabledResponse_fails") + void convertMoreTypesResponse_disablePredictionIntervalNotSupportDiffType_succeeds( + Object[][] values, Type[] types) { + // Given + final MojoFrameToScoreResponseConverter converter + = new MojoFrameToScoreResponseConverter(); + ScoreRequest scoreRequest = new ScoreRequest(); + + // When + ScoreResponse result = + converter.apply( + buildMojoFrame( + Stream.of(types).map(Object::toString).toArray(String[]::new), + types, values, MojoFrameToScoreResponseConverterTest::setJavaValue), + scoreRequest); + + // Then + assertThat(result.getFields()) + .containsExactly(Stream.of(types).map(Object::toString).toArray()) + .inOrder(); + assertThat(result.getPredictionIntervals()) + .isNull(); + } + + private static Stream provideValues_predictionIntervalEnabledResponse_succeeds() { + return Stream.of( + Arguments.of( + new Double[][]{{3.5, -1.0, 2.0}, {3.3, 11.9, 10.3}}, + new String[][]{{"3.5", "-1.0", "2.0"}, {"3.3", "11.9", "10.3"}}), + Arguments.of( + new Double[][]{{2.7, 3.4, 5.9}, {1.1, 2.2, 3.3}}, + new String[][]{{"2.7", "3.4", "5.9"}, {"1.1", "2.2", "3.3"}}) + ); + } + + private static Stream provideValue_predictionIntervalEnabledResponse_fails() { + return Stream.of( + Arguments.of(new Double[][]{{12.2, 11.221},{1.1, 99.1}}, new Type[]{Float64, Float64}), + Arguments.of(new Double[][]{{10.1}, {121.1}}, new Type[]{Float64}), + Arguments.of(new Double[][]{}, new Type[]{}), + Arguments.of(new Object[][]{ + {"abc", null, 12}, {"bbc", 12.4f, 15}}, new Type[]{Str, Float32, Int32}), + Arguments.of(new Object[][]{ + {90L, 1.21f, 12}, {11L, 12.4f, 15}}, new Type[]{Int64, Float32, Int32}), + Arguments.of(new Object[][]{ + {false, true, false}, {true, null, false}}, new Type[]{Bool, Bool, Bool}) + ); + } + + @SuppressWarnings("unused") + private static Stream provideValues_convertMoreTypesResponse_succeeds() { + return Stream.of( + Arguments.of((Object) new String[][] {{"str", "1.1", "2.2", "1", "123", "123456789"}}), + Arguments.of((Object) new String[][] {{null, null, null, null, null, null}})); + } + @SuppressWarnings("unused") private static Stream provideValues_convertMoreTypesResponse_actualValues_succeeds() { return Stream.of( @@ -300,6 +503,13 @@ private static Row asRow(String... values) { return row; } + private static Row asRow(List values) { + Row row = new Row(); + row.ensureCapacity(values.size()); + row.addAll(values); + return row; + } + private static Object[][] aao(Object... values) { return new Object[][] {values}; } diff --git a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoScorerTest.java b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoScorerTest.java index 5b6c65e5..df03cc8e 100644 --- a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoScorerTest.java +++ b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoScorerTest.java @@ -17,6 +17,7 @@ import ai.h2o.mojos.runtime.api.BasePipelineListener; import ai.h2o.mojos.runtime.api.MojoColumnMeta; import ai.h2o.mojos.runtime.api.MojoPipelineService; +import ai.h2o.mojos.runtime.api.PipelineConfig; import ai.h2o.mojos.runtime.frame.MojoColumn; import ai.h2o.mojos.runtime.frame.MojoFrame; import ai.h2o.mojos.runtime.frame.MojoFrameBuilder; @@ -35,6 +36,7 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -46,6 +48,7 @@ class MojoScorerTest { private static final String MOJO_PIPELINE_PATH = "src/test/resources/multinomial-pipeline.mojo"; private static final String TEST_UUID = "TEST_UUID"; + private static MockedStatic pipelineSettings = null; @Mock private ScoreRequestToMojoFrameConverter scoreRequestConverter; @Mock private MojoFrameToScoreResponseConverter scoreResponseConverter; @@ -62,11 +65,17 @@ static void setup() { } private static void mockDummyPipeline() { + if (pipelineSettings != null) { + pipelineSettings.close(); + } MojoPipeline dummyPipeline = new DummyPipeline(TEST_UUID, MojoFrameMeta.getEmpty(), MojoFrameMeta.getEmpty()); - MockedStatic theMock = Mockito.mockStatic(MojoPipelineService.class); - theMock.when(() -> MojoPipelineService - .loadPipeline(new File(MOJO_PIPELINE_PATH))).thenReturn(dummyPipeline); + pipelineSettings = Mockito.mockStatic(MojoPipelineService.class); + pipelineSettings.when(() -> MojoPipelineService + .loadPipeline(new File(MOJO_PIPELINE_PATH))).thenReturn(dummyPipeline); + pipelineSettings.when(() -> MojoPipelineService + .loadPipeline(Mockito.eq(new File(MOJO_PIPELINE_PATH)), any(PipelineConfig.class))) + .thenReturn(dummyPipeline); } @AfterAll diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/config/ScorerConfiguration.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/config/ScorerConfiguration.java index f967fbc3..7d1bbb70 100644 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/config/ScorerConfiguration.java +++ b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/config/ScorerConfiguration.java @@ -16,7 +16,7 @@ class ScorerConfiguration { @Bean public MojoFrameToScoreResponseConverter responseConverter() { - return new MojoFrameToScoreResponseConverter(); + return new MojoFrameToScoreResponseConverter(MojoScorer.supportPredictionInterval); } @Bean diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java index 20d8e648..5b5cf56b 100644 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java +++ b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java @@ -13,6 +13,7 @@ import ai.h2o.mojos.deploy.local.rest.error.ErrorUtil; import com.google.common.base.Strings; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -49,7 +50,7 @@ public ModelsApiController(MojoScorer scorer, SampleRequestBuilder sampleRequest this.scorer = scorer; this.sampleRequestBuilder = sampleRequestBuilder; this.supportedCapabilities = assembleSupportedCapabilities( - scorer.getEnabledShapleyTypes() + scorer.getEnabledShapleyTypes(), scorer.isPredictionIntervalSupport() ); } @@ -159,20 +160,28 @@ public ResponseEntity getSampleRequest() { } private static List assembleSupportedCapabilities( - ShapleyLoadOption enabledShapleyTypes) { + ShapleyLoadOption enabledShapleyTypes, boolean supportPredictionInterval) { + List result = new ArrayList<>(); + if (supportPredictionInterval) { + result.add(CapabilityType.SCORE_PREDICTION_INTERVAL); + } switch (enabledShapleyTypes) { case ALL: - return Arrays.asList( + result.addAll(Arrays.asList( CapabilityType.SCORE, CapabilityType.CONTRIBUTION_ORIGINAL, - CapabilityType.CONTRIBUTION_TRANSFORMED); + CapabilityType.CONTRIBUTION_TRANSFORMED)); + break; case ORIGINAL: - return Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_ORIGINAL); + result.addAll(Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_ORIGINAL)); + break; case TRANSFORMED: - return Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_TRANSFORMED); + result.addAll(Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_TRANSFORMED)); + break; case NONE: default: - return Arrays.asList(CapabilityType.SCORE); + result.add(CapabilityType.SCORE); } + return result; } } diff --git a/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java b/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java index c544b87a..e099cff7 100644 --- a/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java +++ b/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java @@ -13,6 +13,7 @@ import ai.h2o.mojos.deploy.common.transform.ShapleyLoadOption; import ai.h2o.mojos.runtime.MojoPipeline; import ai.h2o.mojos.runtime.api.MojoPipelineService; +import ai.h2o.mojos.runtime.api.PipelineConfig; import java.io.File; import java.io.IOException; @@ -46,7 +47,8 @@ private static void mockMojoPipeline(File tmpModel) { MojoPipeline mojoPipeline = Mockito.mock(MojoPipeline.class); MockedStatic theMock = Mockito.mockStatic(MojoPipelineService.class); theMock.when(() -> MojoPipelineService - .loadPipeline(new File(tmpModel.getAbsolutePath()))).thenReturn(mojoPipeline); + .loadPipeline(Mockito.eq(new File(tmpModel.getAbsolutePath())), any(PipelineConfig.class))) + .thenReturn(mojoPipeline); } @Test @@ -56,6 +58,7 @@ void verifyCapabilities_DefaultShapley_ReturnsExpected() { MojoScorer scorer = mock(MojoScorer.class); when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.NONE); + when(scorer.isPredictionIntervalSupport()).thenReturn(false); ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); @@ -75,6 +78,7 @@ void verifyCapabilities_AllShapleyEnabled_ReturnsExpected() { CapabilityType.CONTRIBUTION_TRANSFORMED); MojoScorer scorer = mock(MojoScorer.class); when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.ALL); + when(scorer.isPredictionIntervalSupport()).thenReturn(false); ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); @@ -93,6 +97,7 @@ void verifyCapabilities_OriginalShapleyEnabled_ReturnsExpected() { CapabilityType.CONTRIBUTION_ORIGINAL); MojoScorer scorer = mock(MojoScorer.class); when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.ORIGINAL); + when(scorer.isPredictionIntervalSupport()).thenReturn(false); ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); @@ -111,6 +116,7 @@ void verifyCapabilities_TransformedShapleyEnabled_ReturnsExpected() { CapabilityType.CONTRIBUTION_TRANSFORMED); MojoScorer scorer = mock(MojoScorer.class); when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.TRANSFORMED); + when(scorer.isPredictionIntervalSupport()).thenReturn(false); ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); @@ -126,6 +132,7 @@ void verifyScore_Fails_ReturnsException() { // Given MojoScorer scorer = mock(MojoScorer.class); when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.TRANSFORMED); + when(scorer.isPredictionIntervalSupport()).thenReturn(false); when(scorer.score(any())).thenThrow(new IllegalStateException("Test Exception")); ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); @@ -145,6 +152,7 @@ void verifyScoreByFile_Fails_ReturnsException() throws IOException { // Given MojoScorer scorer = mock(MojoScorer.class); when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.TRANSFORMED); + when(scorer.isPredictionIntervalSupport()).thenReturn(false); ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); @@ -163,6 +171,7 @@ void verifyScoreContribution_Fails_ReturnsException() { // Given MojoScorer scorer = mock(MojoScorer.class); when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.TRANSFORMED); + when(scorer.isPredictionIntervalSupport()).thenReturn(false); ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder);