Skip to content
This repository has been archived by the owner on Nov 29, 2024. It is now read-only.

Commit

Permalink
address memory issue caused by open api generator (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackjii79 authored and jakubhava committed Aug 26, 2024
1 parent 93263db commit b110c00
Show file tree
Hide file tree
Showing 14 changed files with 74 additions and 73 deletions.
1 change: 1 addition & 0 deletions common/rest-java-model/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ openApiGenerate {
invokerPackage = "ai.h2o.mojos.deploy.common.rest"
inputSpec = "$rootDir/common/swagger/v1openapi3/swagger.json"
outputDir = "$buildDir/gen"
generateAliasAsModel = true
globalProperties.set([
"skipFormModel": "false",
])
Expand Down
1 change: 1 addition & 0 deletions common/rest-jdbc-spring-api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ openApiGenerate {
generatorName = 'spring'
inputSpec = "$rootDir/common/swagger/v1/jdbc_swagger.json"
outputDir = "$buildDir/gen"
generateAliasAsModel = true
configOptions.set([
"useSpringBoot3": "true",
"interfaceOnly": "true",
Expand Down
1 change: 1 addition & 0 deletions common/rest-spring-api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ openApiGenerate {
generatorName = 'spring'
inputSpec = "$rootDir/common/swagger/v1openapi3/swagger.json"
outputDir = "$buildDir/gen"
generateAliasAsModel = true
globalProperties.set([
"skipFormModel": "false",
])
Expand Down
1 change: 1 addition & 0 deletions common/rest-vertex-ai-spring-api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ openApiGenerate {
generatorName = 'spring'
inputSpec = "$rootDir/common/swagger/v1/vertex-ai-swagger.json"
outputDir = "$buildDir/gen"
generateAliasAsModel = true
configOptions.set([
"useSpringBoot3": "true",
"interfaceOnly": "true",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ public class MojoFrameToContributionResponseConverter {
* ContributionResponse}.
*/
public ContributionResponse contributionResponseWithNoOutputGroup(MojoFrame shapleyMojoFrame) {
List<List<String>> outputRows =
Stream.generate(ArrayList<String>::new)
List<Row> outputRows =
Stream.generate(Row::new)
.limit(shapleyMojoFrame.getNrows())
.collect(Collectors.toList());
Utils.copyResultFields(shapleyMojoFrame, outputRows);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class MojoFrameToScoreResponseConverter
// Note: assumption is that pipeline supports Prediction interval.
// However for some h2o3 model, even classification model may still set
// this to be true.
private Boolean supportPredictionInterval;
private final Boolean supportPredictionInterval;

public MojoFrameToScoreResponseConverter(boolean supportPredictionInterval) {
this.supportPredictionInterval = supportPredictionInterval;
Expand All @@ -57,8 +57,8 @@ public MojoFrameToScoreResponseConverter() {
@Override
public ScoreResponse apply(MojoFrame mojoFrame, ScoreRequest scoreRequest) {
Set<String> includedFields = getSetOfIncludedFields(scoreRequest);
List<List<String>> outputRows =
Stream.generate(ArrayList<String>::new)
List<Row> outputRows =
Stream.generate(Row::new)
.limit(mojoFrame.getNrows())
.collect(Collectors.toList());
copyFilteredInputFields(scoreRequest, includedFields, outputRows);
Expand All @@ -81,7 +81,7 @@ public ScoreResponse apply(MojoFrame mojoFrame, ScoreRequest scoreRequest) {
* response frame, only one column rows will be populated into the outputRows to ensure backward
* compatible.
*/
private void fillOutputRows(MojoFrame mojoFrame, List<List<String>> outputRows) {
private void fillOutputRows(MojoFrame mojoFrame, List<Row> outputRows) {
List<List<String>> targetRows = getTargetRows(mojoFrame);
for (int rowIdx = 0; rowIdx < mojoFrame.getNrows(); rowIdx++) {
outputRows.get(rowIdx).addAll(targetRows.get(rowIdx));
Expand Down Expand Up @@ -182,9 +182,9 @@ private List<Integer> getTargetFieldIndices(MojoFrame mojoFrame) {
* Extract prediction interval columns rows from MOJO response frame. Note: Assumption is
* prediction interval should already be enabled and response frame has expected structure.
*/
private List<List<String>> getPredictionIntervalRows(MojoFrame mojoFrame, int targetIdx) {
List<List<String>> predictionIntervalRows =
Stream.generate(ArrayList<String>::new)
private List<Row> getPredictionIntervalRows(MojoFrame mojoFrame, int targetIdx) {
List<Row> predictionIntervalRows =
Stream.generate(Row::new)
.limit(mojoFrame.getNrows())
.collect(Collectors.toList());
for (int row = 0; row < mojoFrame.getNrows(); row++) {
Expand Down Expand Up @@ -234,12 +234,12 @@ private int getTargetColIdx(List<String> mojoColumns) {
}

private static void copyFilteredInputFields(
ScoreRequest scoreRequest, Set<String> includedFields, List<List<String>> outputRows) {
ScoreRequest scoreRequest, Set<String> includedFields, List<Row> outputRows) {
if (includedFields.isEmpty()) {
return;
}
boolean generateRowIds = shouldGenerateRowIds(scoreRequest, includedFields);
List<List<String>> inputRows = scoreRequest.getRows();
List<Row> inputRows = scoreRequest.getRows();
for (int row = 0; row < outputRows.size(); row++) {
List<String> inputRow = inputRows.get(row);
List<String> outputRow = outputRows.get(row);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import static java.util.Arrays.asList;

import ai.h2o.mojos.deploy.common.rest.model.Row;
import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import java.util.HashSet;
import java.util.List;

/** Checks that the request is of the correct form matching the corresponding mojo pipeline. */
Expand Down Expand Up @@ -36,15 +38,15 @@ private String getProblemMessageOrNull(ScoreRequest scoreRequest, MojoFrameMeta
if (fields == null || fields.isEmpty()) {
return "List of input fields cannot be empty";
}
List<List<String>> rows = scoreRequest.getRows();
List<Row> rows = scoreRequest.getRows();
if (rows == null || rows.isEmpty()) {
return "List of input data rows cannot be empty";
}
List<String> expectedFields = asList(expectedMeta.getColumnNames());
if (!fields.containsAll(expectedFields)) {
if (!new HashSet<>(fields).containsAll(expectedFields)) {
return String.format(
"Input fields don't contain all the Mojo fields, expected %s actual %s",
expectedFields.toString(), fields.toString());
expectedFields, fields);
}
int i = 0;
for (List<String> row : scoreRequest.getRows()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package ai.h2o.mojos.deploy.common.transform;

import ai.h2o.mojos.deploy.common.rest.model.DataField;
import ai.h2o.mojos.deploy.common.rest.model.Row;
import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
Expand All @@ -30,35 +30,29 @@ public void accept(ScoreRequest scoreRequest, List<DataField> dataFields) {
transformRow(scoreRequest.getFields(), scoreRequest.getRows(), dataFieldMap));
}

private List<List<String>> transformRow(
List<String> fields, List<List<String>> rows, Map<String, DataField> dataFields) {
private List<Row> transformRow(
List<String> fields, List<Row> rows, Map<String, DataField> dataFields) {
return rows.stream()
.map(
row -> {
List<String> transformData =
IntStream.range(0, row.size())
.mapToObj(
fieldIdx -> {
String colName = fields.get(fieldIdx);
String origin = row.get(fieldIdx);
if (dataFields.containsKey(colName)) {
String sanitizeValue =
Utils.sanitizeBoolean(
origin, dataFields.get(colName).getDataType());
if (!sanitizeValue.equals(origin)) {
logger.debug("Value '{}' parsed as '{}'", origin, sanitizeValue);
}
return sanitizeValue;
} else {
logger.debug("Column '{}' can not be found in Input schema", colName);
return origin;
row ->
IntStream.range(0, row.size())
.mapToObj(
fieldIdx -> {
String colName = fields.get(fieldIdx);
String origin = row.get(fieldIdx);
if (dataFields.containsKey(colName)) {
String sanitizeValue =
Utils.sanitizeBoolean(
origin, dataFields.get(colName).getDataType());
if (!sanitizeValue.equals(origin)) {
logger.debug("Value '{}' parsed as '{}'", origin, sanitizeValue);
}
})
.collect(Collectors.toList());
List<String> transformedRow = new ArrayList<>();
transformedRow.addAll(transformData);
return transformedRow;
})
.collect(Collectors.toList());
return sanitizeValue;
} else {
logger.debug("Column '{}' can not be found in Input schema", colName);
return origin;
}
})
.collect(Collectors.toCollection(Row::new))).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.h2o.mojos.deploy.common.transform;

import ai.h2o.mojos.deploy.common.rest.model.DataField;
import ai.h2o.mojos.deploy.common.rest.model.Row;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import java.util.List;

Expand All @@ -10,7 +11,7 @@ public class Utils {
*
* @param mojoFrame {@link MojoFrame}
*/
public static void copyResultFields(MojoFrame mojoFrame, List<List<String>> outputRows) {
public static void copyResultFields(MojoFrame mojoFrame, List<Row> outputRows) {
String[][] outputColumns = new String[mojoFrame.getNcols()][];
for (int col = 0; col < mojoFrame.getNcols(); col++) {
outputColumns[col] = mojoFrame.getColumn(col).getDataAsStrings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ private MojoFrame generateDummyTransformedMojoFrame() {

private ScoreResponse generateDummyResponse() {
ScoreResponse response = new ScoreResponse();
List<List<String>> outputRows =
Stream.generate(ArrayList<String>::new).limit(4).collect(Collectors.toList());
List<Row> outputRows =
Stream.generate(Row::new).limit(4).collect(Collectors.toList());
response.setScore(outputRows);
response.setFields(Arrays.asList("field1"));
return response;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ void transform_BooleanLiteral_Transformed() {
// Given
ScoreRequest scoreRequest = new ScoreRequest();
scoreRequest.setFields(Collections.singletonList("test"));
List<List<String>> rows =
List<Row> rows =
new ArrayList<>(
Arrays.asList(
new ArrayList<>(),
new ArrayList<>(),
new ArrayList<>(),
new ArrayList<>(),
new ArrayList<>(),
new ArrayList<>(),
new ArrayList<>()));
new Row(),
new Row(),
new Row(),
new Row(),
new Row(),
new Row(),
new Row()));
rows.get(0).addAll(Collections.singletonList("true"));
rows.get(1).addAll(Collections.singletonList("False"));
rows.get(2).addAll(Collections.singletonList("TrUE"));
Expand All @@ -65,11 +65,11 @@ void transform_BooleanLiteral_Transformed() {
scoreRequestTransformer.accept(scoreRequest, dataFields);

// Then
List<List<String>> expected =
List<Row> expected =
new ArrayList<>(
Arrays.asList(
new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
new ArrayList<>(), new ArrayList<>(), new ArrayList<>()));
new Row(), new Row(), new Row(), new Row(),
new Row(), new Row(), new Row()));
expected.get(0).addAll(Collections.singletonList("1"));
expected.get(1).addAll(Collections.singletonList("0"));
expected.get(2).addAll(Collections.singletonList("1"));
Expand All @@ -85,7 +85,7 @@ void transform_NonBooleanLiteral_Unchanged() {
// Given
ScoreRequest scoreRequest = new ScoreRequest();
scoreRequest.setFields(Collections.singletonList("test"));
List<List<String>> rows = new ArrayList<>(Arrays.asList(new Row(), new Row()));
List<Row> rows = new ArrayList<>(Arrays.asList(new Row(), new Row()));
rows.get(0).addAll(Collections.singletonList("unchangedFeature1"));
rows.get(1).addAll(Collections.singletonList("unchangedFeature2"));
scoreRequest.setRows(rows);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public static ScoreRequest getRestScoreRequest(
}
}

List<String> row;
Row row;
for (List<String> gcpRow : gcpRequest.getInstances()) {
row = new Row();
for (int i = 0; i < gcpRow.size(); i++) {
Expand All @@ -149,9 +149,9 @@ public static ScoreResponse getGcpScoreResponse(
response.setId(restResponse.getId());
response.setFields(restResponse.getFields());

List<String> row;
ai.h2o.mojos.deploy.common.rest.vertex.ai.model.Row row;
for (List<String> restRow : restResponse.getScore()) {
row = new ArrayList<>();
row = new ai.h2o.mojos.deploy.common.rest.vertex.ai.model.Row();
for (int i = 0; i < restRow.size(); i++) {
row.add(restRow.get(i));
}
Expand Down
4 changes: 2 additions & 2 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ errorproneVersion = 2.23.0
# Docker settings
dockerRepositoryPrefix = harbor.h2o.ai/opsh2oai/h2oai/
dockerIncludePython = true
# Digest of eclipse-temurin:17.0.9_9-jdk-alpine
javaBaseImage = eclipse-temurin@sha256:24643c2dd329ef482ecd042b59cbfb7fe13716342e22674a0abd763559c8a1dd
# Digest of eclipse-temurin:17.0.10_7-jdk-alpine
javaBaseImage = eclipse-temurin@sha256:0e6e494ac4da6509a038b7689250bc7ea68beaf8a5efbca5ed7c8692457b283c

# Increase timeouts to avoid read error from OSS Nexus
# See:
Expand Down
22 changes: 11 additions & 11 deletions settings.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@ dependencyResolutionManagement {
}

rootProject.name = 'dai-deployment-templates'
include 'aws-lambda-scorer:lambda-template'
include 'aws-lambda-scorer:terraform-recipe'
include 'common:jdbc'
//include 'aws-lambda-scorer:lambda-template'
//include 'aws-lambda-scorer:terraform-recipe'
//include 'common:jdbc'
include 'common:rest-java-model'
include 'common:rest-jdbc-spring-api'
//include 'common:rest-jdbc-spring-api'
include 'common:rest-spring-api'
include 'common:transform'
include 'common:kdb-java'
include 'common:rest-vertex-ai-spring-api'
//include 'common:kdb-java'
//include 'common:rest-vertex-ai-spring-api'
include 'local-rest-scorer'
include 'kdb-mojo-scorer'
include 'aws-sagemaker-hosted-scorer'
include 'gcp-cloud-run'
include 'sql-jdbc-scorer'
include 'gcp-vertex-ai-mojo-scorer'
//include 'kdb-mojo-scorer'
//include 'aws-sagemaker-hosted-scorer'
//include 'gcp-cloud-run'
//include 'sql-jdbc-scorer'
//include 'gcp-vertex-ai-mojo-scorer'

0 comments on commit b110c00

Please sign in to comment.