Skip to content
This repository has been archived by the owner on Nov 29, 2024. It is now read-only.

Commit

Permalink
Merge pull request #354 from h2oai/issue-1756/implement-prediction-in…
Browse files Browse the repository at this point in the history
…terval

Issue 1756/implement prediction interval
  • Loading branch information
jackjii79 authored May 18, 2023
2 parents c1b9a50 + 09de615 commit 09b49db
Show file tree
Hide file tree
Showing 7 changed files with 534 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
package ai.h2o.mojos.deploy.common.transform;

import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;

import ai.h2o.mojos.deploy.common.rest.model.ContributionResponse;
import ai.h2o.mojos.deploy.common.rest.model.PredictionInterval;
import ai.h2o.mojos.deploy.common.rest.model.Row;
import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest;
import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import com.google.common.base.Strings;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
Expand All @@ -27,26 +30,215 @@
public class MojoFrameToScoreResponseConverter
implements BiFunction<MojoFrame, ScoreRequest, ScoreResponse> {

private static final String LOWER_BOUND = ".lower";
private static final String UPPER_BOUND = ".upper";

// If true then pipeline support prediction interval, otherwise false.
// Note: guarantee that pipeline supports Prediction interval.
private Boolean supportPredictionInterval;

public MojoFrameToScoreResponseConverter(boolean supportPredictionInterval) {
this.supportPredictionInterval = supportPredictionInterval;
}

public MojoFrameToScoreResponseConverter() {
supportPredictionInterval = false;
}

/**
* Transform MOJO response frame into ScoreResponse.
* @param mojoFrame mojo response frame.
* @param scoreRequest score request.
* @return score response.
*/
@Override
public ScoreResponse apply(MojoFrame mojoFrame, ScoreRequest scoreRequest) {
public ScoreResponse apply(
MojoFrame mojoFrame, ScoreRequest scoreRequest) {
Set<String> includedFields = getSetOfIncludedFields(scoreRequest);
List<Row> outputRows =
Stream.generate(Row::new).limit(mojoFrame.getNrows()).collect(Collectors.toList());
copyFilteredInputFields(scoreRequest, includedFields, outputRows);
Utils.copyResultFields(mojoFrame, outputRows);
fillOutputRows(mojoFrame, outputRows);

ScoreResponse response = new ScoreResponse();
response.setScore(outputRows);

if (!Boolean.TRUE.equals(scoreRequest.isNoFieldNamesInOutput())) {
List<String> outputFieldNames = getFilteredInputFieldNames(scoreRequest, includedFields);
outputFieldNames.addAll(asList(mojoFrame.getColumnNames()));
outputFieldNames.addAll(getTargetField(mojoFrame));
response.setFields(outputFieldNames);
}

fillWithPredictionInterval(mojoFrame, scoreRequest, response);
return response;
}

/**
* Populate target column rows into outputRows.
* When prediction interval is returned from MOJO
* response frame, only one column rows will
* be populated into the outputRows to ensure
* backward compatible.
*/
private void fillOutputRows(
MojoFrame mojoFrame, List<Row> outputRows) {
List<Row> targetRows = getTargetRows(mojoFrame);
for (int rowIdx = 0; rowIdx < mojoFrame.getNrows(); rowIdx++) {
outputRows.get(rowIdx).addAll(targetRows.get(rowIdx));
}
}

/**
* Populate Prediction Interval value into response field.
* Only when score request set requestPredictionIntervals be true
* and MOJO pipeline support prediction interval.
*/
private void fillWithPredictionInterval(
MojoFrame mojoFrame, ScoreRequest scoreRequest, ScoreResponse scoreResponse) {
if (Boolean.TRUE.equals(scoreRequest.isRequestPredictionIntervals())) {
if (!supportPredictionInterval) {
throw new IllegalStateException(
"Unexpected error, prediction interval should be supported, but actually not");
}
if (mojoFrame.getNcols() > 1) {
int targetIdx = getTargetColIdx(Arrays.asList(mojoFrame.getColumnNames()));
PredictionInterval predictionInterval = new PredictionInterval();
predictionInterval.setFields(getPredictionIntervalFields(mojoFrame, targetIdx));
predictionInterval.setRows(getPredictionIntervalRows(mojoFrame, targetIdx));
scoreResponse.setPredictionIntervals(predictionInterval);
} else {
scoreResponse.setPredictionIntervals(
new PredictionInterval().fields(new Row()).rows(Collections.emptyList()));
}
}
}

/**
* Extract target column rows from MOJO response frame.
* Note: To ensure backward compatibility,
* if prediction interval is enabled then extracts only one
* column rows from response columns.
*/
private List<Row> getTargetRows(MojoFrame mojoFrame) {
List<Row> taretRows = Stream
.generate(Row::new)
.limit(mojoFrame.getNrows())
.collect(Collectors.toList());
for (int row = 0; row < mojoFrame.getNrows(); row++) {
for (int col : getTargetFieldIndices(mojoFrame)) {
String cell = mojoFrame.getColumn(col).getDataAsStrings()[row];
taretRows.get(row).add(cell);
}
}
return taretRows;
}

/**
* Extract target columns from MOJO response frame.
* When prediction interval is enabled, extracts only one
* column from MOJO frame, otherwise all columns names
* will be extracted.
*/
private List<String> getTargetField(
MojoFrame mojoFrame) {
if (mojoFrame.getNcols() > 0) {
List<String> targetColumns = Arrays.asList(mojoFrame.getColumnNames());
if (supportPredictionInterval) {
int targetIdx = getTargetColIdx(targetColumns);
if (targetIdx < 0) {
throw new IllegalStateException(
"Unexpected error, target column does not exist in MOJO response frame"
);
}
return targetColumns.subList(targetIdx, targetIdx + 1);
}
return targetColumns;
} else {
return Collections.emptyList();
}
}

/**
* Extract target columns indices from MOJO response frame.
* When prediction interval is enabled, extracts only one
* column index from MOJO frame, otherwise all
* columns indices will be extracted.
*/
private List<Integer> getTargetFieldIndices(MojoFrame mojoFrame) {
if (mojoFrame.getNcols() > 0) {
List<String> targetColumns = Arrays.asList(mojoFrame.getColumnNames());
if (supportPredictionInterval) {
int targetIdx = getTargetColIdx(targetColumns);
if (targetIdx < 0) {
throw new IllegalStateException(
"Unexpected error, target column does not exist in MOJO response frame"
);
}
return Collections.singletonList(targetIdx);
}
return IntStream.range(0, mojoFrame.getNcols()).boxed().collect(Collectors.toList());
} else {
return Collections.emptyList();
}
}

/**
* Extract prediction interval columns rows from MOJO response frame.
* Note: Assumption is prediction interval should already be enabled
* and response frame has expected structure.
*/
private List<Row> getPredictionIntervalRows(MojoFrame mojoFrame, int targetIdx) {
List<Row> predictionIntervalRows = Stream
.generate(Row::new)
.limit(mojoFrame.getNrows())
.collect(Collectors.toList());
for (int row = 0; row < mojoFrame.getNrows(); row++) {
for (int col = 0; col < mojoFrame.getNcols(); col++) {
if (col == targetIdx) {
continue;
}
String cell = mojoFrame.getColumn(col).getDataAsStrings()[row];
predictionIntervalRows.get(row).add(cell);
}
}
return predictionIntervalRows;
}

/**
* Extract prediction interval columns names from MOJO response frame.
* Note: Assumption is prediction interval should already be enabled
* and response frame has expected structure.
*/
private Row getPredictionIntervalFields(MojoFrame mojoFrame, int targetIdx) {
Row row = new Row();
List<String> mojoColumns = Arrays.asList(mojoFrame.getColumnNames());

row.addAll(mojoColumns.subList(0, targetIdx));
row.addAll(mojoColumns.subList(targetIdx + 1, mojoFrame.getNcols()));
return row;
}

/**
* Extract target column index from list of column names.
* Note: Assumption is prediction interval should already be enabled
* and response columns list has expected structure.
*/
private int getTargetColIdx(List<String> mojoColumns) {
if (mojoColumns.size() == 1) {
return 0;
}
String[] columns = mojoColumns.toArray(new String[0]);
Arrays.sort(columns);
StringBuilder builder = new StringBuilder();
for (int idx = 0, cmpIdx = columns.length - 1; idx < columns[0].length(); idx++) {
if (columns[0].charAt(idx) == columns[cmpIdx].charAt(idx)) {
builder.append(columns[0].charAt(idx));
} else {
break;
}
}
return mojoColumns.indexOf(builder.toString());
}

private static void copyFilteredInputFields(
ScoreRequest scoreRequest, Set<String> includedFields, List<Row> outputRows) {
if (includedFields.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ai.h2o.mojos.deploy.common.rest.model.ShapleyType;
import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.api.MojoPipelineService;
import ai.h2o.mojos.runtime.api.PipelineConfig;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.lic.LicenseException;
Expand Down Expand Up @@ -44,8 +45,11 @@ public class MojoScorer {

private static final String MOJO_PIPELINE_PATH_PROPERTY = "mojo.path";
private static final String MOJO_PIPELINE_PATH = System.getProperty(MOJO_PIPELINE_PATH_PROPERTY);
private static final MojoPipeline pipeline = loadMojoPipelineFromFile();

public static final boolean supportPredictionInterval = checkIfPredictionIntervalSupport();
private static final MojoPipeline pipeline =
supportPredictionInterval
? loadMojoPipelineFromFile(buildPipelineConfigWithPredictionInterval())
: loadMojoPipelineFromFile();
private final ShapleyLoadOption enabledShapleyTypes;
private final boolean shapleyEnabled;
private static MojoPipeline pipelineTransformedShapley;
Expand Down Expand Up @@ -96,11 +100,19 @@ public MojoScorer(
* @return response {@link ScoreResponse}
*/
public ScoreResponse score(ScoreRequest request) {
if (Boolean.TRUE.equals(request.isRequestPredictionIntervals())
&& !supportPredictionInterval) {
throw new IllegalArgumentException(
"requestPredictionIntervals set to true, but model does not support it"
);
}

scoreRequestTransformer.accept(request, getModelInfo().getSchema().getInputFields());
MojoFrame requestFrame = scoreRequestConverter
.apply(request, pipeline.getInputFrameBuilder());
MojoFrame responseFrame = doScore(requestFrame);
ScoreResponse response = scoreResponseConverter.apply(responseFrame, request);
ScoreResponse response = scoreResponseConverter.apply(
responseFrame, request);
response.id(pipeline.getUuid());

ShapleyType requestShapleyType = request.getRequestShapleyValueType();
Expand Down Expand Up @@ -245,7 +257,8 @@ public ScoreResponse scoreCsv(String csvFilePath) throws IOException {
requestFrame = csvConverter.apply(csvStream, pipeline.getInputFrameBuilder());
}
MojoFrame responseFrame = doScore(requestFrame);
ScoreResponse response = scoreResponseConverter.apply(responseFrame, new ScoreRequest());
ScoreResponse response = scoreResponseConverter.apply(
responseFrame, new ScoreRequest());
response.id(pipeline.getUuid());
return response;
}
Expand Down Expand Up @@ -319,6 +332,10 @@ public ShapleyLoadOption getEnabledShapleyTypes() {
return enabledShapleyTypes;
}

public boolean isPredictionIntervalSupport() {
return supportPredictionInterval;
}

/**
* Method to load mojo pipelines for shapley scoring based on configuration
*
Expand Down Expand Up @@ -356,6 +373,32 @@ private void loadMojoPipelinesForShapley() {
}

private static MojoPipeline loadMojoPipelineFromFile() {
File mojoFile = getMojoFile();
try {
MojoPipeline mojoPipeline = MojoPipelineService.loadPipeline(mojoFile);
log.info("Mojo pipeline successfully loaded ({}).", mojoPipeline.getUuid());
return mojoPipeline;
} catch (IOException e) {
throw new RuntimeException("Unable to load mojo from " + mojoFile, e);
} catch (LicenseException e) {
throw new RuntimeException("License file not found", e);
}
}

private static MojoPipeline loadMojoPipelineFromFile(PipelineConfig pipelineConfig) {
File mojoFile = getMojoFile();
try {
MojoPipeline mojoPipeline = MojoPipelineService.loadPipeline(mojoFile, pipelineConfig);
log.info("Mojo pipeline successfully loaded ({}).", mojoPipeline.getUuid());
return mojoPipeline;
} catch (IOException e) {
throw new RuntimeException("Unable to load mojo from " + mojoFile, e);
} catch (LicenseException e) {
throw new RuntimeException("License file not found", e);
}
}

private static File getMojoFile() {
Preconditions.checkArgument(
!Strings.isNullOrEmpty(MOJO_PIPELINE_PATH),
"Path to mojo pipeline not specified, set the %s property.",
Expand All @@ -372,14 +415,25 @@ private static MojoPipeline loadMojoPipelineFromFile() {
if (!mojoFile.isFile()) {
throw new RuntimeException("Could not load mojo from file: " + mojoFile);
}
return mojoFile;
}

private static boolean checkIfPredictionIntervalSupport() {
File mojoFile = getMojoFile();
try {
MojoPipeline mojoPipeline = MojoPipelineService.loadPipeline(mojoFile);
log.info("Mojo pipeline successfully loaded ({}).", mojoPipeline.getUuid());
return mojoPipeline;
} catch (IOException e) {
MojoPipelineService.loadPipeline(mojoFile, buildPipelineConfigWithPredictionInterval());
return true;
} catch (IllegalArgumentException e) {
log.debug("Prediction interval is not supported for the given model", e);
return false;
} catch (IOException | LicenseException e) {
throw new RuntimeException("Unable to load mojo from " + mojoFile, e);
} catch (LicenseException e) {
throw new RuntimeException("License file not found", e);
}
}

private static PipelineConfig buildPipelineConfigWithPredictionInterval() {
PipelineConfig.Builder builder = PipelineConfig.builder();
builder.withPredictionInterval(true);
return builder.build();
}
}
Loading

0 comments on commit 09b49db

Please sign in to comment.