From e7f0d2f8bac7a43c93d5e41a9dcc486308b250da Mon Sep 17 00:00:00 2001 From: jackjii79 Date: Mon, 29 Jan 2024 13:26:35 -0800 Subject: [PATCH] fix ci (#378) --- .github/workflows/codeql-analysis.yml | 5 +- .github/workflows/snyk-scan.yml | 4 +- Jenkinsfile | 22 +- .../lambda-template/build.gradle | 8 + .../mojos/deploy/aws/lambda/MojoScorer.java | 8 +- aws-sagemaker-hosted-scorer/build.gradle | 4 +- .../hosted/config/ScorerConfiguration.java | 14 +- .../controller/ModelsApiController.java | 11 +- build.gradle | 7 +- common/jdbc/build.gradle | 3 + common/rest-java-model/build.gradle | 59 +- common/rest-jdbc-spring-api/build.gradle | 41 +- common/rest-spring-api/build.gradle | 53 +- common/rest-vertex-ai-spring-api/build.gradle | 41 +- common/swagger/v1/jdbc_swagger.json | 253 ++++++++ common/swagger/v1/swagger.json | 496 ++++++++++++++ common/swagger/v1/vertex-ai-swagger.json | 251 ++++++++ common/swagger/v1exp/swagger.json | 97 +++ common/swagger/v1openapi3/swagger.json | 604 ++++++++++++++++++ common/transform/build.gradle | 27 +- ...ntributionRequestToMojoFrameConverter.java | 9 +- ...oFrameToContributionResponseConverter.java | 20 +- .../MojoFrameToScoreResponseConverter.java | 105 ++- .../deploy/common/transform/MojoScorer.java | 102 ++- .../common/transform/RequestChecker.java | 5 +- .../transform/SampleRequestBuilder.java | 9 +- .../ScoreRequestToMojoFrameConverter.java | 3 +- .../transform/ScoreRequestTransformer.java | 72 +-- .../common/transform/ShapleyLoadOption.java | 11 +- .../mojos/deploy/common/transform/Utils.java | 13 +- ...meToContributionResponseConverterTest.java | 104 +-- ...MojoFrameToScoreResponseConverterTest.java | 238 ++++--- .../MojoPipelineToModelInfoConverterTest.java | 5 +- .../common/transform/MojoScorerTest.java | 184 +++--- .../common/transform/RequestCheckerTest.java | 1 - .../ScoreRequestToMojoFrameConverterTest.java | 4 +- .../ScoreRequestTransformerTest.java | 28 +- gcp-cloud-run/build.gradle | 4 +- .../gcp/vertex/ai/GcpVertexAiApplication.java | 4 +- .../ai/config/EnvironmentConfiguration.java | 7 +- .../vertex/ai/config/ScorerConfiguration.java | 14 +- .../ai/controller/ModelsApiController.java | 64 +- gradle.properties | 34 +- gradle/java.gradle | 1 - gradle/java_no_style.gradle | 4 +- gradle/mixins/dependencies.gradle | 17 +- gradle/mixins/spotless.gradle | 12 - init.gradle | 12 + kdb-mojo-scorer/build.gradle | 8 +- local-rest-scorer/build.gradle | 20 +- .../rest/config/ScorerConfiguration.java | 14 +- .../rest/controller/ModelsApiController.java | 57 +- .../controller/ModelsMediaController.java | 26 - .../converter/ScoreMediaRequestConverter.java | 2 +- .../deploy/local/rest/error/ErrorUtil.java | 4 +- .../rest/error/ModelsExceptionHandler.java | 51 +- .../controller/ModelsApiControllerTest.java | 48 +- settings.gradle | 15 + sql-jdbc-scorer/build.gradle | 3 +- 59 files changed, 2517 insertions(+), 825 deletions(-) create mode 100644 common/swagger/v1/jdbc_swagger.json create mode 100644 common/swagger/v1/swagger.json create mode 100644 common/swagger/v1/vertex-ai-swagger.json create mode 100644 common/swagger/v1exp/swagger.json create mode 100644 common/swagger/v1openapi3/swagger.json delete mode 100644 gradle/mixins/spotless.gradle create mode 100644 init.gradle delete mode 100644 local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsMediaController.java diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 1ba1cf7d..388772fc 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -35,7 +35,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL @@ -63,8 +63,7 @@ jobs: - if: matrix.language == 'java' name: Build Java uses: gradle/gradle-build-action@v2 - with: - arguments: assemble + run: ./gradlew --init-script init.gradle assemble - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/snyk-scan.yml b/.github/workflows/snyk-scan.yml index 6a37d580..0aeff08a 100644 --- a/.github/workflows/snyk-scan.yml +++ b/.github/workflows/snyk-scan.yml @@ -28,7 +28,7 @@ jobs: - uses: snyk/actions/setup@master - uses: actions/setup-java@v3 with: - java-version: "8" + java-version: "17" distribution: 'adopt' - name: Snyk scan for Java dependencies @@ -76,7 +76,7 @@ jobs: - uses: snyk/actions/setup@master - uses: actions/setup-java@v3 with: - java-version: "8" + java-version: "17" distribution: 'adopt' - name: Snyk scan for Java dependencies - ${{ matrix.depsfiles }} diff --git a/Jenkinsfile b/Jenkinsfile index 21c85a48..e88f9dfd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -4,7 +4,7 @@ import ai.h2o.ci.Utils -JAVA_IMAGE = 'harbor.h2o.ai/dockerhub-proxy/library/openjdk:8u222-jdk-slim' +JAVA_IMAGE = 'harbor.h2o.ai/dockerhub-proxy/library/openjdk:17-jdk-slim' NODE_LABEL = 'docker' DOCKERHUB_CREDS = 'dockerhub' HARBOR_URL = "http://harbor.h2o.ai/" @@ -43,9 +43,9 @@ pipeline { description: 'Whether to also push distribution ZIP archive to S3.', ) booleanParam( - name: 'PUSH_TO_VORVAN', - defaultValue: false, - description: 'Whether to also push Docker images to h2o.ai maintained gcr.io repo Vorvan.', + name: 'PUSH_TO_VORVAN', + defaultValue: false, + description: 'Whether to also push Docker images to h2o.ai maintained gcr.io repo Vorvan.', ) } @@ -76,7 +76,7 @@ pipeline { script { versionText = getVersion() echo "Version: ${versionText}" - sh "./gradlew check" + sh "./gradlew --init-script init.gradle check" } } } @@ -100,7 +100,7 @@ pipeline { steps { timeout(time: 60, unit: 'MINUTES') { script { - sh "./gradlew distributionZip" + sh "./gradlew --init-script init.gradle distributionZip" if (isReleaseVersion(versionText)) { utilsLib.appendBuildDescription("Release ${versionText}") } @@ -169,7 +169,7 @@ pipeline { def imageTags = isMasterBranch() || isReleaseBranch() ? "${versionText},${gitCommitHash}" : "${gitCommitHash}" withDockerCredentials(DOCKERHUB_CREDS, "FROM_") { withDockerCredentials("harbor.h2o.ai", "TO_") { - sh "./gradlew jib \ + sh "./gradlew --init-script init.gradle jib \ -Djib.to.auth.username=${TO_DOCKER_USERNAME} \ -Djib.to.auth.password=${TO_DOCKER_PASSWORD} \ -Djib.from.auth.username=${FROM_DOCKER_USERNAME} \ @@ -205,7 +205,7 @@ pipeline { def imageTags = isMasterBranch() || isReleaseBranch() ? "${versionText},${gitCommitHash}" : "${gitCommitHash}" withDockerCredentials(DOCKERHUB_CREDS, "FROM_") { withDockerCredentials(DOCKERHUB_CREDS, "TO_") { - sh "./gradlew jib \ + sh "./gradlew --init-script init.gradle jib \ -Djib.to.auth.username=${TO_DOCKER_USERNAME} \ -Djib.to.auth.password=${TO_DOCKER_PASSWORD} \ -Djib.from.auth.username=${FROM_DOCKER_USERNAME} \ @@ -242,7 +242,7 @@ pipeline { withGCRCredentials(VORVAN_CRED) { def gcrCreds = readFile("${GCR_JSON_KEY}") withEnv(['TO_DOCKER_USERNAME=_json_key', "TO_DOCKER_PASSWORD=${gcrCreds}"]) { - sh "./gradlew jib \ + sh "./gradlew --init-script init.gradle jib \ -Djib.from.auth.username=${FROM_DOCKER_USERNAME} \ -Djib.from.auth.password=${FROM_DOCKER_PASSWORD} \ -Djib.to.tags=${imageTags} \ @@ -264,8 +264,8 @@ pipeline { */ def getVersion() { def version = sh( - script: "./gradlew -q -Dorg.gradle.internal.launcher.welcomeMessageEnabled=false printVersion", - returnStdout: true).trim() + script: "./gradlew --init-script init.gradle -q -Dorg.gradle.internal.launcher.welcomeMessageEnabled=false printVersion", + returnStdout: true).trim().tokenize("\n").last() if (!version) { error "Version must be set" } diff --git a/aws-lambda-scorer/lambda-template/build.gradle b/aws-lambda-scorer/lambda-template/build.gradle index 779200db..45efa486 100644 --- a/aws-lambda-scorer/lambda-template/build.gradle +++ b/aws-lambda-scorer/lambda-template/build.gradle @@ -1,3 +1,7 @@ +plugins { + id 'org.springframework.boot' +} + apply from: project(":").file('gradle/java.gradle') dependencies { @@ -16,6 +20,10 @@ dependencies { testRuntimeOnly group: 'org.junit.jupiter', name: 'junit-jupiter-engine' } +bootJar { + enabled=false +} + test { useJUnitPlatform() } diff --git a/aws-lambda-scorer/lambda-template/src/main/java/ai/h2o/mojos/deploy/aws/lambda/MojoScorer.java b/aws-lambda-scorer/lambda-template/src/main/java/ai/h2o/mojos/deploy/aws/lambda/MojoScorer.java index 2b7e09a1..a93ba630 100644 --- a/aws-lambda-scorer/lambda-template/src/main/java/ai/h2o/mojos/deploy/aws/lambda/MojoScorer.java +++ b/aws-lambda-scorer/lambda-template/src/main/java/ai/h2o/mojos/deploy/aws/lambda/MojoScorer.java @@ -35,10 +35,10 @@ public final class MojoScorer { private static final Object pipelineLock = new Object(); private static MojoPipeline pipeline; - private final ScoreRequestToMojoFrameConverter requestConverter - = new ScoreRequestToMojoFrameConverter(); - private final MojoFrameToScoreResponseConverter responseConverter - = new MojoFrameToScoreResponseConverter(); + private final ScoreRequestToMojoFrameConverter requestConverter = + new ScoreRequestToMojoFrameConverter(); + private final MojoFrameToScoreResponseConverter responseConverter = + new MojoFrameToScoreResponseConverter(); private final RequestChecker requestChecker = new RequestChecker(new SampleRequestBuilder()); /** Processes a single {@link ScoreRequest} in the given AWS Lambda {@link Context}. */ diff --git a/aws-sagemaker-hosted-scorer/build.gradle b/aws-sagemaker-hosted-scorer/build.gradle index 4d4dfa26..77d30014 100644 --- a/aws-sagemaker-hosted-scorer/build.gradle +++ b/aws-sagemaker-hosted-scorer/build.gradle @@ -12,8 +12,8 @@ dependencies { implementation group: 'io.springfox', name: 'springfox-boot-starter', version: springFoxVersion implementation group: 'com.google.guava', name: 'guava', version: guavaVersion implementation group: 'org.springframework.boot', name: 'spring-boot-starter-web' - implementation group: 'org.apache.tomcat.embed', name: 'tomcat-embed-core', version: '9.0.63' - implementation group: 'org.apache.tomcat.embed', name: 'tomcat-embed-websocket', version: '9.0.63' + implementation group: 'org.apache.tomcat.embed', name: 'tomcat-embed-core', version: tomcatVersion + implementation group: 'org.apache.tomcat.embed', name: 'tomcat-embed-websocket', version: tomcatVersion } test { diff --git a/aws-sagemaker-hosted-scorer/src/main/java/ai/h2o/mojos/deploy/sagemaker/hosted/config/ScorerConfiguration.java b/aws-sagemaker-hosted-scorer/src/main/java/ai/h2o/mojos/deploy/sagemaker/hosted/config/ScorerConfiguration.java index a6e4ce3c..bcf6b203 100644 --- a/aws-sagemaker-hosted-scorer/src/main/java/ai/h2o/mojos/deploy/sagemaker/hosted/config/ScorerConfiguration.java +++ b/aws-sagemaker-hosted-scorer/src/main/java/ai/h2o/mojos/deploy/sagemaker/hosted/config/ScorerConfiguration.java @@ -64,12 +64,12 @@ public MojoScorer mojoScorer( ScoreRequestTransformer scoreRequestTransformer, CsvToMojoFrameConverter csvConverter) { return new MojoScorer( - requestConverter, - responseConverter, - contributionRequestConverter, - contributionResponseConverter, - modelInfoConverter, - scoreRequestTransformer, - csvConverter); + requestConverter, + responseConverter, + contributionRequestConverter, + contributionResponseConverter, + modelInfoConverter, + scoreRequestTransformer, + csvConverter); } } diff --git a/aws-sagemaker-hosted-scorer/src/main/java/ai/h2o/mojos/deploy/sagemaker/hosted/controller/ModelsApiController.java b/aws-sagemaker-hosted-scorer/src/main/java/ai/h2o/mojos/deploy/sagemaker/hosted/controller/ModelsApiController.java index 9344af61..96995f3e 100644 --- a/aws-sagemaker-hosted-scorer/src/main/java/ai/h2o/mojos/deploy/sagemaker/hosted/controller/ModelsApiController.java +++ b/aws-sagemaker-hosted-scorer/src/main/java/ai/h2o/mojos/deploy/sagemaker/hosted/controller/ModelsApiController.java @@ -13,7 +13,6 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -25,10 +24,9 @@ @Controller public class ModelsApiController implements ModelApi { - private static final String UNIMPLEMENTED_MESSAGE - = "Shapley values are not implemented yet"; - private static final List SUPPORTED_CAPABILITIES - = Arrays.asList(CapabilityType.SCORE); + private static final String UNIMPLEMENTED_MESSAGE = "Shapley values are not implemented yet"; + private static final List SUPPORTED_CAPABILITIES = + Arrays.asList(CapabilityType.SCORE); private static final Logger log = LoggerFactory.getLogger(ModelsApiController.class); private final MojoScorer scorer; @@ -101,8 +99,7 @@ public ResponseEntity getScoreByFile(String file) { } @Override - public ResponseEntity getContribution( - ContributionRequest request) { + public ResponseEntity getContribution(ContributionRequest request) { // TODO: to be implemented in the future log.info(" Unsupported operation: " + UNIMPLEMENTED_MESSAGE); return ResponseEntity.status(HttpStatus.NOT_IMPLEMENTED).build(); diff --git a/build.gradle b/build.gradle index ca8cecd0..b1170a66 100644 --- a/build.gradle +++ b/build.gradle @@ -8,15 +8,14 @@ buildscript { dependencies { classpath group: 'org.springframework.boot', name: 'spring-boot-gradle-plugin', version: springBootPluginVersion - classpath group: 'gradle.plugin.org.hidetake', name: 'gradle-swagger-generator-plugin', - version: swaggerGradlePluginVersion - classpath group: 'com.github.jengelman.gradle.plugins', name: 'shadow', + classpath group: 'gradle.plugin.com.github.johnrengelman', name: 'shadow', version: shadowJarVersion -// classpath group: 'com.diffplug.spotless', name: 'spotless-plugin-gradle', version: spotlessPluginVersion classpath group: 'net.ltgt.gradle', name: 'gradle-errorprone-plugin', version: errorpronePluginVersion classpath group: 'com.google.cloud.tools.jib', name: 'com.google.cloud.tools.jib.gradle.plugin', version: jibPluginVersion + classpath group: 'org.openapitools', name: 'openapi-generator-gradle-plugin', + version: openApiGeneratorGradlePluginVersion } } diff --git a/common/jdbc/build.gradle b/common/jdbc/build.gradle index e118178f..b218f4f7 100644 --- a/common/jdbc/build.gradle +++ b/common/jdbc/build.gradle @@ -15,6 +15,9 @@ dependencies { implementation group: 'org.apache.spark', name: 'spark-sql_2.12' implementation group: 'org.apache.spark', name: 'spark-mllib_2.12' implementation group: 'com.typesafe', name:'config' + implementation group: 'org.slf4j', name: 'slf4j-api', version: '1.7.36' + implementation group: 'ch.qos.logback', name: 'logback-classic', version: '1.0.13' + implementation group: 'ch.qos.logback', name: 'logback-core', version: '1.0.13' testImplementation group: 'org.scalatest', name: 'scalatest_2.12', version: '3.0.5' testRuntimeOnly group:'org.scala-lang.modules', name: 'scala-xml_2.12', version: '1.1.1' diff --git a/common/rest-java-model/build.gradle b/common/rest-java-model/build.gradle index 109e4b14..60135688 100644 --- a/common/rest-java-model/build.gradle +++ b/common/rest-java-model/build.gradle @@ -1,35 +1,48 @@ plugins { - id 'org.hidetake.swagger.generator' + id 'org.springframework.boot' + id 'org.openapi.generator' } apply from: project(":").file('gradle/java.gradle') dependencies { implementation group: 'com.google.code.gson', name: 'gson' - implementation group: 'javax.annotation', name: 'javax.annotation-api' implementation group: 'io.swagger.core.v3', name: 'swagger-annotations' - swaggerCodegen group: 'io.swagger.codegen.v3', name: 'swagger-codegen-cli' + implementation group: 'org.springframework.boot', name: 'spring-boot-starter-web' + implementation group: 'org.springframework.boot', name: 'spring-boot-starter-validation' + implementation group: 'org.openapitools', name: 'jackson-databind-nullable' } -swaggerSources { - model { - inputFile = file('../swagger/v1/swagger.yaml') - code { - language = 'java' - configFile = file('../swagger/v1/swagger_codegen.json') - components = [models: true] - dependsOn validation - } - } - modelV1Exp { - inputFile = file('../swagger/v1exp/swagger.yaml') - code { - language = 'java' - configFile = file('../swagger/v1exp/swagger_codegen.json') - components = [models: true] - dependsOn validation +openApiValidate { + inputSpec = "$rootDir/common/swagger/v1openapi3/swagger.json" + recommend = true +} + +openApiGenerate { + generatorName = 'spring' + packageName = "ai.h2o.mojos.deploy.common.rest" + invokerPackage = "ai.h2o.mojos.deploy.common.rest" + inputSpec = "$rootDir/common/swagger/v1openapi3/swagger.json" + outputDir = "$buildDir/gen" + globalProperties.set([ + "skipFormModel": "false", + ]) + configOptions.set([ + "useSpringBoot3": "true", + "interfaceOnly": "true", + "basePackage": "ai.h2o.mojos.deploy.common.rest", + "modelPackage": "ai.h2o.mojos.deploy.common.rest.model", + ]) +} + +bootJar { + enabled=false +} + +compileJava.dependsOn tasks.openApiValidate, tasks.openApiGenerate +sourceSets { + main { + java { + srcDir("$buildDir/gen/src/main/java") } } } - -compileJava.dependsOn swaggerSources.model.code, swaggerSources.modelV1Exp.code -sourceSets.main.java.srcDirs "${swaggerSources.model.code.outputDir}/src/main/java", "${swaggerSources.modelV1Exp.code.outputDir}/src/main/java" diff --git a/common/rest-jdbc-spring-api/build.gradle b/common/rest-jdbc-spring-api/build.gradle index 26c0e249..640a6bd4 100644 --- a/common/rest-jdbc-spring-api/build.gradle +++ b/common/rest-jdbc-spring-api/build.gradle @@ -1,6 +1,6 @@ plugins { id 'org.springframework.boot' - id 'org.hidetake.swagger.generator' + id 'org.openapi.generator' } apply from: project(":").file('gradle/java.gradle') @@ -8,19 +8,26 @@ dependencies { implementation group: 'io.swagger.core.v3', name: 'swagger-annotations' implementation group: 'org.springframework.boot', name: 'spring-boot-starter-web' implementation group: 'org.springframework.boot', name: 'spring-boot-starter-validation' - swaggerCodegen group: 'io.swagger.codegen.v3', name: 'swagger-codegen-cli' + implementation group: 'org.openapitools', name: 'jackson-databind-nullable' } -swaggerSources { - api { - inputFile = file('../swagger/v1/jdbc-swagger.yaml') - code { - language = 'spring' - configFile = file('../swagger/v1/jdbc_swagger_codegen.json') - components = [models: true, apis: true] - dependsOn validation - } - } +openApiValidate { + inputSpec = "$rootDir/common/swagger/v1/jdbc_swagger.json" + recommend = true +} + +openApiGenerate { + generatorName = 'spring' + inputSpec = "$rootDir/common/swagger/v1/jdbc_swagger.json" + outputDir = "$buildDir/gen" + configOptions.set([ + "useSpringBoot3": "true", + "interfaceOnly": "true", + "basePackage": "ai.h2o.mojos.deploy.common.rest.jdbc", + "configPackage": "ai.h2o.mojos.deploy.common.rest.jdbc.config", + "apiPackage": "ai.h2o.mojos.deploy.common.rest.jdbc.api", + "modelPackage": "ai.h2o.mojos.deploy.common.rest.jdbc.model", + ]) } jar { @@ -32,5 +39,11 @@ bootJar { enabled=false } -compileJava.dependsOn swaggerSources.api.code -sourceSets.main.java.srcDir "${swaggerSources.api.code.outputDir}/src/main/java" \ No newline at end of file +compileJava.dependsOn tasks.openApiValidate, tasks.openApiGenerate +sourceSets { + main { + java { + srcDir("$buildDir/gen/src/main/java") + } + } +} diff --git a/common/rest-spring-api/build.gradle b/common/rest-spring-api/build.gradle index f81c3ed8..7bb85789 100644 --- a/common/rest-spring-api/build.gradle +++ b/common/rest-spring-api/build.gradle @@ -1,6 +1,6 @@ plugins { id 'org.springframework.boot' - id 'org.hidetake.swagger.generator' + id 'org.openapi.generator' } apply from: project(":").file('gradle/java.gradle') @@ -8,28 +8,29 @@ dependencies { implementation group: 'io.swagger.core.v3', name: 'swagger-annotations' implementation group: 'org.springframework.boot', name: 'spring-boot-starter-web' implementation group: 'org.springframework.boot', name: 'spring-boot-starter-validation' - swaggerCodegen group: 'io.swagger.codegen.v3', name: 'swagger-codegen-cli' + implementation group: 'org.openapitools', name: 'jackson-databind-nullable' } -swaggerSources { - api { - inputFile = file('../swagger/v1/swagger.yaml') - code { - language = 'spring' - configFile = file('../swagger/v1/swagger_codegen.json') - components = [models: true, apis: true] - dependsOn validation - } - } - apiV1Exp { - inputFile = file('../swagger/v1exp/swagger.yaml') - code { - language = 'spring' - configFile = file('../swagger/v1exp/swagger_codegen.json') - components = [models: true, apis: true] - dependsOn validation - } - } +openApiValidate { + inputSpec = "$rootDir/common/swagger/v1openapi3/swagger.json" + recommend = true +} + +openApiGenerate { + generatorName = 'spring' + inputSpec = "$rootDir/common/swagger/v1openapi3/swagger.json" + outputDir = "$buildDir/gen" + globalProperties.set([ + "skipFormModel": "false", + ]) + configOptions.set([ + "useSpringBoot3": "true", + "interfaceOnly": "true", + "basePackage": "ai.h2o.mojos.deploy.common.rest", + "configPackage": "ai.h2o.mojos.deploy.common.rest.config", + "apiPackage": "ai.h2o.mojos.deploy.common.rest.api", + "modelPackage": "ai.h2o.mojos.deploy.common.rest.model", + ]) } jar { @@ -41,5 +42,11 @@ bootJar { enabled=false } -compileJava.dependsOn swaggerSources.api.code, swaggerSources.apiV1Exp.code -sourceSets.main.java.srcDirs "${swaggerSources.api.code.outputDir}/src/main/java", "${swaggerSources.apiV1Exp.code.outputDir}/src/main/java" +compileJava.dependsOn tasks.openApiValidate, tasks.openApiGenerate +sourceSets { + main { + java { + srcDir("$buildDir/gen/src/main/java") + } + } +} diff --git a/common/rest-vertex-ai-spring-api/build.gradle b/common/rest-vertex-ai-spring-api/build.gradle index 6e7fdd12..f636917a 100644 --- a/common/rest-vertex-ai-spring-api/build.gradle +++ b/common/rest-vertex-ai-spring-api/build.gradle @@ -1,6 +1,6 @@ plugins { id 'org.springframework.boot' - id 'org.hidetake.swagger.generator' + id 'org.openapi.generator' } apply from: project(":").file('gradle/java.gradle') @@ -8,19 +8,26 @@ dependencies { implementation group: 'io.swagger.core.v3', name: 'swagger-annotations' implementation group: 'org.springframework.boot', name: 'spring-boot-starter-web' implementation group: 'org.springframework.boot', name: 'spring-boot-starter-validation' - swaggerCodegen group: 'io.swagger.codegen.v3', name: 'swagger-codegen-cli' + implementation group: 'org.openapitools', name: 'jackson-databind-nullable' } -swaggerSources { - api { - inputFile = file('../swagger/v1/vertex-ai-swagger.yaml') - code { - language = 'spring' - configFile = file('../swagger/v1/vertex_ai_swagger_codegen.json') - components = [models: true, apis: true] - dependsOn validation - } - } +openApiValidate { + inputSpec = "$rootDir/common/swagger/v1/vertex-ai-swagger.json" + recommend = true +} + +openApiGenerate { + generatorName = 'spring' + inputSpec = "$rootDir/common/swagger/v1/vertex-ai-swagger.json" + outputDir = "$buildDir/gen" + configOptions.set([ + "useSpringBoot3": "true", + "interfaceOnly": "true", + "basePackage": "ai.h2o.mojos.deploy.common.rest.vertex.ai", + "configPackage": "ai.h2o.mojos.deploy.common.rest.vertex.ai.config", + "apiPackage": "ai.h2o.mojos.deploy.common.rest.vertex.ai.api", + "modelPackage": "ai.h2o.mojos.deploy.common.rest.vertex.ai.model", + ]) } jar { @@ -32,5 +39,11 @@ bootJar { enabled=false } -compileJava.dependsOn swaggerSources.api.code -sourceSets.main.java.srcDir "${swaggerSources.api.code.outputDir}/src/main/java" +compileJava.dependsOn tasks.openApiValidate, tasks.openApiGenerate +sourceSets { + main { + java { + srcDir("$buildDir/gen/src/main/java") + } + } +} diff --git a/common/swagger/v1/jdbc_swagger.json b/common/swagger/v1/jdbc_swagger.json new file mode 100644 index 00000000..e38ae92a --- /dev/null +++ b/common/swagger/v1/jdbc_swagger.json @@ -0,0 +1,253 @@ +{ + "swagger": "2.0", + "info": { + "description": "This is a definition of the REST API for scoring from H2O. This API is intended to be used within DAI and eventually across all H2O scoring systems", + "version": "1.0.0", + "title": "Scoring API - v1", + "termsOfService": "", + "contact": { + "email": "support@h2o.ai" + }, + "license": { + "name": "License", + "url": "http://www.h2o.ai" + } + }, + "host": "localhost", + "basePath": "/", + "schemes": [ + "https", + "http" + ], + "paths": { + "/model/id": { + "get": { + "tags": [ + "metadata" + ], + "summary": "Returns model id", + "description": "Returns unique id of the model loaded in the server and used for scoring", + "operationId": "getModelId", + "produces": [ + "text/plain" + ], + "parameters": [], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "type": "string" + } + } + }, + "security": [ + { + "api_key": [] + } + ] + } + }, + "/model/schema": { + "get": { + "tags": [ + "metadata" + ], + "summary": "Describe a model", + "description": "Returns information about the model used for scoring, e.g., input schema.", + "operationId": "getModelInfo", + "produces": [ + "application/json" + ], + "parameters": [], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "$ref": "#/definitions/Model" + } + } + } + } + }, + "/model/score": { + "post": { + "tags": [ + "scoring" + ], + "summary": "Score on given rows", + "description": "Computes score of the rows sent in the body of the post request", + "operationId": "getScore", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/ScoreRequest" + } + } + ], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "$ref": "#/definitions/ScoreResponse" + } + }, + "400": { + "description": "Invalid payload" + } + } + }, + "get": { + "tags": [ + "scoring" + ], + "summary": "Score on given file", + "description": "Computes score of the rows in the file specified by the path in the query parameter", + "operationId": "getScoreByGet", + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "query", + "name": "sqlQuery", + "type": "string" + }, + { + "in": "query", + "name": "outputTable", + "type": "string" + }, + { + "in": "query", + "name": "idColumn", + "type": "string", + "required": false + }, + { + "in": "query", + "name": "saveMethod", + "type": "string", + "enum": [ + "preview", + "append", + "overwrite", + "ignore", + "error" + ] + } + ], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "$ref": "#/definitions/ScoreResponse" + } + }, + "400": { + "description": "Invalid payload" + } + } + } + } + }, + "securityDefinitions": { + "api_key": { + "type": "apiKey", + "name": "api_key", + "in": "header" + } + }, + "definitions": { + "Model": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "schema": { + "$ref": "#/definitions/ModelSchema" + } + } + }, + "ScoreRequest": { + "type": "object", + "required": [ + "sqlQuery", + "outputTable", + "saveMethod" + ], + "properties": { + "sqlQuery": { + "description": "A string that is the SQL statement to be used for querying the configured database\n", + "type": "string" + }, + "outputTable": { + "description": "A string containing the name of the sql table to which the scored data will be written to.\n", + "type": "string" + }, + "idColumn": { + "description": "A string containing column name matching the ID column/Primary key of table being queried\n", + "type": "string" + }, + "saveMethod": { + "description": "String dictating how the resultant scored data should be written. can be one of: preview - does not write to configured sql database, but returns sample which can be reviewed, append - if output table already exists, data will be appended to that table, overwrite - if the output table already exists, the table will be overwritten with new data, ignore - if output table already exists, the new data will be dropped and nothing will be written, error - if the output table already exists, the new data will be dropped and an error will be thrown\n", + "type": "string", + "enum": [ + "preview", + "append", + "overwrite", + "ignore", + "error" + ] + } + } + }, + "ScoreResponse": { + "type": "object", + "properties": { + "id": { + "description": "A unique id of the model used for scoring.\n", + "type": "string" + }, + "success": { + "description": "A boolean value dictating whether the scoring request was successful\n", + "type": "boolean" + }, + "previewScores": { + "description": "first few rows, up to 5, of the resultant scored dataset\n", + "type": "array", + "items": { + "type": "string" + } + }, + "previewColumns": { + "description": "names of output columns for scoring preview\n", + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "ModelSchema": { + "type": "object", + "properties": { + "inputFields": { + "type": "array", + "items": { + "type": "string" + } + } + } + } + } +} \ No newline at end of file diff --git a/common/swagger/v1/swagger.json b/common/swagger/v1/swagger.json new file mode 100644 index 00000000..2e599b34 --- /dev/null +++ b/common/swagger/v1/swagger.json @@ -0,0 +1,496 @@ +{ + "swagger": "2.0", + "info": { + "description": "This is a definition of the REST API for scoring from H2O. This API is intended to be used within DAI and eventually across all H2O scoring systems", + "version": "1.2.0", + "title": "Scoring API - v1", + "termsOfService": "", + "contact": { + "email": "support@h2o.ai" + }, + "license": { + "name": "License", + "url": "http://www.h2o.ai" + } + }, + "host": "localhost", + "basePath": "/", + "schemes": [ + "https", + "http" + ], + "paths": { + "/model/id": { + "get": { + "tags": [ + "metadata" + ], + "summary": "Returns model id", + "description": "Returns unique id of the model loaded in the server and used for scoring", + "operationId": "getModelId", + "produces": [ + "text/plain" + ], + "parameters": [], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "type": "string" + } + } + }, + "security": [ + { + "api_key": [] + } + ] + } + }, + "/model/schema": { + "get": { + "tags": [ + "metadata" + ], + "summary": "Describe a model", + "description": "Returns information about the model used for scoring, e.g., input schema.", + "operationId": "getModelInfo", + "produces": [ + "application/json" + ], + "parameters": [], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "$ref": "#/definitions/Model" + } + } + } + } + }, + "/model/sample_request": { + "get": { + "tags": [ + "metadata" + ], + "summary": "Sample scoring request", + "description": "Builds a sample scoring request that would pass all validations", + "operationId": "getSampleRequest", + "produces": [ + "application/json" + ], + "parameters": [], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "$ref": "#/definitions/ScoreRequest" + } + } + } + } + }, + "/model/capabilities": { + "get": { + "tags": [ + "metadata" + ], + "summary": "List capabilities supported by the scorer.", + "description": "Returns the capabilities supported by the scorer.", + "operationId": "getCapabilities", + "produces": [ + "application/json" + ], + "parameters": [], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/CapabilityType" + } + } + } + } + } + }, + "/model/score": { + "post": { + "tags": [ + "scoring" + ], + "summary": "Score on given rows", + "description": "Computes score of the rows sent in the body of the post request", + "operationId": "getScore", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/ScoreRequest" + } + } + ], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "$ref": "#/definitions/ScoreResponse" + } + }, + "400": { + "description": "Invalid payload" + } + } + }, + "get": { + "tags": [ + "scoring" + ], + "summary": "Score on given file", + "description": "Computes score of the rows in the file specified by the path in the query parameter", + "operationId": "getScoreByFile", + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "query", + "name": "file", + "type": "string" + } + ], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "$ref": "#/definitions/ScoreResponse" + } + }, + "400": { + "description": "Invalid payload" + } + } + } + }, + "/model/contribution": { + "post": { + "tags": [ + "contribution" + ], + "summary": "Contribution score or Shapley values on given rows", + "description": "Computes contribution score with the rows sent in the body of the post request", + "operationId": "getContribution", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/ContributionRequest" + } + } + ], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "$ref": "#/definitions/ContributionResponse" + } + }, + "400": { + "description": "Invalid payload" + }, + "501": { + "description": "Implementation not supported" + } + } + } + } + }, + "securityDefinitions": { + "api_key": { + "type": "apiKey", + "name": "api_key", + "in": "header" + } + }, + "definitions": { + "Model": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "properties": { + "type": "object", + "properties": { + "scoringType": { + "type": "string", + "$ref": "#/definitions/ScoringType" + }, + "scoringResponLabels": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "schema": { + "$ref": "#/definitions/ModelSchema" + } + } + }, + "Row": { + "type": "array", + "items": { + "type": "string" + } + }, + "PredictionInterval": { + "type": "object", + "properties": { + "fields": { + "$ref": "#/definitions/Row" + }, + "rows": { + "type": "array", + "items": { + "$ref": "#/definitions/Row" + } + } + } + }, + "ContributionRequest": { + "type": "object", + "required": [ + "requestShapleyValueType" + ], + "properties": { + "requestShapleyValueType": { + "description": "The string to say what type of Shap values are needed. `ORIGINAL` implies Shap values of original features are requested, `TRANSFORMED` implies that Shap values of transformed features are requested.\n", + "$ref": "#/definitions/ShapleyType" + }, + "fields": { + "description": "An array holding the names of fields in the order of appearance in the `rows` property. The length of `fields` has to match length of each row in `rows`. No duplicates are allowed.\n", + "type": "array", + "items": { + "type": "string" + } + }, + "rows": { + "description": "An array of rows consisting the actual input data for scoring, one scoring request per row.\n", + "type": "array", + "items": { + "$ref": "#/definitions/Row" + } + } + } + }, + "ContributionResponse": { + "type": "object", + "properties": { + "features": { + "description": "An array holding the names of fields in the order of appearance in the rows of the `contributions` property.\n", + "type": "array", + "items": { + "type": "string" + } + }, + "contributionGroups": { + "description": "An array of rows consisting of the shapley contributions output corresponding to an output group.\n", + "type": "array", + "items": { + "$ref": "#/definitions/ContributionGroup" + } + } + } + }, + "ContributionGroup": { + "type": "object", + "properties": { + "outputGroup": { + "description": "Name of the output group. It will be populated only for multinomial models. Shapley values are not supported for third party models yet, hence this field will not be populated.\n", + "type": "string" + }, + "contributions": { + "description": "An array of rows consisting of the shapley contributions output corresponding to columns in the fields\n", + "type": "array", + "items": { + "$ref": "#/definitions/Row" + } + } + } + }, + "ScoreRequest": { + "type": "object", + "properties": { + "requestShapleyValueType": { + "description": "The string to say what type of Shap values are needed. `ORIGINAL` implies Shap values of original features are requested, `TRANSFORMED` implies that Shap values of transformed features are requested.\n", + "$ref": "#/definitions/ShapleyType" + }, + "includeFieldsInOutput": { + "description": "An array holding the list of field names to be copied from the input request row to the corresponding scoring output. It is an error to specify a field name not present in the `fields` property, except when it is equal to the `idField` property. In the latter case, the row id would be generated and returned in the response. Note that the order of items in `includeFieldsInOutput` is ignored and the specified fields are returned in the order of appearance in the input request row.\n", + "type": "array", + "items": { + "type": "string" + } + }, + "noFieldNamesInOutput": { + "description": "If set to `true`. The scorer will not fill response column names in the `fields` field. This is can be useful to maintain compatibility with older scorer versions or to save bandwidth.\n", + "type": "boolean" + }, + "idField": { + "description": "Name of the field that holds a row id, e.g., a value that uniquely identifies each row of the request. The caller may specify a name of the field that is not present in fields. In which case, the scorer is allowed to generate a UUID to identify each row (e.g., for logging and monitoring purposes). To retrieve such a generated id as a part of the response, simply name it in the `includeFieldsInOutput`.\n", + "type": "string" + }, + "fields": { + "description": "An array holding the names of fields in the order of appearance in the `rows` property. The length of `fields` has to match length of each row in `rows`. No duplicates are allowed.\n", + "type": "array", + "items": { + "type": "string" + } + }, + "rows": { + "description": "An array of rows consisting the actual input data for scoring, one scoring request per row.\n", + "type": "array", + "items": { + "$ref": "#/definitions/Row" + } + }, + "requestPredictionIntervals": { + "type": "boolean", + "description": "If set to `true`, the scorer will try to fill field `predictionIntervals` in response if it is supported.\n" + } + } + }, + "ScoreResponse": { + "type": "object", + "properties": { + "id": { + "description": "A unique id of the model used for scoring.\n", + "type": "string" + }, + "fields": { + "description": "An array holding the names of fields in the order of appearance in the rows of the `score` property. This field is not populated if requested by setting the `noFieldNamesInOutput` request field to `true`.\n", + "type": "array", + "items": { + "type": "string" + } + }, + "score": { + "description": "An array of rows consisting the actual scoring output. The order of rows corresponds to the order of the input request rows. Each row contains any copied input fields first (in the order of appearance in the input row). If the `idField` was specified and also listed in the `includeFieldsInOutput` but not provided in `fields`, a unique id will be generated and positioned right after all the other fields copied from the input. The scoring output follows.\n", + "type": "array", + "items": { + "$ref": "#/definitions/Row" + } + }, + "featureShapleyContributions": { + "type": "object", + "description": "An object with features and shapley values that was requested by the client. This field will not be populated if the Shapley values are not available for a model.\n", + "$ref": "#/definitions/ContributionResponse" + }, + "predictionIntervals": { + "type": "object", + "description": "Prediction interval consist of an array of interval bound names and rows of array of bounds per bound name. Setting `requestPredictionIntervals` to true will enable populating the field. The field will be empty or an error response returned if prediction intervals are not returned or supported by the model.\n", + "$ref": "#/definitions/PredictionInterval" + } + } + }, + "DataField": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "dataType": { + "type": "string", + "enum": [ + "Bool", + "Int32", + "Int64", + "Float32", + "Float64", + "Str", + "Time64" + ] + }, + "example": { + "type": "string" + } + } + }, + "ModelSchema": { + "type": "object", + "properties": { + "inputFields": { + "type": "array", + "items": { + "$ref": "#/definitions/DataField" + } + }, + "targetFields": { + "type": "array", + "items": { + "$ref": "#/definitions/DataField" + } + }, + "outputFields": { + "type": "array", + "items": { + "$ref": "#/definitions/DataField" + } + } + } + }, + "ShapleyType": { + "type": "string", + "enum": [ + "ORIGINAL", + "TRANSFORMED", + "NONE" + ] + }, + "ScoringType": { + "type": "string", + "enum": [ + "REGRESSION", + "CLASSIFICATION", + "BINOMIAL" + ] + }, + "CapabilityType": { + "type": "string", + "enum": [ + "SCORE", + "SCORE_PREDICTION_INTERVAL", + "CONTRIBUTION_ORIGINAL", + "CONTRIBUTION_TRANSFORMED", + "MEDIA", + "TEST_TIME_AUGMENTATION" + ] + } + } +} \ No newline at end of file diff --git a/common/swagger/v1/vertex-ai-swagger.json b/common/swagger/v1/vertex-ai-swagger.json new file mode 100644 index 00000000..97a78504 --- /dev/null +++ b/common/swagger/v1/vertex-ai-swagger.json @@ -0,0 +1,251 @@ +{ + "swagger": "2.0", + "info": { + "description": "This is a definition of the REST API for scoring from H2O. This API is intended to be used within DAI and eventually across all H2O scoring systems", + "version": "1.0.0", + "title": "Scoring API - v1", + "termsOfService": "", + "contact": { + "email": "support@h2o.ai" + }, + "license": { + "name": "License", + "url": "http://www.h2o.ai" + } + }, + "host": "localhost", + "basePath": "/", + "schemes": [ + "https", + "http" + ], + "paths": { + "/model/id": { + "get": { + "tags": [ + "metadata" + ], + "summary": "Returns model id", + "description": "Returns unique id of the model loaded in the server and used for scoring", + "operationId": "getModelId", + "produces": [ + "text/plain" + ], + "parameters": [], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "type": "string" + } + } + }, + "security": [ + { + "api_key": [] + } + ] + } + }, + "/model/score": { + "post": { + "tags": [ + "scoring" + ], + "summary": "Score on given rows", + "description": "Computes score of the rows sent in the body of the post request", + "operationId": "getScore", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/ScoreRequest" + } + } + ], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "$ref": "#/definitions/ScoreResponse" + } + }, + "400": { + "description": "Invalid payload" + } + } + } + } + }, + "securityDefinitions": { + "api_key": { + "type": "apiKey", + "name": "api_key", + "in": "header" + } + }, + "definitions": { + "Model": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "properties": { + "type": "object", + "properties": { + "scoringType": { + "type": "string", + "enum": [ + "REGRESSION", + "CLASSIFICATION", + "BINOMIAL" + ] + }, + "scoringResponLabels": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "schema": { + "$ref": "#/definitions/ModelSchema" + } + } + }, + "Row": { + "type": "array", + "items": { + "type": "string" + } + }, + "Parameter": { + "type": "object", + "properties": { + "fields": { + "description": "An array holding the names of fields in the order of appearance in the `rows` property. The length of `fields` has to match length of each row in `rows`. No duplicates are allowed.\n", + "type": "array", + "items": { + "type": "string" + } + }, + "includeFieldsInOutput": { + "description": "An array holding the list of field names to be copied from the input request row to the corresponding scoring output. It is an error to specify a field name not present in the `fields` property, except when it is equal to the `idField` property. In the latter case, the row id would be generated and returned in the response. Note that the order of items in `includeFieldsInOutput` is ignored and the specified fields are returned in the order of appearance in the input request row.\n", + "type": "array", + "items": { + "type": "string" + } + }, + "noFieldNamesInOutput": { + "description": "If set to `true`. The scorer will not fill response column names in the `fields` field. This is can be useful to maintain compatibility with older scorer versions or to save bandwidth.\n", + "type": "boolean" + }, + "idField": { + "description": "Name of the field that holds a row id, e.g., a value that uniquely identifies each row of the request. The caller may specify a name of the field that is not present in fields. In which case, the scorer is allowed to generate a UUID to identify each row (e.g., for logging and monitoring purposes). To retrieve such a generated id as a part of the response, simply name it in the `includeFieldsInOutput`.\n", + "type": "string" + } + } + }, + "ScoreRequest": { + "type": "object", + "properties": { + "parameters": { + "description": "An object holding the fields and other optional parameters.\n", + "$ref": "#/definitions/Parameter" + }, + "instances": { + "description": "An array of rows consisting the actual input data for scoring, one scoring request per row.\n", + "type": "array", + "items": { + "$ref": "#/definitions/Row" + } + } + } + }, + "ScoreResponse": { + "type": "object", + "properties": { + "id": { + "description": "A unique id of the model used for scoring.\n", + "type": "string" + }, + "fields": { + "description": "An array holding the names of fields in the order of appearance in the rows of the `score` property. This field is not populated if requested by setting the `noFieldNamesInOutput` request field to `true`.\n", + "type": "array", + "items": { + "type": "string" + } + }, + "predictions": { + "description": "An array of rows consisting the actual scoring output. The order of rows corresponds to the order of the input request rows. Each row contains any copied input fields first (in the order of appearance in the input row). If the `idField` was specified and also listed in the `includeFieldsInOutput` but not provided in `fields`, a unique id will be generated and positioned right after all the other fields copied from the input. The scoring output follows.\n", + "type": "array", + "items": { + "$ref": "#/definitions/Row" + } + } + } + }, + "DataField": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "dataType": { + "type": "string", + "enum": [ + "Bool", + "Int32", + "Int64", + "Float32", + "Float64", + "Str", + "Time64" + ] + }, + "example": { + "type": "string" + } + } + }, + "ModelSchema": { + "type": "object", + "properties": { + "inputFields": { + "type": "array", + "items": { + "$ref": "#/definitions/DataField" + } + }, + "targetFields": { + "type": "array", + "items": { + "$ref": "#/definitions/DataField" + } + }, + "outputFields": { + "type": "array", + "items": { + "$ref": "#/definitions/DataField" + } + } + } + } + } +} \ No newline at end of file diff --git a/common/swagger/v1exp/swagger.json b/common/swagger/v1exp/swagger.json new file mode 100644 index 00000000..72ad3cb6 --- /dev/null +++ b/common/swagger/v1exp/swagger.json @@ -0,0 +1,97 @@ +{ + "openapi": "3.0.0", + "info": { + "description": "This is an extension of the REST API for scoring from H2O. This API is intended to be used for scoring within H2O.ai.", + "version": "1.2.0-exp", + "title": "Scoring API - v1 experimental", + "contact": { + "email": "support@h2o.ai" + }, + "license": { + "name": "License", + "url": "http://www.h2o.ai" + } + }, + "servers": [ + { + "url": "/" + } + ], + "paths": { + "/model/media-score": { + "post": { + "tags": [ + "scoring" + ], + "summary": "Score model with provided media files", + "description": "Computes score of provided data making use of provided media files.", + "operationId": "getMediaScore", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "scoreMediaRequest": { + "$ref": "#/components/schemas/scoreMediaRequest" + }, + "files": { + "type": "array", + "items": { + "type": "string", + "format": "binary" + } + } + }, + "required": [ + "scoreMediaRequest", + "files" + ] + } + } + } + }, + "responses": { + "200": { + "description": "Successful scoring operation", + "content": { + "application/json": { + "schema": { + "$ref": "../v1/swagger.yaml#/definitions/ScoreResponse" + } + } + } + }, + "400": { + "description": "Invalid payload" + }, + "501": { + "description": "Implementation not supported" + } + } + } + } + }, + "components": { + "schemas": { + "scoreMediaRequest": { + "allOf": [ + { + "$ref": "../v1/swagger.yaml#/definitions/ScoreRequest" + }, + { + "properties": { + "mediaFields": { + "description": "An array holding the names of all fields which are expected to contain media files. Contents of these fields will be replaced by corresponding uploaded files where the expected values in the column must be the file names of the uploaded files.\n", + "type": "array", + "items": { + "type": "string" + } + } + } + } + ] + } + } + } +} \ No newline at end of file diff --git a/common/swagger/v1openapi3/swagger.json b/common/swagger/v1openapi3/swagger.json new file mode 100644 index 00000000..67e979eb --- /dev/null +++ b/common/swagger/v1openapi3/swagger.json @@ -0,0 +1,604 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "Scoring API - v1.2.0-openapi3", + "description": "This is a definition of the REST API for scoring from H2O. This API is intended to be used within DAI and eventually across all H2O scoring systems. This API combines both v1 and v1Exp in OPENAPI 3.0 spec.", + "contact": { + "email": "support@h2o.ai" + }, + "license": { + "name": "License", + "url": "http://www.h2o.ai" + }, + "version": "1.2.0-openapi3" + }, + "servers": [ + { + "url": "/" + } + ], + "paths": { + "/model/id": { + "get": { + "tags": [ + "metadata" + ], + "summary": "Returns model id", + "description": "Returns unique id of the model loaded in the server and used for scoring", + "operationId": "getModelId", + "responses": { + "200": { + "description": "Successful operation", + "content": { + "text/plain": { + "schema": { + "type": "string" + } + } + } + } + }, + "security": [ + { + "api_key": [] + } + ] + } + }, + "/model/schema": { + "get": { + "tags": [ + "metadata" + ], + "summary": "Describe a model", + "description": "Returns information about the model used for scoring, e.g., input schema.", + "operationId": "getModelInfo", + "responses": { + "200": { + "description": "Successful operation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Model" + } + } + } + } + } + } + }, + "/model/sample_request": { + "get": { + "tags": [ + "metadata" + ], + "summary": "Sample scoring request", + "description": "Builds a sample scoring request that would pass all validations", + "operationId": "getSampleRequest", + "responses": { + "200": { + "description": "Successful operation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ScoreRequest" + } + } + } + } + } + } + }, + "/model/capabilities": { + "get": { + "tags": [ + "metadata" + ], + "summary": "List capabilities supported by the scorer.", + "description": "Returns the capabilities supported by the scorer.", + "operationId": "getCapabilities", + "responses": { + "200": { + "description": "Successful operation", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CapabilityType" + } + } + } + } + } + } + } + }, + "/model/score": { + "get": { + "tags": [ + "scoring" + ], + "summary": "Score on given file", + "description": "Computes score of the rows in the file specified by the path in the query parameter", + "operationId": "getScoreByFile", + "parameters": [ + { + "name": "file", + "in": "query", + "required": false, + "style": "form", + "explode": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful operation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ScoreResponse" + } + } + } + }, + "400": { + "description": "Invalid payload", + "content": {} + } + } + }, + "post": { + "tags": [ + "scoring" + ], + "summary": "Score on given rows", + "description": "Computes score of the rows sent in the body of the post request", + "operationId": "getScore", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ScoreRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful operation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ScoreResponse" + } + } + } + }, + "400": { + "description": "Invalid payload", + "content": {} + }, + "500": { + "description": "Failure operation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + }, + "x-codegen-request-body-name": "payload" + } + }, + "/model/contribution": { + "post": { + "tags": [ + "contribution" + ], + "summary": "Contribution score or Shapley values on given rows", + "description": "Computes contribution score with the rows sent in the body of the post request", + "operationId": "getContribution", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ContributionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful operation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ContributionResponse" + } + } + } + }, + "400": { + "description": "Invalid payload", + "content": {} + }, + "500": { + "description": "Failure operation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + }, + "501": { + "description": "Implementation not supported", + "content": {} + } + }, + "x-codegen-request-body-name": "payload" + } + }, + "/model/media-score": { + "post": { + "tags": [ + "scoring" + ], + "summary": "Score model with provided media files", + "description": "Computes score of provided data making use of provided media files.", + "operationId": "getMediaScore", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "scoreMediaRequest": { + "$ref": "#/components/schemas/scoreMediaRequest" + }, + "files": { + "type": "array", + "items": { + "type": "string", + "format": "binary" + } + } + }, + "required": [ + "scoreMediaRequest", + "files" + ] + } + } + } + }, + "responses": { + "200": { + "description": "Successful scoring operation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ScoreResponse" + } + } + } + }, + "400": { + "description": "Invalid payload" + }, + "501": { + "description": "Implementation not supported" + } + } + } + } + }, + "components": { + "schemas": { + "Model": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "properties": { + "$ref": "#/components/schemas/Model_properties" + }, + "schema": { + "$ref": "#/components/schemas/ModelSchema" + } + } + }, + "Row": { + "type": "array", + "properties": { + "length": { + "type": "integer" + } + }, + "items": { + "type": "string" + } + }, + "ContributionRequest": { + "required": [ + "requestShapleyValueType" + ], + "type": "object", + "properties": { + "requestShapleyValueType": { + "$ref": "#/components/schemas/ShapleyType" + }, + "fields": { + "type": "array", + "description": "An array holding the names of fields in the order of appearance in the `rows` property. The length of `fields` has to match length of each row in `rows`. No duplicates are allowed.\n", + "items": { + "type": "string" + } + }, + "rows": { + "type": "array", + "description": "An array of rows consisting the actual input data for scoring, one scoring request per row.\n", + "items": { + "$ref": "#/components/schemas/Row" + } + } + } + }, + "ContributionResponse": { + "type": "object", + "properties": { + "features": { + "type": "array", + "description": "An array holding the names of fields in the order of appearance in the rows of the `contributions` property.\n", + "items": { + "type": "string" + } + }, + "contributionGroups": { + "type": "array", + "description": "An array of rows consisting of the shapley contributions output corresponding to an output group.\n", + "items": { + "$ref": "#/components/schemas/ContributionGroup" + } + } + } + }, + "ContributionGroup": { + "type": "object", + "properties": { + "outputGroup": { + "type": "string", + "description": "Name of the output group. It will be populated only for multinomial models. Shapley values are not supported for third party models yet, hence this field will not be populated.\n" + }, + "contributions": { + "type": "array", + "description": "An array of rows consisting of the shapley contributions output corresponding to columns in the fields\n", + "items": { + "$ref": "#/components/schemas/Row" + } + } + } + }, + "ScoreRequest": { + "type": "object", + "properties": { + "requestShapleyValueType": { + "$ref": "#/components/schemas/ShapleyType" + }, + "includeFieldsInOutput": { + "type": "array", + "description": "An array holding the list of field names to be copied from the input request row to the corresponding scoring output. It is an error to specify a field name not present in the `fields` property, except when it is equal to the `idField` property. In the latter case, the row id would be generated and returned in the response. Note that the order of items in `includeFieldsInOutput` is ignored and the specified fields are returned in the order of appearance in the input request row.\n", + "items": { + "type": "string" + } + }, + "noFieldNamesInOutput": { + "type": "boolean", + "description": "If set to `true`. The scorer will not fill response column names in the `fields` field. This is can be useful to maintain compatibility with older scorer versions or to save bandwidth.\n" + }, + "idField": { + "type": "string", + "description": "Name of the field that holds a row id, e.g., a value that uniquely identifies each row of the request. The caller may specify a name of the field that is not present in fields. In which case, the scorer is allowed to generate a UUID to identify each row (e.g., for logging and monitoring purposes). To retrieve such a generated id as a part of the response, simply name it in the `includeFieldsInOutput`.\n" + }, + "fields": { + "type": "array", + "description": "An array holding the names of fields in the order of appearance in the `rows` property. The length of `fields` has to match length of each row in `rows`. No duplicates are allowed.\n", + "items": { + "type": "string" + } + }, + "rows": { + "type": "array", + "description": "An array of rows consisting the actual input data for scoring, one scoring request per row.\n", + "items": { + "$ref": "#/components/schemas/Row" + } + }, + "requestPredictionIntervals": { + "type": "boolean", + "description": "If set to `true`, the scorer will try to fill field `predictionIntervals` in response if it is supported.\n" + } + } + }, + "ScoreResponse": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "A unique id of the model used for scoring.\n" + }, + "fields": { + "type": "array", + "description": "An array holding the names of fields in the order of appearance in the rows of the `score` property. This field is not populated if requested by setting the `noFieldNamesInOutput` request field to `true`.\n", + "items": { + "type": "string" + } + }, + "score": { + "type": "array", + "description": "An array of rows consisting the actual scoring output. The order of rows corresponds to the order of the input request rows. Each row contains any copied input fields first (in the order of appearance in the input row). If the `idField` was specified and also listed in the `includeFieldsInOutput` but not provided in `fields`, a unique id will be generated and positioned right after all the other fields copied from the input. The scoring output follows.\n", + "items": { + "$ref": "#/components/schemas/Row" + } + }, + "featureShapleyContributions": { + "$ref": "#/components/schemas/ContributionResponse" + }, + "predictionIntervals": { + "$ref": "#/components/schemas/PredictionInterval" + } + } + }, + "ErrorResponse": { + "type": "object", + "properties": { + "detail": { + "type": "string", + "description": "A string message containing the detail error message.\n" + } + } + }, + "DataField": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "dataType": { + "type": "string", + "enum": [ + "Bool", + "Int32", + "Int64", + "Float32", + "Float64", + "Str", + "Time64" + ] + }, + "example": { + "type": "string" + } + } + }, + "ModelSchema": { + "type": "object", + "properties": { + "inputFields": { + "type": "array", + "items": { + "$ref": "#/components/schemas/DataField" + } + }, + "targetFields": { + "type": "array", + "items": { + "$ref": "#/components/schemas/DataField" + } + }, + "outputFields": { + "type": "array", + "items": { + "$ref": "#/components/schemas/DataField" + } + } + } + }, + "ShapleyType": { + "type": "string", + "enum": [ + "ORIGINAL", + "TRANSFORMED", + "NONE" + ] + }, + "ScoringType": { + "type": "string", + "enum": [ + "REGRESSION", + "CLASSIFICATION", + "BINOMIAL" + ] + }, + "CapabilityType": { + "type": "string", + "enum": [ + "SCORE", + "SCORE_PREDICTION_INTERVAL", + "CONTRIBUTION_ORIGINAL", + "CONTRIBUTION_TRANSFORMED" + ] + }, + "Model_properties": { + "type": "object", + "properties": { + "scoringType": { + "$ref": "#/components/schemas/ScoringType" + }, + "scoringResponLabels": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "PredictionInterval": { + "type": "object", + "description": "Prediction interval consist of an array of interval bound names and rows of array of bounds per bound name. Setting `requestPredictionIntervals` to true will enable populating the field. The field will be empty or an error response returned if prediction intervals are not returned or supported by the model.", + "properties": { + "fields": { + "$ref": "#/components/schemas/Row" + }, + "rows": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Row" + } + } + } + }, + "scoreMediaRequest": { + "allOf": [ + { + "$ref": "#/components/schemas/ScoreRequest" + }, + { + "properties": { + "mediaFields": { + "description": "An array holding the names of all fields which are expected to contain media files. Contents of these fields will be replaced by corresponding uploaded files where the expected values in the column must be the file names of the uploaded files.\n", + "type": "array", + "items": { + "type": "string" + } + } + } + } + ] + } + }, + "securitySchemes": { + "api_key": { + "type": "apiKey", + "name": "api_key", + "in": "header" + } + } + } +} \ No newline at end of file diff --git a/common/transform/build.gradle b/common/transform/build.gradle index e3d1da7b..1434aec7 100644 --- a/common/transform/build.gradle +++ b/common/transform/build.gradle @@ -1,12 +1,16 @@ +plugins { + id 'org.springframework.boot' +} + apply from: project(":").file('gradle/java.gradle') dependencies { implementation project(':common:rest-java-model') + implementation group: 'io.swagger.core.v3', name: 'swagger-annotations' implementation group: 'ai.h2o', name: 'mojo2-runtime-api' implementation group: 'ai.h2o', name: 'mojo2-runtime-impl' implementation group: 'com.google.guava', name: 'guava' implementation group: 'org.slf4j', name: 'slf4j-api' - implementation group: 'org.yaml', name: 'snakeyaml' // FIXME(MM): this should not be required, since the dependency should be provided // by mojo2-runtime-impl. The problem is that mojo2 does not expose that dependency // as compile time dependency for consumers. @@ -15,15 +19,22 @@ dependencies { // end of fixme testImplementation group: 'com.google.truth.extensions', name: 'truth-java8-extension' - testImplementation group: 'org.mockito', name: 'mockito-inline' - testImplementation group: 'org.mockito', name : 'mockito-core' - testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter' - testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-api' - testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-params' - testRuntimeOnly group: 'org.junit.jupiter', name: 'junit-jupiter-engine' - testImplementation group: 'org.junit-pioneer', name: 'junit-pioneer' + testImplementation group: 'org.mockito', name: 'mockito-inline', version: mockitoInlineVersion + testImplementation group: 'org.mockito', name : 'mockito-core', version: mockitoVersion + testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: mockitoVersion + testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-api', version: jupiterVersion + testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-params', version: jupiterVersion + testRuntimeOnly group: 'org.junit.jupiter', name: 'junit-jupiter-engine', version: jupiterVersion + testImplementation group: 'org.junit-pioneer', name: 'junit-pioneer', version: jupiterPioneerVersion +} + +bootJar { + enabled=false } test { useJUnitPlatform() + + jvmArgs '--add-opens=java.base/java.util=ALL-UNNAMED' + jvmArgs '--add-opens=java.base/java.lang=ALL-UNNAMED' } diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ContributionRequestToMojoFrameConverter.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ContributionRequestToMojoFrameConverter.java index 78b194b6..fdd60e96 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ContributionRequestToMojoFrameConverter.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ContributionRequestToMojoFrameConverter.java @@ -1,7 +1,6 @@ package ai.h2o.mojos.deploy.common.transform; import ai.h2o.mojos.deploy.common.rest.model.ContributionRequest; -import ai.h2o.mojos.deploy.common.rest.model.Row; import ai.h2o.mojos.runtime.frame.MojoFrame; import ai.h2o.mojos.runtime.frame.MojoFrameBuilder; import ai.h2o.mojos.runtime.frame.MojoRowBuilder; @@ -9,16 +8,16 @@ import java.util.function.BiFunction; /** - * Converts the original API request object - * {@link ContributionRequest} into the input {@link MojoFrame}. + * Converts the original API request object {@link ContributionRequest} into the input {@link + * MojoFrame}. */ public class ContributionRequestToMojoFrameConverter - implements BiFunction { + implements BiFunction { @Override public MojoFrame apply(ContributionRequest scoreRequest, MojoFrameBuilder frameBuilder) { List fields = scoreRequest.getFields(); if (scoreRequest.getRows() != null) { - for (Row row : scoreRequest.getRows()) { + for (List row : scoreRequest.getRows()) { MojoRowBuilder rowBuilder = frameBuilder.getMojoRowBuilder(); for (int i = 0; i < row.size(); i++) { rowBuilder.setValue(fields.get(i), row.get(i)); diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToContributionResponseConverter.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToContributionResponseConverter.java index 7e332270..fef44d63 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToContributionResponseConverter.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToContributionResponseConverter.java @@ -3,9 +3,7 @@ import ai.h2o.mojos.deploy.common.rest.model.ContributionGroup; import ai.h2o.mojos.deploy.common.rest.model.ContributionResponse; import ai.h2o.mojos.deploy.common.rest.model.Row; -import ai.h2o.mojos.runtime.frame.MojoColumn; import ai.h2o.mojos.runtime.frame.MojoFrame; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -20,9 +18,10 @@ public class MojoFrameToContributionResponseConverter { * Converts the resulting predicted {@link MojoFrame} into the API response object {@link * ContributionResponse}. */ - public ContributionResponse contributionResponseWithNoOutputGroup( - MojoFrame shapleyMojoFrame) { - List outputRows = Stream.generate(Row::new).limit(shapleyMojoFrame.getNrows()) + public ContributionResponse contributionResponseWithNoOutputGroup(MojoFrame shapleyMojoFrame) { + List> outputRows = + Stream.generate(ArrayList::new) + .limit(shapleyMojoFrame.getNrows()) .collect(Collectors.toList()); Utils.copyResultFields(shapleyMojoFrame, outputRows); @@ -45,7 +44,7 @@ public ContributionResponse contributionResponseWithNoOutputGroup( * ContributionResponse grouped by the strings called as outputgroupNames}. */ public ContributionResponse contributionResponseWithOutputGroup( - MojoFrame shapleyMojoFrame, List outputGroupNames) { + MojoFrame shapleyMojoFrame, List outputGroupNames) { int rowCount = shapleyMojoFrame.getNrows(); List columnNames = Arrays.asList(shapleyMojoFrame.getColumnNames()); @@ -55,8 +54,7 @@ public ContributionResponse contributionResponseWithOutputGroup( boolean isFirstOutputGroup = true; for (String outputGroupName : outputGroupNames) { - ContributionGroup contributionGroup - = createContributionGroup(rowCount, outputGroupName); + ContributionGroup contributionGroup = createContributionGroup(rowCount, outputGroupName); Pattern pattern = Pattern.compile("\\." + outputGroupName); // note: columnNames from mojo contains a combination of featureName and outputGroupName @@ -71,7 +69,7 @@ public ContributionResponse contributionResponseWithOutputGroup( } String[] columnDataFromMojo = shapleyMojoFrame.getColumn(i).getDataAsStrings(); for (int k = 0; k < rowCount; k++) { - Row existingRow = contributionGroup.getContributions().get(k); + List existingRow = contributionGroup.getContributions().get(k); existingRow.add(columnDataFromMojo[k]); } } @@ -86,8 +84,8 @@ public ContributionResponse contributionResponseWithOutputGroup( private ContributionGroup createContributionGroup(int rowCount, String outputGroupName) { ContributionGroup contributionGroups = new ContributionGroup(); contributionGroups.setOutputGroup(outputGroupName); - contributionGroups.setContributions(Stream.generate(Row::new) - .limit(rowCount).collect(Collectors.toList())); + contributionGroups.setContributions( + Stream.generate(Row::new).limit(rowCount).collect(Collectors.toList())); return contributionGroups; } } diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverter.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverter.java index a56922e3..59a6a60c 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverter.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverter.java @@ -49,23 +49,25 @@ public MojoFrameToScoreResponseConverter() { /** * Transform MOJO response frame into ScoreResponse. + * * @param mojoFrame mojo response frame. * @param scoreRequest score request. * @return score response. */ @Override - public ScoreResponse apply( - MojoFrame mojoFrame, ScoreRequest scoreRequest) { + public ScoreResponse apply(MojoFrame mojoFrame, ScoreRequest scoreRequest) { Set includedFields = getSetOfIncludedFields(scoreRequest); - List outputRows = - Stream.generate(Row::new).limit(mojoFrame.getNrows()).collect(Collectors.toList()); + List> outputRows = + Stream.generate(ArrayList::new) + .limit(mojoFrame.getNrows()) + .collect(Collectors.toList()); copyFilteredInputFields(scoreRequest, includedFields, outputRows); fillOutputRows(mojoFrame, outputRows); ScoreResponse response = new ScoreResponse(); response.setScore(outputRows); - if (!Boolean.TRUE.equals(scoreRequest.isNoFieldNamesInOutput())) { + if (!Boolean.TRUE.equals(scoreRequest.getNoFieldNamesInOutput())) { List outputFieldNames = getFilteredInputFieldNames(scoreRequest, includedFields); outputFieldNames.addAll(getTargetField(mojoFrame)); response.setFields(outputFieldNames); @@ -75,31 +77,27 @@ public ScoreResponse apply( } /** - * Populate target column rows into outputRows. - * When prediction interval is returned from MOJO - * response frame, only one column rows will - * be populated into the outputRows to ensure - * backward compatible. + * Populate target column rows into outputRows. When prediction interval is returned from MOJO + * response frame, only one column rows will be populated into the outputRows to ensure backward + * compatible. */ - private void fillOutputRows( - MojoFrame mojoFrame, List outputRows) { - List targetRows = getTargetRows(mojoFrame); + private void fillOutputRows(MojoFrame mojoFrame, List> outputRows) { + List> targetRows = getTargetRows(mojoFrame); for (int rowIdx = 0; rowIdx < mojoFrame.getNrows(); rowIdx++) { outputRows.get(rowIdx).addAll(targetRows.get(rowIdx)); } } /** - * Populate Prediction Interval value into response field. - * Only when score request set requestPredictionIntervals be true - * and MOJO pipeline support prediction interval. + * Populate Prediction Interval value into response field. Only when score request set + * requestPredictionIntervals be true and MOJO pipeline support prediction interval. */ private void fillWithPredictionInterval( MojoFrame mojoFrame, ScoreRequest scoreRequest, ScoreResponse scoreResponse) { - if (Boolean.TRUE.equals(scoreRequest.isRequestPredictionIntervals())) { + if (Boolean.TRUE.equals(scoreRequest.getRequestPredictionIntervals())) { if (!supportPredictionInterval) { throw new IllegalStateException( - "Unexpected error, prediction interval should be supported, but actually not"); + "Unexpected error, prediction interval should be supported, but actually not"); } PredictionInterval predictionInterval = new PredictionInterval().fields(new Row()).rows(Collections.emptyList()); @@ -116,16 +114,14 @@ private void fillWithPredictionInterval( } /** - * Extract target column rows from MOJO response frame. - * Note: To ensure backward compatibility, - * if prediction interval is enabled then extracts only one - * column rows from response columns. + * Extract target column rows from MOJO response frame. Note: To ensure backward compatibility, if + * prediction interval is enabled then extracts only one column rows from response columns. */ - private List getTargetRows(MojoFrame mojoFrame) { - List taretRows = Stream - .generate(Row::new) - .limit(mojoFrame.getNrows()) - .collect(Collectors.toList()); + private List> getTargetRows(MojoFrame mojoFrame) { + List> taretRows = + Stream.generate(ArrayList::new) + .limit(mojoFrame.getNrows()) + .collect(Collectors.toList()); for (int row = 0; row < mojoFrame.getNrows(); row++) { for (int col : getTargetFieldIndices(mojoFrame)) { String cell = mojoFrame.getColumn(col).getDataAsStrings()[row]; @@ -136,13 +132,10 @@ private List getTargetRows(MojoFrame mojoFrame) { } /** - * Extract target columns from MOJO response frame. - * When prediction interval is enabled, extracts only one - * column from MOJO frame, otherwise all columns names - * will be extracted. + * Extract target columns from MOJO response frame. When prediction interval is enabled, extracts + * only one column from MOJO frame, otherwise all columns names will be extracted. */ - private List getTargetField( - MojoFrame mojoFrame) { + private List getTargetField(MojoFrame mojoFrame) { if (mojoFrame.getNcols() > 0) { List targetColumns = Arrays.asList(mojoFrame.getColumnNames()); if (supportPredictionInterval) { @@ -150,8 +143,7 @@ private List getTargetField( if (targetIdx < 0) { log.debug( "singular target column does not exist in MOJO response frame," - + " this could be a classification model." - ); + + " this could be a classification model."); } else { return targetColumns.subList(targetIdx, targetIdx + 1); } @@ -163,10 +155,9 @@ private List getTargetField( } /** - * Extract target columns indices from MOJO response frame. - * When prediction interval is enabled, extracts only one - * column index from MOJO frame, otherwise all - * columns indices will be extracted. + * Extract target columns indices from MOJO response frame. When prediction interval is enabled, + * extracts only one column index from MOJO frame, otherwise all columns indices will be + * extracted. */ private List getTargetFieldIndices(MojoFrame mojoFrame) { if (mojoFrame.getNcols() > 0) { @@ -176,8 +167,7 @@ private List getTargetFieldIndices(MojoFrame mojoFrame) { if (targetIdx < 0) { log.debug( "singular target column does not exist in MOJO response frame," - + " this could be a classification model." - ); + + " this could be a classification model."); } else { return Collections.singletonList(targetIdx); } @@ -189,15 +179,14 @@ private List getTargetFieldIndices(MojoFrame mojoFrame) { } /** - * Extract prediction interval columns rows from MOJO response frame. - * Note: Assumption is prediction interval should already be enabled - * and response frame has expected structure. + * Extract prediction interval columns rows from MOJO response frame. Note: Assumption is + * prediction interval should already be enabled and response frame has expected structure. */ - private List getPredictionIntervalRows(MojoFrame mojoFrame, int targetIdx) { - List predictionIntervalRows = Stream - .generate(Row::new) - .limit(mojoFrame.getNrows()) - .collect(Collectors.toList()); + private List> getPredictionIntervalRows(MojoFrame mojoFrame, int targetIdx) { + List> predictionIntervalRows = + Stream.generate(ArrayList::new) + .limit(mojoFrame.getNrows()) + .collect(Collectors.toList()); for (int row = 0; row < mojoFrame.getNrows(); row++) { for (int col = 0; col < mojoFrame.getNcols(); col++) { if (col == targetIdx) { @@ -211,9 +200,8 @@ private List getPredictionIntervalRows(MojoFrame mojoFrame, int targetIdx) } /** - * Extract prediction interval columns names from MOJO response frame. - * Note: Assumption is prediction interval should already be enabled - * and response frame has expected structure. + * Extract prediction interval columns names from MOJO response frame. Note: Assumption is + * prediction interval should already be enabled and response frame has expected structure. */ private Row getPredictionIntervalFields(MojoFrame mojoFrame, int targetIdx) { Row row = new Row(); @@ -225,9 +213,8 @@ private Row getPredictionIntervalFields(MojoFrame mojoFrame, int targetIdx) { } /** - * Extract target column index from list of column names. - * Note: Assumption is a singular target column should be found. - * Otherwise, the output indicates this a classification model. + * Extract target column index from list of column names. Note: Assumption is a singular target + * column should be found. Otherwise, the output indicates this a classification model. */ private int getTargetColIdx(List mojoColumns) { if (mojoColumns.size() == 1) { @@ -247,15 +234,15 @@ private int getTargetColIdx(List mojoColumns) { } private static void copyFilteredInputFields( - ScoreRequest scoreRequest, Set includedFields, List outputRows) { + ScoreRequest scoreRequest, Set includedFields, List> outputRows) { if (includedFields.isEmpty()) { return; } boolean generateRowIds = shouldGenerateRowIds(scoreRequest, includedFields); - List inputRows = scoreRequest.getRows(); + List> inputRows = scoreRequest.getRows(); for (int row = 0; row < outputRows.size(); row++) { - Row inputRow = inputRows.get(row); - Row outputRow = outputRows.get(row); + List inputRow = inputRows.get(row); + List outputRow = outputRows.get(row); List inputFields = scoreRequest.getFields(); for (int col = 0; col < inputFields.size(); col++) { if (includedFields.contains(inputFields.get(col))) { diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java index 1762db62..387e8334 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java @@ -24,20 +24,19 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * H2O DAI mojo scorer. * - *

The scorer code is shared for all mojo deployments and is only parameterized by the - * {@code mojo.path} property to define the mojo to use. - * {@code shapley.enable} property to enable shapley contribution. + *

The scorer code is shared for all mojo deployments and is only parameterized by the {@code + * mojo.path} property to define the mojo to use. {@code shapley.enable} property to enable shapley + * contribution. */ public class MojoScorer { - private static final String ENABLE_SHAPLEY_CONTRIBUTION_MESSAGE - = "shapley.types.enabled property has to be set to one of [TRANSFORMED, ORIGINAL, ALL]" + private static final String ENABLE_SHAPLEY_CONTRIBUTION_MESSAGE = + "shapley.types.enabled property has to be set to one of [TRANSFORMED, ORIGINAL, ALL]" + " or shapley.enable property has to be set to true in the runtime configuration " + "to obtain Shapley contribution"; @@ -48,8 +47,8 @@ public class MojoScorer { public static final boolean supportPredictionInterval = checkIfPredictionIntervalSupport(); private static final MojoPipeline pipeline = supportPredictionInterval - ? loadMojoPipelineFromFile(buildPipelineConfigWithPredictionInterval()) - : loadMojoPipelineFromFile(); + ? loadMojoPipelineFromFile(buildPipelineConfigWithPredictionInterval()) + : loadMojoPipelineFromFile(); private final ShapleyLoadOption enabledShapleyTypes; private final boolean shapleyEnabled; private static MojoPipeline pipelineTransformedShapley; @@ -100,24 +99,20 @@ public MojoScorer( * @return response {@link ScoreResponse} */ public ScoreResponse score(ScoreRequest request) { - if (Boolean.TRUE.equals(request.isRequestPredictionIntervals()) + if (Boolean.TRUE.equals(request.getRequestPredictionIntervals()) && !supportPredictionInterval) { throw new IllegalArgumentException( - "requestPredictionIntervals set to true, but model does not support it" - ); + "requestPredictionIntervals set to true, but model does not support it"); } scoreRequestTransformer.accept(request, getModelInfo().getSchema().getInputFields()); - MojoFrame requestFrame = scoreRequestConverter - .apply(request, pipeline.getInputFrameBuilder()); + MojoFrame requestFrame = scoreRequestConverter.apply(request, pipeline.getInputFrameBuilder()); MojoFrame responseFrame = doScore(requestFrame); - ScoreResponse response = scoreResponseConverter.apply( - responseFrame, request); + ScoreResponse response = scoreResponseConverter.apply(responseFrame, request); response.id(pipeline.getUuid()); ShapleyType requestShapleyType = request.getRequestShapleyValueType(); - if (requestShapleyType == null - || requestShapleyType == ShapleyType.NONE) { + if (requestShapleyType == null || requestShapleyType == ShapleyType.NONE) { return response; } @@ -130,8 +125,7 @@ public ScoreResponse score(ScoreRequest request) { throw new IllegalArgumentException( String.format( "Requested Shapley type %s not enabled for this scorer. Expected: %s", - requestShapleyType, - enabledShapleyTypes)); + requestShapleyType, enabledShapleyTypes)); } try { @@ -154,14 +148,14 @@ public ScoreResponse score(ScoreRequest request) { } private ContributionResponse originalFeatureContribution(ScoreRequest request) { - MojoFrame requestFrame = scoreRequestConverter - .apply(request, pipelineOriginalShapley.getInputFrameBuilder()); + MojoFrame requestFrame = + scoreRequestConverter.apply(request, pipelineOriginalShapley.getInputFrameBuilder()); return contribution(doShapleyContrib(requestFrame, true)); } private ContributionResponse transformedFeatureContribution(ScoreRequest request) { - MojoFrame requestFrame = scoreRequestConverter - .apply(request, pipelineTransformedShapley.getInputFrameBuilder()); + MojoFrame requestFrame = + scoreRequestConverter.apply(request, pipelineTransformedShapley.getInputFrameBuilder()); return contribution(doShapleyContrib(requestFrame, false)); } @@ -180,27 +174,29 @@ public ContributionResponse computeContribution(ContributionRequest request) { ShapleyType requestedShapleyType = request.getRequestShapleyValueType(); if (!ShapleyLoadOption.requestedTypeEnabled( - enabledShapleyTypes, requestedShapleyType.toString())) { + enabledShapleyTypes, requestedShapleyType.toString())) { throw new IllegalArgumentException( - String.format( - "Requested Shapley type %s not enabled for this scorer. Expected: %s", - requestedShapleyType, enabledShapleyTypes)); + String.format( + "Requested Shapley type %s not enabled for this scorer. Expected: %s", + requestedShapleyType, enabledShapleyTypes)); } MojoFrame requestFrame; switch (requestedShapleyType) { case TRANSFORMED: - requestFrame = contributionRequestConverter - .apply(request, pipelineTransformedShapley.getInputFrameBuilder()); + requestFrame = + contributionRequestConverter.apply( + request, pipelineTransformedShapley.getInputFrameBuilder()); return contribution(doShapleyContrib(requestFrame, false)); case ORIGINAL: - requestFrame = contributionRequestConverter - .apply(request, pipelineOriginalShapley.getInputFrameBuilder()); + requestFrame = + contributionRequestConverter.apply( + request, pipelineOriginalShapley.getInputFrameBuilder()); return contribution(doShapleyContrib(requestFrame, true)); default: throw new IllegalArgumentException( - "Only ORIGINAL or TRANSFORMED are accepted enums values of Shapley values"); + "Only ORIGINAL or TRANSFORMED are accepted enums values of Shapley values"); } } @@ -211,23 +207,22 @@ private ContributionResponse contribution(MojoFrame contributionFrame) { List outputGroupNames = getOutputGroups(outputMeta); if (ScoringType.CLASSIFICATION.equals(scoringType)) { - return contributionResponseConverter - .contributionResponseWithOutputGroup(contributionFrame, outputGroupNames); + return contributionResponseConverter.contributionResponseWithOutputGroup( + contributionFrame, outputGroupNames); } else { - return contributionResponseConverter - .contributionResponseWithNoOutputGroup(contributionFrame); + return contributionResponseConverter.contributionResponseWithNoOutputGroup(contributionFrame); } } private List getOutputGroups(MojoFrameMeta outputMeta) { int numberOutputColumns = outputMeta.getColumns().size(); - List outputClass = new ArrayList<>(); + List outputClass = new ArrayList<>(); for (int i = 0; i < numberOutputColumns; i++) { String outputClassName = outputMeta.getColumnName(i); // the MOJO API will provide list of target labels in the future // Link: https://github.com/h2oai/mojo2/issues/1366 String[] outputClassNameSplit = outputClassName.split("\\."); - String refinedOutputClass = outputClassNameSplit[outputClassNameSplit.length - 1 ]; + String refinedOutputClass = outputClassNameSplit[outputClassNameSplit.length - 1]; outputClass.add(refinedOutputClass); } return outputClass; @@ -257,8 +252,7 @@ public ScoreResponse scoreCsv(String csvFilePath) throws IOException { requestFrame = csvConverter.apply(csvStream, pipeline.getInputFrameBuilder()); } MojoFrame responseFrame = doScore(requestFrame); - ScoreResponse response = scoreResponseConverter.apply( - responseFrame, new ScoreRequest()); + ScoreResponse response = scoreResponseConverter.apply(responseFrame, new ScoreRequest()); response.id(pipeline.getUuid()); return response; } @@ -291,17 +285,16 @@ private static MojoFrame doScore(MojoFrame requestFrame) { * Method to get shapley contribution for an incoming request of type {@link ScoreRequest}. * * @param requestFrame {@link MojoFrame} - * @param isOriginal {@link boolean} Simple boolean to specify if the shapley contribution - * has to be performed for original features - * or transformed features + * @param isOriginal {@link boolean} Simple boolean to specify if the shapley contribution has to + * be performed for original features or transformed features * @return response {@link MojoFrame} */ private static MojoFrame doShapleyContrib(MojoFrame requestFrame, boolean isOriginal) { log.debug( - "Input has {} rows, {} columns: {}", - requestFrame.getNrows(), - requestFrame.getNcols(), - Arrays.toString(requestFrame.getColumnNames())); + "Input has {} rows, {} columns: {}", + requestFrame.getNrows(), + requestFrame.getNcols(), + Arrays.toString(requestFrame.getColumnNames())); MojoFrame shapleyResponseFrame; if (isOriginal) { shapleyResponseFrame = pipelineOriginalShapley.transform(requestFrame); @@ -309,10 +302,10 @@ private static MojoFrame doShapleyContrib(MojoFrame requestFrame, boolean isOrig shapleyResponseFrame = pipelineTransformedShapley.transform(requestFrame); } log.debug( - "Response has {} rows, {} columns: {}", - shapleyResponseFrame.getNrows(), - shapleyResponseFrame.getNcols(), - Arrays.toString(shapleyResponseFrame.getColumnNames())); + "Response has {} rows, {} columns: {}", + shapleyResponseFrame.getNrows(), + shapleyResponseFrame.getNcols(), + Arrays.toString(shapleyResponseFrame.getColumnNames())); return shapleyResponseFrame; } @@ -339,10 +332,9 @@ public boolean isPredictionIntervalSupport() { /** * Method to load mojo pipelines for shapley scoring based on configuration * - *

Order of operations to preserve backwards compatibility: - * 1. if property or env var shapley.types.enabled is set, load pipelines based on that - * 2. if shapley.enabled is true load all pipelines - * + *

Order of operations to preserve backwards compatibility: 1. if property or env var + * shapley.types.enabled is set, load pipelines based on that 2. if shapley.enabled is true load + * all pipelines */ private void loadMojoPipelinesForShapley() { if (ShapleyLoadOption.NONE == enabledShapleyTypes) { diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/RequestChecker.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/RequestChecker.java index 81cfd854..00044180 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/RequestChecker.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/RequestChecker.java @@ -2,7 +2,6 @@ import static java.util.Arrays.asList; -import ai.h2o.mojos.deploy.common.rest.model.Row; import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest; import ai.h2o.mojos.runtime.frame.MojoFrameMeta; import java.util.List; @@ -37,7 +36,7 @@ private String getProblemMessageOrNull(ScoreRequest scoreRequest, MojoFrameMeta if (fields == null || fields.isEmpty()) { return "List of input fields cannot be empty"; } - List rows = scoreRequest.getRows(); + List> rows = scoreRequest.getRows(); if (rows == null || rows.isEmpty()) { return "List of input data rows cannot be empty"; } @@ -48,7 +47,7 @@ private String getProblemMessageOrNull(ScoreRequest scoreRequest, MojoFrameMeta expectedFields.toString(), fields.toString()); } int i = 0; - for (Row row : scoreRequest.getRows()) { + for (List row : scoreRequest.getRows()) { if (row.size() != fields.size()) { return String.format("Not enough elements in row %d (zero-indexed)", i); } diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/SampleRequestBuilder.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/SampleRequestBuilder.java index fb803946..9aae8647 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/SampleRequestBuilder.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/SampleRequestBuilder.java @@ -1,7 +1,5 @@ package ai.h2o.mojos.deploy.common.transform; -import static java.util.Arrays.asList; - import ai.h2o.mojos.deploy.common.rest.model.Row; import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest; import ai.h2o.mojos.runtime.api.MojoColumnMeta; @@ -19,9 +17,10 @@ public class SampleRequestBuilder { /** Builds a valid {@link ScoreRequest} based on the given mojo input {@link MojoFrameMeta}. */ public ScoreRequest build(MojoFrameMeta inputMeta) { ScoreRequest request = new ScoreRequest(); - final List fields = inputMeta.getColumns().stream() - .map(MojoColumnMeta::getColumnName) - .collect(Collectors.toList()); + final List fields = + inputMeta.getColumns().stream() + .map(MojoColumnMeta::getColumnName) + .collect(Collectors.toList()); request.setFields(fields); Row row = new Row(); for (MojoColumn.Type type : inputMeta.getColumnTypes()) { diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestToMojoFrameConverter.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestToMojoFrameConverter.java index de2253b6..362b9fdc 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestToMojoFrameConverter.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestToMojoFrameConverter.java @@ -1,6 +1,5 @@ package ai.h2o.mojos.deploy.common.transform; -import ai.h2o.mojos.deploy.common.rest.model.Row; import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest; import ai.h2o.mojos.runtime.frame.MojoFrame; import ai.h2o.mojos.runtime.frame.MojoFrameBuilder; @@ -18,7 +17,7 @@ public MojoFrame apply(ScoreRequest scoreRequest, MojoFrameBuilder frameBuilder) List fields = scoreRequest.getFields(); if (scoreRequest.getRows() != null) { - for (Row row : scoreRequest.getRows()) { + for (List row : scoreRequest.getRows()) { MojoRowBuilder rowBuilder = frameBuilder.getMojoRowBuilder(); for (int i = 0; i < row.size(); i++) { rowBuilder.setValue(fields.get(i), row.get(i)); diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestTransformer.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestTransformer.java index af324ec2..8f848b30 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestTransformer.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestTransformer.java @@ -1,10 +1,9 @@ package ai.h2o.mojos.deploy.common.transform; import ai.h2o.mojos.deploy.common.rest.model.DataField; -import ai.h2o.mojos.deploy.common.rest.model.Row; import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest; import com.google.common.collect.ImmutableMap; - +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.function.BiConsumer; @@ -15,8 +14,8 @@ import org.slf4j.LoggerFactory; /** - * Transform scoring request rows, - * specifically, convert boolean literal string into 1, 0 respectively. + * Transform scoring request rows, specifically, convert boolean literal string into 1, 0 + * respectively. */ public class ScoreRequestTransformer implements BiConsumer> { @@ -25,40 +24,41 @@ public class ScoreRequestTransformer implements BiConsumer dataFields) { Map dataFieldMap = - dataFields.stream().collect( - ImmutableMap.toImmutableMap(DataField::getName, Function.identity() - ) - ); + dataFields.stream() + .collect(ImmutableMap.toImmutableMap(DataField::getName, Function.identity())); scoreRequest.setRows( - transformRow(scoreRequest.getFields(), scoreRequest.getRows(), dataFieldMap) - ); + transformRow(scoreRequest.getFields(), scoreRequest.getRows(), dataFieldMap)); } - private List transformRow( - List fields, List rows, Map dataFields - ) { - return rows.stream().map(row -> { - List transformData = IntStream.range(0, row.size()).mapToObj( - fieldIdx -> { - String colName = fields.get(fieldIdx); - String origin = row.get(fieldIdx); - if (dataFields.containsKey(colName)) { - String sanitizeValue = Utils.sanitizeBoolean( - origin, dataFields.get(colName).getDataType() - ); - if (!sanitizeValue.equals(origin)) { - logger.debug("Value '{}' parsed as '{}'", origin, sanitizeValue); - } - return sanitizeValue; - } else { - logger.warn("Column '{}' can not be found in Input schema", colName); - return origin; - } - } - ).collect(Collectors.toList()); - Row transformedRow = new Row(); - transformedRow.addAll(transformData); - return transformedRow; - }).collect(Collectors.toList()); + private List> transformRow( + List fields, List> rows, Map dataFields) { + return rows.stream() + .map( + row -> { + List transformData = + IntStream.range(0, row.size()) + .mapToObj( + fieldIdx -> { + String colName = fields.get(fieldIdx); + String origin = row.get(fieldIdx); + if (dataFields.containsKey(colName)) { + String sanitizeValue = + Utils.sanitizeBoolean( + origin, dataFields.get(colName).getDataType()); + if (!sanitizeValue.equals(origin)) { + logger.debug("Value '{}' parsed as '{}'", origin, sanitizeValue); + } + return sanitizeValue; + } else { + logger.warn("Column '{}' can not be found in Input schema", colName); + return origin; + } + }) + .collect(Collectors.toList()); + List transformedRow = new ArrayList<>(); + transformedRow.addAll(transformData); + return transformedRow; + }) + .collect(Collectors.toList()); } } diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ShapleyLoadOption.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ShapleyLoadOption.java index 3db9ea56..2a3c1a27 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ShapleyLoadOption.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/ShapleyLoadOption.java @@ -1,8 +1,6 @@ package ai.h2o.mojos.deploy.common.transform; -/** - * Enum defining options for loading the mojo to enable Shapley predictions. - */ +/** Enum defining options for loading the mojo to enable Shapley predictions. */ public enum ShapleyLoadOption { ALL, NONE, @@ -16,6 +14,7 @@ public enum ShapleyLoadOption { /** * Checks whether Shapley scoring is permitted. + * * @param option {@link ShapleyLoadOption} * @return {@link Boolean} */ @@ -33,6 +32,7 @@ public static boolean isEnabled(ShapleyLoadOption option) { /** * Checks whether requested type of Shapley value scoring is permitted. + * * @param requested {@link String} * @return {@link Boolean} */ @@ -45,12 +45,11 @@ public static boolean requestedTypeEnabled(ShapleyLoadOption option, String requ /** * Extracts configuration from system properties or environment variables. + * * @return {@link ShapleyLoadOption} */ public static ShapleyLoadOption fromEnvironment() { - return shapleyEnabledFromEnvironment() - ? ALL - : shapleyTypeFromEnvironment(); + return shapleyEnabledFromEnvironment() ? ALL : shapleyTypeFromEnvironment(); } private static boolean shapleyEnabledFromEnvironment() { diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/Utils.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/Utils.java index 675d5551..cf8a55eb 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/Utils.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/Utils.java @@ -1,9 +1,7 @@ package ai.h2o.mojos.deploy.common.transform; import ai.h2o.mojos.deploy.common.rest.model.DataField; -import ai.h2o.mojos.deploy.common.rest.model.Row; import ai.h2o.mojos.runtime.frame.MojoFrame; - import java.util.List; public class Utils { @@ -12,13 +10,13 @@ public class Utils { * * @param mojoFrame {@link MojoFrame} */ - public static void copyResultFields(MojoFrame mojoFrame, List outputRows) { + public static void copyResultFields(MojoFrame mojoFrame, List> outputRows) { String[][] outputColumns = new String[mojoFrame.getNcols()][]; for (int col = 0; col < mojoFrame.getNcols(); col++) { outputColumns[col] = mojoFrame.getColumn(col).getDataAsStrings(); } for (int row = 0; row < mojoFrame.getNrows(); row++) { - Row outputRow = outputRows.get(row); + List outputRow = outputRows.get(row); for (String[] resultColumn : outputColumns) { outputRow.add(resultColumn[row]); } @@ -27,13 +25,12 @@ public static void copyResultFields(MojoFrame mojoFrame, List outputRows) { /** * Sanitize boolean string literal values true / false (case insensitive) into 1 / 0 respectively. + * * @return sanitized string. */ public static String sanitizeBoolean(String value, DataField.DataTypeEnum dataType) { - if ( - dataType.equals(DataField.DataTypeEnum.FLOAT32) - || dataType.equals(DataField.DataTypeEnum.FLOAT64) - ) { + if (dataType.equals(DataField.DataTypeEnum.FLOAT32) + || dataType.equals(DataField.DataTypeEnum.FLOAT64)) { if ("true".equalsIgnoreCase(value)) { return "1"; } else if ("false".equalsIgnoreCase(value)) { diff --git a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToContributionResponseConverterTest.java b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToContributionResponseConverterTest.java index 6c950898..7c326c28 100644 --- a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToContributionResponseConverterTest.java +++ b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToContributionResponseConverterTest.java @@ -28,15 +28,16 @@ import org.junit.jupiter.params.provider.MethodSource; class MojoFrameToContributionResponseConverterTest { - private final MojoFrameToContributionResponseConverter converter - = new MojoFrameToContributionResponseConverter(); + private final MojoFrameToContributionResponseConverter converter = + new MojoFrameToContributionResponseConverter(); @Test void convertEmptyRowsResponse_succeeds() { // Given - MojoFrame mojoFrame = new MojoFrameBuilder( - MojoFrameMeta.getEmpty(), Collections.emptyList(), Collections.emptyMap()) - .toMojoFrame(); + MojoFrame mojoFrame = + new MojoFrameBuilder( + MojoFrameMeta.getEmpty(), Collections.emptyList(), Collections.emptyMap()) + .toMojoFrame(); // When ContributionResponse result = converter.contributionResponseWithNoOutputGroup(mojoFrame); @@ -55,38 +56,45 @@ void convertMoreTypesResponse_succeeds(String[][] contributions) { MojoColumn.Type[] types = {Str, Float32, Float64, Bool, Int32, Int64}; // When - ContributionResponse result = converter.contributionResponseWithNoOutputGroup(buildMojoFrame( - Stream.of(types).map(Object::toString).toArray(String[]::new), types, contributions)); + ContributionResponse result = + converter.contributionResponseWithNoOutputGroup( + buildMojoFrame( + Stream.of(types).map(Object::toString).toArray(String[]::new), + types, + contributions)); // Then assertThat(result.getContributionGroups().size()).isEqualTo(1); assertThat(result.getContributionGroups().get(0).getContributions()) - .containsExactly(Stream.of(contributions) - .map(MojoFrameToContributionResponseConverterTest::asRow).toArray()); + .containsExactly( + Stream.of(contributions) + .map(MojoFrameToContributionResponseConverterTest::asRow) + .toArray()); assertThat(result.getFeatures()) - .containsExactly("Str", "Float32", "Float64", "Bool", "Int32", "Int64") + .containsExactly("Str", "Float32", "Float64", "Bool", "Int32", "Int64") .inOrder(); } @SuppressWarnings("unused") private static Stream provideValues_convertMoreTypesResponse_succeeds() { return Stream.of( - Arguments.of((Object) new String[][] {{"str", "1.1", "2.2", "1", "123", "123456789"}}), - Arguments.of((Object) new String[][] {{null, null, null, null, null, null}})); + Arguments.of((Object) new String[][] {{"str", "1.1", "2.2", "1", "123", "123456789"}}), + Arguments.of((Object) new String[][] {{null, null, null, null, null, null}})); } @Test void convertEmptyRowsResponse_withEmptyOutputGroup_succeeds() { // Given - MojoFrame mojoFrame = new MojoFrameBuilder( - MojoFrameMeta.getEmpty(), Collections.emptyList(), Collections.emptyMap()) + MojoFrame mojoFrame = + new MojoFrameBuilder( + MojoFrameMeta.getEmpty(), Collections.emptyList(), Collections.emptyMap()) .toMojoFrame(); List outputGroupNames = new ArrayList<>(); // When - ContributionResponse result = converter - .contributionResponseWithOutputGroup(mojoFrame, outputGroupNames); + ContributionResponse result = + converter.contributionResponseWithOutputGroup(mojoFrame, outputGroupNames); // Then assertThat(result.getContributionGroups().size()).isEqualTo(0); @@ -101,14 +109,15 @@ void convertSingleFeatureResponse_succeeds() { String[][] contributions = {{"23.6"}}; // When - ContributionResponse result = converter - .contributionResponseWithNoOutputGroup(buildMojoFrame(features, types, contributions)); + ContributionResponse result = + converter.contributionResponseWithNoOutputGroup( + buildMojoFrame(features, types, contributions)); // Then assertThat(result.getContributionGroups().size()).isEqualTo(1); assertThat(result.getContributionGroups().get(0).getContributions().size()).isEqualTo(1); assertThat(result.getContributionGroups().get(0).getContributions()) - .containsExactly(asRow("23.6")); + .containsExactly(asRow("23.6")); assertThat(result.getContributionGroups().get(0).getOutputGroup()).isNull(); assertThat(result.getFeatures()).containsExactly("feature"); } @@ -123,15 +132,15 @@ void convertSingleFeatureResponse_withOneOutputGroup_succeeds() { List outputGroupNames = Collections.singletonList("test"); // When - ContributionResponse result = converter - .contributionResponseWithOutputGroup( - buildMojoFrame(features, types, contributions), outputGroupNames); + ContributionResponse result = + converter.contributionResponseWithOutputGroup( + buildMojoFrame(features, types, contributions), outputGroupNames); // Then assertThat(result.getContributionGroups().size()).isEqualTo(1); assertThat(result.getContributionGroups().get(0).getContributions().size()).isEqualTo(1); assertThat(result.getContributionGroups().get(0).getContributions()) - .containsExactly(asRow("122.2")); + .containsExactly(asRow("122.2")); assertThat(result.getContributionGroups().get(0).getOutputGroup()).isEqualTo("test"); assertThat(result.getFeatures()).containsExactly("feature"); } @@ -145,7 +154,8 @@ void convertSingleFeatureResponse_withManyOutputGroup_succeeds() { List outputGroupNames = Arrays.asList("test1", "test2"); // When - ContributionResponse result = converter.contributionResponseWithOutputGroup( + ContributionResponse result = + converter.contributionResponseWithOutputGroup( buildMojoFrame(features, types, contributions), outputGroupNames); // Then @@ -153,12 +163,12 @@ void convertSingleFeatureResponse_withManyOutputGroup_succeeds() { assertThat(result.getContributionGroups().get(0).getContributions().size()).isEqualTo(1); assertThat(result.getContributionGroups().get(0).getContributions()) - .containsExactly(asRow("122.2")); + .containsExactly(asRow("122.2")); assertThat(result.getContributionGroups().get(0).getOutputGroup()).isEqualTo("test1"); assertThat(result.getContributionGroups().get(1).getContributions().size()).isEqualTo(1); assertThat(result.getContributionGroups().get(1).getContributions()) - .containsExactly(asRow("34.6")); + .containsExactly(asRow("34.6")); assertThat(result.getContributionGroups().get(1).getOutputGroup()).isEqualTo("test2"); assertThat(result.getFeatures()).containsExactly("feature"); @@ -174,7 +184,8 @@ void convertSingleFeatureResponse_withManyRowsAndOutputGroups_succeeds() { List outputGroupNames = Arrays.asList("test1", "test2"); // When - ContributionResponse result = converter.contributionResponseWithOutputGroup( + ContributionResponse result = + converter.contributionResponseWithOutputGroup( buildMojoFrame(features, types, contributions), outputGroupNames); // Then @@ -182,12 +193,12 @@ void convertSingleFeatureResponse_withManyRowsAndOutputGroups_succeeds() { assertThat(result.getContributionGroups().get(0).getContributions().size()).isEqualTo(2); assertThat(result.getContributionGroups().get(0).getContributions()) - .containsExactly(asRow("122.2"), asRow("90.2")); + .containsExactly(asRow("122.2"), asRow("90.2")); assertThat(result.getContributionGroups().get(0).getOutputGroup()).isEqualTo("test1"); assertThat(result.getContributionGroups().get(1).getContributions().size()).isEqualTo(2); assertThat(result.getContributionGroups().get(1).getContributions()) - .containsExactly(asRow("34.6"), asRow("45.6")); + .containsExactly(asRow("34.6"), asRow("45.6")); assertThat(result.getContributionGroups().get(1).getOutputGroup()).isEqualTo("test2"); assertThat(result.getFeatures()).containsExactly("feature"); @@ -199,49 +210,48 @@ void convertSingleFeatureResponse_withManyFeaturesAndOutputGroups_succeeds() { String[] features = {"feature1.test1", "feature1.test2", "feature2.test1", "feature2.test2"}; MojoColumn.Type[] types = {Float32, Float32, Float32, Float32}; String[][] contributions = { - {"122.2", "34.6", "90.9", "78.0"}, - {"90.2", "45.6", "56.9", "56.0"}}; + {"122.2", "34.6", "90.9", "78.0"}, + {"90.2", "45.6", "56.9", "56.0"} + }; List outputGroupNames = Arrays.asList("test1", "test2"); // When - ContributionResponse result = converter - .contributionResponseWithOutputGroup( - buildMojoFrame(features, types, contributions), outputGroupNames); + ContributionResponse result = + converter.contributionResponseWithOutputGroup( + buildMojoFrame(features, types, contributions), outputGroupNames); // Then assertThat(result.getContributionGroups().size()).isEqualTo(2); assertThat(result.getContributionGroups().get(0).getContributions().size()).isEqualTo(2); assertThat(result.getContributionGroups().get(0).getContributions()) - .containsExactly(asRow("122.2", "90.9"), asRow("90.2", "56.9")); + .containsExactly(asRow("122.2", "90.9"), asRow("90.2", "56.9")); assertThat(result.getContributionGroups().get(0).getOutputGroup()).isEqualTo("test1"); assertThat(result.getContributionGroups().get(1).getContributions().size()).isEqualTo(2); assertThat(result.getContributionGroups().get(1).getContributions()) - .containsExactly(asRow("34.6", "78.0"), asRow("45.6", "56.0")); + .containsExactly(asRow("34.6", "78.0"), asRow("45.6", "56.0")); assertThat(result.getContributionGroups().get(1).getOutputGroup()).isEqualTo("test2"); assertThat(result.getFeatures()).containsExactly("feature1", "feature2"); } - private static MojoFrame buildMojoFrame(String[] fields, - MojoColumn.Type[] types, - String[][] values) { + private static MojoFrame buildMojoFrame( + String[] fields, MojoColumn.Type[] types, String[][] values) { return buildMojoFrame(fields, types, values, (rb, type, col, value) -> rb.setValue(col, value)); } private static MojoFrame buildMojoFrame( - String[] fields, MojoColumn.Type[] types, - T[][] values, - MojoFrameToScoreResponseConverterTest.RowBuilderSetter setter) { - final List columns = MojoColumnMeta.toColumns( - fields, - types, - MojoColumn.Kind.Output); + String[] fields, + MojoColumn.Type[] types, + T[][] values, + MojoFrameToScoreResponseConverterTest.RowBuilderSetter setter) { + final List columns = + MojoColumnMeta.toColumns(fields, types, MojoColumn.Kind.Output); final MojoFrameMeta meta = new MojoFrameMeta(columns); final MojoFrameBuilder frameBuilder = - new MojoFrameBuilder(meta, Collections.emptyList(), Collections.emptyMap()); + new MojoFrameBuilder(meta, Collections.emptyList(), Collections.emptyMap()); for (T[] row : values) { MojoRowBuilder rowBuilder = frameBuilder.getMojoRowBuilder(); int col = 0; diff --git a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverterTest.java b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverterTest.java index 90d4ea7a..fcfc7e51 100644 --- a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverterTest.java +++ b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoFrameToScoreResponseConverterTest.java @@ -19,13 +19,11 @@ import ai.h2o.mojos.runtime.frame.MojoFrameBuilder; import ai.h2o.mojos.runtime.frame.MojoFrameMeta; import ai.h2o.mojos.runtime.frame.MojoRowBuilder; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Stream; - import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -36,8 +34,7 @@ class MojoFrameToScoreResponseConverterTest { @Test void convertEmptyRowsResponse_succeeds() { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); ScoreRequest scoreRequest = new ScoreRequest(); MojoFrame mojoFrame = new MojoFrameBuilder( @@ -55,16 +52,14 @@ void convertEmptyRowsResponse_succeeds() { @Test void convertSingleFieldResponse_succeeds() { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); String[] fields = {"field"}; Type[] types = {Str}; String[][] values = {{"value"}}; ScoreRequest scoreRequest = new ScoreRequest(); // When - ScoreResponse result = converter.apply( - buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()).containsExactly(asRow("value")); @@ -74,8 +69,7 @@ void convertSingleFieldResponse_succeeds() { @Test void convertSingleFieldResponse_withoutFieldNames_succeeds() { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); String[] fields = {"field"}; Type[] types = {Str}; String[][] values = {{"value"}}; @@ -83,8 +77,7 @@ void convertSingleFieldResponse_withoutFieldNames_succeeds() { scoreRequest.setNoFieldNamesInOutput(true); // When - ScoreResponse result = converter.apply( - buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()).containsExactly(asRow("value")); @@ -94,8 +87,7 @@ void convertSingleFieldResponse_withoutFieldNames_succeeds() { @Test void convertIncludesOneField_succeeds() { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); String[] fields = {"outputField"}; Type[] types = {Str}; String[][] values = {{"outputValue"}}; @@ -105,8 +97,7 @@ void convertIncludesOneField_succeeds() { scoreRequest.addRowsItem(asRow("inputValue")); // When - ScoreResponse result = converter.apply( - buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()).containsExactly(asRow("inputValue", "outputValue")); @@ -116,8 +107,7 @@ void convertIncludesOneField_succeeds() { @Test void convertIncludesSomeFields_succeeds() { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); String[] fields = {"outputField1", "outputField2"}; Type[] types = {Str, Str}; String[][] values = {{"outputValue1", "outputValue2"}}; @@ -127,8 +117,7 @@ void convertIncludesSomeFields_succeeds() { scoreRequest.addRowsItem(asRow("inputValue1", "omittedValue", "inputValue3")); // When - ScoreResponse result = converter.apply( - buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()) @@ -141,8 +130,7 @@ void convertIncludesSomeFields_succeeds() { @Test void convertIncludePresentIdField_succeeds() { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); String[] fields = {"outputField"}; Type[] types = {Str}; String[][] values = {{"outputValue"}}; @@ -153,8 +141,7 @@ void convertIncludePresentIdField_succeeds() { scoreRequest.setIdField("id"); // When - ScoreResponse result = converter.apply( - buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()).containsExactly(asRow("inputValue", "testId", "outputValue")); @@ -164,8 +151,7 @@ void convertIncludePresentIdField_succeeds() { @Test void convertIncludeMissingIdField_generateUuid() { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); String[] fields = {"outputField"}; Type[] types = {Str}; String[][] values = {{"outputValue"}}; @@ -176,8 +162,7 @@ void convertIncludeMissingIdField_generateUuid() { scoreRequest.setIdField("id"); // When - ScoreResponse result = converter.apply( - buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()).hasSize(1); @@ -193,21 +178,19 @@ void convertIncludeMissingIdField_generateUuid() { @Test void convertMoreRowsResponse_succeeds() { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); String[] fields = {"field"}; Type[] types = {Str}; String[][] values = {{"value1"}, {"value2"}, {"value3"}}; ScoreRequest scoreRequest = new ScoreRequest(); // When - ScoreResponse result = converter.apply( - buildMojoFrame(fields, types, values), scoreRequest); + ScoreResponse result = converter.apply(buildMojoFrame(fields, types, values), scoreRequest); // Then assertThat(result.getScore()) - .containsExactly(Stream.of(values) - .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()) + .containsExactly( + Stream.of(values).map(MojoFrameToScoreResponseConverterTest::asRow).toArray()) .inOrder(); assertThat(result.getFields()).containsExactly("field"); } @@ -216,33 +199,31 @@ void convertMoreRowsResponse_succeeds() { @MethodSource("provideValues_convertMoreTypesResponse_succeeds") void convertMoreTypesResponse_succeeds(String[][] values) { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); Type[] types = {Str, Float32, Float64, Bool, Int32, Int64}; ScoreRequest scoreRequest = new ScoreRequest(); // When ScoreResponse result = converter.apply( - buildMojoFrame( - Stream.of(types).map(Object::toString).toArray(String[]::new), types, values), - scoreRequest); + buildMojoFrame( + Stream.of(types).map(Object::toString).toArray(String[]::new), types, values), + scoreRequest); // Then assertThat(result.getScore()) - .containsExactly(Stream.of(values) - .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); + .containsExactly( + Stream.of(values).map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); assertThat(result.getFields()) - .containsExactly("Str", "Float32", "Float64", "Bool", "Int32", "Int64") - .inOrder(); + .containsExactly("Str", "Float32", "Float64", "Bool", "Int32", "Int64") + .inOrder(); } @ParameterizedTest @MethodSource("provideValues_convertMoreTypesResponse_actualValues_succeeds") void convertMoreTypesResponse_actualValues_succeeds(Object[][] values, String[][] expValues) { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); Type[] types = {Str, Float32, Float64, Bool, Int32, Int64}; ScoreRequest scoreRequest = new ScoreRequest(); @@ -253,7 +234,7 @@ void convertMoreTypesResponse_actualValues_succeeds(Object[][] values, String[][ Stream.of(types).map(Object::toString).toArray(String[]::new), types, values, - MojoFrameToScoreResponseConverterTest::setJavaValue), + MojoFrameToScoreResponseConverterTest::setJavaValue), scoreRequest); // Then @@ -270,41 +251,43 @@ void convertMoreTypesResponse_actualValues_succeeds(Object[][] values, String[][ void convertMoreTypesResponse_enablePredictionIntervalSameType_succeeds( Object[][] values, String[][] expValues) { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(true); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(true); Type[] types = {Float64, Float64, Float64}; ScoreRequest scoreRequest = new ScoreRequest().requestPredictionIntervals(true); // When ScoreResponse result = converter.apply( - buildMojoFrame( - new String[]{"result.upper", "result", "result.lower"}, - types, values, MojoFrameToScoreResponseConverterTest::setJavaValue), - scoreRequest); + buildMojoFrame( + new String[] {"result.upper", "result", "result.lower"}, + types, + values, + MojoFrameToScoreResponseConverterTest::setJavaValue), + scoreRequest); // Then assertThat(result.getScore()) .containsExactly( - Stream.of(expValues) - .map(input -> Arrays.asList(input).subList(1, 2)) - .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); - assertThat(result.getFields()) - .containsExactly("result") - .inOrder(); + Stream.of(expValues) + .map(input -> Arrays.asList(input).subList(1, 2)) + .map(MojoFrameToScoreResponseConverterTest::asRow) + .toArray()); + assertThat(result.getFields()).containsExactly("result").inOrder(); assertThat(result.getPredictionIntervals().getFields()) - .containsExactly("result.upper", "result.lower") - .inOrder(); + .containsExactly("result.upper", "result.lower") + .inOrder(); assertThat(result.getPredictionIntervals().getRows()) .containsExactly( - Stream.of(expValues) - .map(input -> { - List intervalRow = new ArrayList<>(2); - intervalRow.add(input[0]); - intervalRow.add(input[2]); - return intervalRow; - }) - .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); + Stream.of(expValues) + .map( + input -> { + List intervalRow = new ArrayList<>(2); + intervalRow.add(input[0]); + intervalRow.add(input[2]); + return intervalRow; + }) + .map(MojoFrameToScoreResponseConverterTest::asRow) + .toArray()); } @ParameterizedTest @@ -312,30 +295,29 @@ void convertMoreTypesResponse_enablePredictionIntervalSameType_succeeds( void convertMoreTypesResponse_disablePredictionIntervalSameType_succeeds( Object[][] values, String[][] expValues) { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(true); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(true); Type[] types = {Float64, Float64, Float64}; ScoreRequest scoreRequest = new ScoreRequest(); // When ScoreResponse result = converter.apply( - buildMojoFrame( - new String[]{"result.upper", "result", "result.lower"}, - types, values, MojoFrameToScoreResponseConverterTest::setJavaValue), - scoreRequest); + buildMojoFrame( + new String[] {"result.upper", "result", "result.lower"}, + types, + values, + MojoFrameToScoreResponseConverterTest::setJavaValue), + scoreRequest); // Then assertThat(result.getScore()) .containsExactly( - Stream.of(expValues) - .map(input -> Arrays.asList(input).subList(1, 2)) - .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); - assertThat(result.getFields()) - .containsExactly("result") - .inOrder(); - assertThat(result.getPredictionIntervals()) - .isNull(); + Stream.of(expValues) + .map(input -> Arrays.asList(input).subList(1, 2)) + .map(MojoFrameToScoreResponseConverterTest::asRow) + .toArray()); + assertThat(result.getFields()).containsExactly("result").inOrder(); + assertThat(result.getPredictionIntervals()).isNull(); } @ParameterizedTest @@ -343,29 +325,28 @@ void convertMoreTypesResponse_disablePredictionIntervalSameType_succeeds( void convertMoreTypesResponse_disablePredictionIntervalNotSupportSameType_succeeds( Object[][] values, String[][] expValues) { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); Type[] types = {Float64, Float64, Float64}; ScoreRequest scoreRequest = new ScoreRequest(); // When ScoreResponse result = converter.apply( - buildMojoFrame( - new String[]{"result.upper", "result", "result.lower"}, - types, values, MojoFrameToScoreResponseConverterTest::setJavaValue), - scoreRequest); + buildMojoFrame( + new String[] {"result.upper", "result", "result.lower"}, + types, + values, + MojoFrameToScoreResponseConverterTest::setJavaValue), + scoreRequest); // Then assertThat(result.getScore()) .containsExactly( - Stream.of(expValues) - .map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); + Stream.of(expValues).map(MojoFrameToScoreResponseConverterTest::asRow).toArray()); assertThat(result.getFields()) - .containsExactly("result.upper", "result", "result.lower") - .inOrder(); - assertThat(result.getPredictionIntervals()) - .isNull(); + .containsExactly("result.upper", "result", "result.lower") + .inOrder(); + assertThat(result.getPredictionIntervals()).isNull(); } @ParameterizedTest @@ -373,16 +354,18 @@ void convertMoreTypesResponse_disablePredictionIntervalNotSupportSameType_succee void convertMoreTypesResponse_enablePredictionIntervalDiffType_fails( Object[][] values, Type[] types) { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(false); + final MojoFrameToScoreResponseConverter converter = + new MojoFrameToScoreResponseConverter(false); ScoreRequest scoreRequest = new ScoreRequest().requestPredictionIntervals(true); // When & Then try { converter.apply( buildMojoFrame( - Stream.of(types).map(Object::toString).toArray(String[]::new), - types, values, MojoFrameToScoreResponseConverterTest::setJavaValue), + Stream.of(types).map(Object::toString).toArray(String[]::new), + types, + values, + MojoFrameToScoreResponseConverterTest::setJavaValue), scoreRequest); } catch (Exception e) { assertThat(e instanceof IllegalStateException).isTrue(); @@ -394,56 +377,57 @@ void convertMoreTypesResponse_enablePredictionIntervalDiffType_fails( void convertMoreTypesResponse_disablePredictionIntervalNotSupportDiffType_succeeds( Object[][] values, Type[] types) { // Given - final MojoFrameToScoreResponseConverter converter - = new MojoFrameToScoreResponseConverter(); + final MojoFrameToScoreResponseConverter converter = new MojoFrameToScoreResponseConverter(); ScoreRequest scoreRequest = new ScoreRequest(); // When ScoreResponse result = converter.apply( - buildMojoFrame( - Stream.of(types).map(Object::toString).toArray(String[]::new), - types, values, MojoFrameToScoreResponseConverterTest::setJavaValue), - scoreRequest); + buildMojoFrame( + Stream.of(types).map(Object::toString).toArray(String[]::new), + types, + values, + MojoFrameToScoreResponseConverterTest::setJavaValue), + scoreRequest); // Then assertThat(result.getFields()) - .containsExactly(Stream.of(types).map(Object::toString).toArray()) - .inOrder(); - assertThat(result.getPredictionIntervals()) - .isNull(); + .containsExactly(Stream.of(types).map(Object::toString).toArray()) + .inOrder(); + assertThat(result.getPredictionIntervals()).isNull(); } private static Stream provideValues_predictionIntervalEnabledResponse_succeeds() { return Stream.of( - Arguments.of( - new Double[][]{{3.5, -1.0, 2.0}, {3.3, 11.9, 10.3}}, - new String[][]{{"3.5", "-1.0", "2.0"}, {"3.3", "11.9", "10.3"}}), - Arguments.of( - new Double[][]{{2.7, 3.4, 5.9}, {1.1, 2.2, 3.3}}, - new String[][]{{"2.7", "3.4", "5.9"}, {"1.1", "2.2", "3.3"}}) - ); + Arguments.of( + new Double[][] {{3.5, -1.0, 2.0}, {3.3, 11.9, 10.3}}, + new String[][] {{"3.5", "-1.0", "2.0"}, {"3.3", "11.9", "10.3"}}), + Arguments.of( + new Double[][] {{2.7, 3.4, 5.9}, {1.1, 2.2, 3.3}}, + new String[][] {{"2.7", "3.4", "5.9"}, {"1.1", "2.2", "3.3"}})); } private static Stream provideValue_predictionIntervalEnabledResponse_fails() { return Stream.of( - Arguments.of(new Double[][]{{12.2, 11.221},{1.1, 99.1}}, new Type[]{Float64, Float64}), - Arguments.of(new Double[][]{{10.1}, {121.1}}, new Type[]{Float64}), - Arguments.of(new Double[][]{}, new Type[]{}), - Arguments.of(new Object[][]{ - {"abc", null, 12}, {"bbc", 12.4f, 15}}, new Type[]{Str, Float32, Int32}), - Arguments.of(new Object[][]{ - {90L, 1.21f, 12}, {11L, 12.4f, 15}}, new Type[]{Int64, Float32, Int32}), - Arguments.of(new Object[][]{ - {false, true, false}, {true, null, false}}, new Type[]{Bool, Bool, Bool}) - ); + Arguments.of(new Double[][] {{12.2, 11.221}, {1.1, 99.1}}, new Type[] {Float64, Float64}), + Arguments.of(new Double[][] {{10.1}, {121.1}}, new Type[] {Float64}), + Arguments.of(new Double[][] {}, new Type[] {}), + Arguments.of( + new Object[][] {{"abc", null, 12}, {"bbc", 12.4f, 15}}, + new Type[] {Str, Float32, Int32}), + Arguments.of( + new Object[][] {{90L, 1.21f, 12}, {11L, 12.4f, 15}}, + new Type[] {Int64, Float32, Int32}), + Arguments.of( + new Object[][] {{false, true, false}, {true, null, false}}, + new Type[] {Bool, Bool, Bool})); } @SuppressWarnings("unused") private static Stream provideValues_convertMoreTypesResponse_succeeds() { return Stream.of( - Arguments.of((Object) new String[][] {{"str", "1.1", "2.2", "1", "123", "123456789"}}), - Arguments.of((Object) new String[][] {{null, null, null, null, null, null}})); + Arguments.of((Object) new String[][] {{"str", "1.1", "2.2", "1", "123", "123456789"}}), + Arguments.of((Object) new String[][] {{null, null, null, null, null, null}})); } @SuppressWarnings("unused") diff --git a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoPipelineToModelInfoConverterTest.java b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoPipelineToModelInfoConverterTest.java index 649f1387..0b2bdd19 100644 --- a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoPipelineToModelInfoConverterTest.java +++ b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoPipelineToModelInfoConverterTest.java @@ -137,7 +137,9 @@ private static DataField[] toDataFields(String[] inputNames, DataTypeEnum[] inpu return result; } - /** Dummy test {@link MojoPipeline} just to be able to test the transformation. */ + /** + * Dummy test {@link MojoPipeline} just to be able to test the transformation. + */ private static class DummyPipeline extends MojoPipeline { private final MojoFrameMeta inputMeta; private final MojoFrameMeta outputMeta; @@ -199,7 +201,6 @@ public void setListener(BasePipelineListener listener) { @Override public void printPipelineInfo(PrintStream out) { - } } } diff --git a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoScorerTest.java b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoScorerTest.java index df03cc8e..b1de0e40 100644 --- a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoScorerTest.java +++ b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/MojoScorerTest.java @@ -23,7 +23,6 @@ import ai.h2o.mojos.runtime.frame.MojoFrameBuilder; import ai.h2o.mojos.runtime.frame.MojoFrameMeta; import ai.h2o.mojos.runtime.frame.MojoRowBuilder; - import java.io.File; import java.io.PrintStream; import java.util.ArrayList; @@ -32,11 +31,9 @@ import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; - import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -50,13 +47,20 @@ class MojoScorerTest { private static final String TEST_UUID = "TEST_UUID"; private static MockedStatic pipelineSettings = null; - @Mock private ScoreRequestToMojoFrameConverter scoreRequestConverter; - @Mock private MojoFrameToScoreResponseConverter scoreResponseConverter; - @Mock private MojoFrameToContributionResponseConverter contributionResponseConverter; - @Mock private ContributionRequestToMojoFrameConverter contributionRequestConverter; - @Mock private MojoPipelineToModelInfoConverter modelInfoConverter; - @Mock private ScoreRequestTransformer scoreRequestTransformer; - @Mock private CsvToMojoFrameConverter csvConverter; + @Mock + private ScoreRequestToMojoFrameConverter scoreRequestConverter; + @Mock + private MojoFrameToScoreResponseConverter scoreResponseConverter; + @Mock + private MojoFrameToContributionResponseConverter contributionResponseConverter; + @Mock + private ContributionRequestToMojoFrameConverter contributionRequestConverter; + @Mock + private MojoPipelineToModelInfoConverter modelInfoConverter; + @Mock + private ScoreRequestTransformer scoreRequestTransformer; + @Mock + private CsvToMojoFrameConverter csvConverter; @BeforeAll static void setup() { @@ -69,13 +73,17 @@ private static void mockDummyPipeline() { pipelineSettings.close(); } MojoPipeline dummyPipeline = - new DummyPipeline(TEST_UUID, MojoFrameMeta.getEmpty(), MojoFrameMeta.getEmpty()); + new DummyPipeline(TEST_UUID, MojoFrameMeta.getEmpty(), MojoFrameMeta.getEmpty()); pipelineSettings = Mockito.mockStatic(MojoPipelineService.class); - pipelineSettings.when(() -> MojoPipelineService - .loadPipeline(new File(MOJO_PIPELINE_PATH))).thenReturn(dummyPipeline); - pipelineSettings.when(() -> MojoPipelineService - .loadPipeline(Mockito.eq(new File(MOJO_PIPELINE_PATH)), any(PipelineConfig.class))) - .thenReturn(dummyPipeline); + pipelineSettings + .when(() -> MojoPipelineService.loadPipeline(new File(MOJO_PIPELINE_PATH))) + .thenReturn(dummyPipeline); + pipelineSettings + .when( + () -> + MojoPipelineService.loadPipeline( + Mockito.eq(new File(MOJO_PIPELINE_PATH)), any(PipelineConfig.class))) + .thenReturn(dummyPipeline); } @AfterAll @@ -100,11 +108,9 @@ void verifyScoreRequestWithoutShapley_ShapleyDisabled_Succeeds() { request.addFieldsItem("field1"); request.addRowsItem(toRow("text")); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); @@ -121,17 +127,14 @@ void verifyScoreRequestWithTransformedShapley_ShapleyDisabled_Fails() { request.addRowsItem(toRow("text")); request.setRequestShapleyValueType(ShapleyType.TRANSFORMED); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); // When & Then - assertThrows( - IllegalArgumentException.class, () -> scorer.score(request)); + assertThrows(IllegalArgumentException.class, () -> scorer.score(request)); } @Test @@ -143,17 +146,14 @@ void verifyScoreRequestWithOriginalShapley_ShapleyDisabled_Fails() { request.addRowsItem(toRow("text")); request.setRequestShapleyValueType(ShapleyType.TRANSFORMED); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); // When & Then - assertThrows( - IllegalArgumentException.class, () -> scorer.score(request)); + assertThrows(IllegalArgumentException.class, () -> scorer.score(request)); } @Test @@ -168,8 +168,7 @@ void verifyContributionWithOriginalShapley_ShapleyDisabled_Fails() { MojoScorer scorer = dummyScorer(); // When & Then - assertThrows(IllegalArgumentException.class, () -> scorer - .computeContribution(request)); + assertThrows(IllegalArgumentException.class, () -> scorer.computeContribution(request)); } @Test @@ -183,8 +182,7 @@ void verifyContributionWithTransformedShapley_ShapleyDisabled_Fails() { MojoScorer scorer = dummyScorer(); // When & Then - assertThrows(IllegalArgumentException.class, () -> scorer - .computeContribution(request)); + assertThrows(IllegalArgumentException.class, () -> scorer.computeContribution(request)); } @Test @@ -199,8 +197,7 @@ void verifyContributionWithOriginalShapley_TransformedShapleyEnabled_Fails() { MojoScorer scorer = dummyScorer(); // When & Then - assertThrows(IllegalArgumentException.class, () -> scorer - .computeContribution(request)); + assertThrows(IllegalArgumentException.class, () -> scorer.computeContribution(request)); } @Test @@ -215,8 +212,7 @@ void verifyContributionWithTransformedShapley_OriginalShapleyEnabled_Fails() { MojoScorer scorer = dummyScorer(); // When & Then - assertThrows(IllegalArgumentException.class, () -> scorer - .computeContribution(request)); + assertThrows(IllegalArgumentException.class, () -> scorer.computeContribution(request)); } @Test @@ -228,11 +224,9 @@ void verifyScoreRequestWithoutShapley_ShapleyEnabled_Succeeds() { request.addFieldsItem("field1"); request.addRowsItem(toRow("text")); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); @@ -250,11 +244,9 @@ void verifyScoreRequestWithTransformedShapley_ShapleyEnabled_Succeeds() { request.addRowsItem(toRow("text")); request.setRequestShapleyValueType(ShapleyType.TRANSFORMED); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); @@ -272,11 +264,9 @@ void verifyScoreRequestWithOriginalShapley_ShapleyEnabled_Succeeds() { request.addRowsItem(toRow("text")); request.setRequestShapleyValueType(ShapleyType.ORIGINAL); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); @@ -293,11 +283,9 @@ void verifyScoreRequestWithoutShapley_ShapleyOptionAll_Succeeds() { request.addFieldsItem("field1"); request.addRowsItem(toRow("text")); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); @@ -315,11 +303,9 @@ void verifyScoreRequestWithTransformedShapley_ShapleyOptionAll_Succeeds() { request.addRowsItem(toRow("text")); request.setRequestShapleyValueType(ShapleyType.TRANSFORMED); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); @@ -337,11 +323,9 @@ void verifyScoreRequestWithOriginalShapley_ShapleyOptionAll_Succeeds() { request.addRowsItem(toRow("text")); request.setRequestShapleyValueType(ShapleyType.ORIGINAL); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); @@ -358,11 +342,9 @@ void verifyScoreRequestWithoutShapley_ShapleyOptionTransformed_Succeeds() { request.addFieldsItem("field1"); request.addRowsItem(toRow("text")); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); @@ -380,11 +362,9 @@ void verifyScoreRequestWithTransformedShapley_ShapleyOptionTransformed_Succeeds( request.addRowsItem(toRow("text")); request.setRequestShapleyValueType(ShapleyType.TRANSFORMED); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); @@ -402,17 +382,14 @@ void verifyScoreRequestWithOriginalShapley_ShapleyOptionTransformed_Fails() { request.addRowsItem(toRow("text")); request.setRequestShapleyValueType(ShapleyType.ORIGINAL); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); // When & Then - assertThrows( - IllegalArgumentException.class, () -> scorer.score(request)); + assertThrows(IllegalArgumentException.class, () -> scorer.score(request)); } @Test @@ -424,11 +401,9 @@ void verifyScoreRequestWithoutShapley_ShapleyOptionOriginal_Succeeds() { request.addFieldsItem("field1"); request.addRowsItem(toRow("text")); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); @@ -446,17 +421,14 @@ void verifyScoreRequestWithTransformedShapley_ShapleyOptionOriginal_Fails() { request.addRowsItem(toRow("text")); request.setRequestShapleyValueType(ShapleyType.TRANSFORMED); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); // When & Then - assertThrows( - IllegalArgumentException.class, () -> scorer.score(request)); + assertThrows(IllegalArgumentException.class, () -> scorer.score(request)); } @Test @@ -469,11 +441,9 @@ void verifyScoreRequestWithOriginalShapley_ShapleyOptionOriginal_Succeeds() { request.addRowsItem(toRow("text")); request.setRequestShapleyValueType(ShapleyType.ORIGINAL); MojoFrame dummyMojoFrame = generateDummyTransformedMojoFrame(); - given(scoreRequestConverter.apply(any(), any())) - .willReturn(dummyMojoFrame); + given(scoreRequestConverter.apply(any(), any())).willReturn(dummyMojoFrame); ScoreResponse dummyResponse = generateDummyResponse(); - given(scoreResponseConverter.apply(any(), any())) - .willReturn(dummyResponse); + given(scoreResponseConverter.apply(any(), any())).willReturn(dummyResponse); MojoScorer scorer = dummyScorer(); @@ -492,7 +462,7 @@ private MojoFrame generateDummyTransformedMojoFrame() { columns.add(MojoColumnMeta.newOutput("Prediction.0", MojoColumn.Type.Float64)); final MojoFrameMeta meta = new MojoFrameMeta(columns); final MojoFrameBuilder frameBuilder = - new MojoFrameBuilder(meta, Collections.emptyList(), Collections.emptyMap()); + new MojoFrameBuilder(meta, Collections.emptyList(), Collections.emptyMap()); MojoRowBuilder mojoRowBuilder = frameBuilder.getMojoRowBuilder(); mojoRowBuilder.setValue("Prediction.0", "0.64"); frameBuilder.addRow(mojoRowBuilder); @@ -504,8 +474,8 @@ private MojoFrame generateDummyTransformedMojoFrame() { private ScoreResponse generateDummyResponse() { ScoreResponse response = new ScoreResponse(); - List outputRows = - Stream.generate(Row::new).limit(4).collect(Collectors.toList()); + List> outputRows = + Stream.generate(ArrayList::new).limit(4).collect(Collectors.toList()); response.setScore(outputRows); response.setFields(Arrays.asList("field1")); return response; @@ -520,7 +490,9 @@ private Model generateDummyModelInfo() { return model; } - /** Dummy pipeline {@link MojoPipeline} just to mock the static methods used inside scoring. */ + /** + * Dummy pipeline {@link MojoPipeline} just to mock the static methods used inside scoring. + */ private static class DummyPipeline extends MojoPipeline { private final MojoFrameMeta inputMeta; private final MojoFrameMeta outputMeta; @@ -531,10 +503,6 @@ private DummyPipeline(String uuid, MojoFrameMeta inputMeta, MojoFrameMeta output this.outputMeta = outputMeta; } - static DummyPipeline ofMeta(MojoFrameMeta inputMeta, MojoFrameMeta outputMeta) { - return new DummyPipeline(TEST_UUID, inputMeta, outputMeta); - } - @Override public MojoFrameMeta getInputMeta() { return inputMeta; @@ -579,19 +547,17 @@ public void setListener(BasePipelineListener listener) { @Override public void printPipelineInfo(PrintStream out) { - } } public MojoScorer dummyScorer() { return new MojoScorer( - scoreRequestConverter, - scoreResponseConverter, - contributionRequestConverter, - contributionResponseConverter, - modelInfoConverter, - scoreRequestTransformer, - csvConverter - ); + scoreRequestConverter, + scoreResponseConverter, + contributionRequestConverter, + contributionResponseConverter, + modelInfoConverter, + scoreRequestTransformer, + csvConverter); } } diff --git a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/RequestCheckerTest.java b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/RequestCheckerTest.java index a560d4d6..d7e55f40 100644 --- a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/RequestCheckerTest.java +++ b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/RequestCheckerTest.java @@ -14,7 +14,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; - import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; diff --git a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestToMojoFrameConverterTest.java b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestToMojoFrameConverterTest.java index ce7b4ce6..3c734e6a 100644 --- a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestToMojoFrameConverterTest.java +++ b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestToMojoFrameConverterTest.java @@ -119,8 +119,8 @@ void convertMoreRowsRequest_succeeds() { String[] values = {"value1", "value2", "value3"}; ScoreRequest request = new ScoreRequest(); request.addFieldsItem("field1"); - request.rows(Stream.of(values) - .map(ScoreRequestToMojoFrameConverterTest::asRow).collect(toList())); + request.rows( + Stream.of(values).map(ScoreRequestToMojoFrameConverterTest::asRow).collect(toList())); // When MojoFrame result = diff --git a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestTransformerTest.java b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestTransformerTest.java index 241e13b2..8807ff4e 100644 --- a/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestTransformerTest.java +++ b/common/transform/src/test/java/ai/h2o/mojos/deploy/common/transform/ScoreRequestTransformerTest.java @@ -5,12 +5,10 @@ import ai.h2o.mojos.deploy.common.rest.model.DataField; import ai.h2o.mojos.deploy.common.rest.model.Row; import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; - import org.junit.jupiter.api.Test; public class ScoreRequestTransformerTest { @@ -40,9 +38,16 @@ void transform_BooleanLiteral_Transformed() { // Given ScoreRequest scoreRequest = new ScoreRequest(); scoreRequest.setFields(Collections.singletonList("test")); - List rows = new ArrayList<>(Arrays.asList( - new Row(), new Row(), new Row(), new Row(), new Row(), new Row(), new Row() - )); + List> rows = + new ArrayList<>( + Arrays.asList( + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>())); rows.get(0).addAll(Collections.singletonList("true")); rows.get(1).addAll(Collections.singletonList("False")); rows.get(2).addAll(Collections.singletonList("TrUE")); @@ -60,9 +65,11 @@ void transform_BooleanLiteral_Transformed() { scoreRequestTransformer.accept(scoreRequest, dataFields); // Then - List expected = new ArrayList<>(Arrays.asList( - new Row(), new Row(), new Row(), new Row(), new Row(), new Row(), new Row() - )); + List> expected = + new ArrayList<>( + Arrays.asList( + new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + new ArrayList<>(), new ArrayList<>(), new ArrayList<>())); expected.get(0).addAll(Collections.singletonList("1")); expected.get(1).addAll(Collections.singletonList("0")); expected.get(2).addAll(Collections.singletonList("1")); @@ -78,7 +85,7 @@ void transform_NonBooleanLiteral_Unchanged() { // Given ScoreRequest scoreRequest = new ScoreRequest(); scoreRequest.setFields(Collections.singletonList("test")); - List rows = new ArrayList<>(Arrays.asList(new Row(), new Row())); + List> rows = new ArrayList<>(Arrays.asList(new Row(), new Row())); rows.get(0).addAll(Collections.singletonList("unchangedFeature1")); rows.get(1).addAll(Collections.singletonList("unchangedFeature2")); scoreRequest.setRows(rows); @@ -91,7 +98,8 @@ void transform_NonBooleanLiteral_Unchanged() { scoreRequestTransformer.accept(scoreRequest, dataFields); // Then - List expected = new ArrayList<>(Arrays.asList(new Row(), new Row())); + List> expected = + new ArrayList<>(Arrays.asList(new ArrayList<>(), new ArrayList<>())); expected.get(0).addAll(Collections.singletonList("unchangedFeature1")); expected.get(1).addAll(Collections.singletonList("unchangedFeature2")); assertEquals(expected, scoreRequest.getRows()); diff --git a/gcp-cloud-run/build.gradle b/gcp-cloud-run/build.gradle index ebd8318f..6e587e42 100644 --- a/gcp-cloud-run/build.gradle +++ b/gcp-cloud-run/build.gradle @@ -8,8 +8,8 @@ dependencies { implementation project(':local-rest-scorer') implementation group: 'org.slf4j', name: 'slf4j-api' implementation group: 'com.google.cloud', name: 'google-cloud-storage' - implementation group: 'org.apache.tomcat.embed', name: 'tomcat-embed-core', version: '9.0.63' - implementation group: 'org.apache.tomcat.embed', name: 'tomcat-embed-websocket', version: '9.0.63' + implementation group: 'org.apache.tomcat.embed', name: 'tomcat-embed-core', version: tomcatVersion + implementation group: 'org.apache.tomcat.embed', name: 'tomcat-embed-websocket', version: tomcatVersion } test { diff --git a/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/GcpVertexAiApplication.java b/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/GcpVertexAiApplication.java index 0e6dc91d..bb1d923e 100644 --- a/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/GcpVertexAiApplication.java +++ b/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/GcpVertexAiApplication.java @@ -9,8 +9,8 @@ @SpringBootApplication public class GcpVertexAiApplication { /** - * Wrapper application for running local rest scorer in Google Vertex AI. - * Downloads pipeline.mojo and license.sig files from GCS before starting rest server. + * Wrapper application for running local rest scorer in Google Vertex AI. Downloads pipeline.mojo + * and license.sig files from GCS before starting rest server. * * @param args N/A, application only requires environment variables */ diff --git a/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/config/EnvironmentConfiguration.java b/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/config/EnvironmentConfiguration.java index e11d1629..59817464 100644 --- a/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/config/EnvironmentConfiguration.java +++ b/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/config/EnvironmentConfiguration.java @@ -41,17 +41,16 @@ public void configureScoringEnvironment() { private void downloadArtifactsFromGcs(Map env) { downloadFileFromGcs(getFromEnv(env, "MOJO_GCS_PATH"), Paths.get(MOJO_DOWNLOAD_PATH)); - + // Only download license key file if license key string was not provided if (env.getOrDefault("DRIVERLESS_AI_LICENSE_KEY", "").isEmpty()) { downloadFileFromGcs(getFromEnv(env, "LICENSE_GCS_PATH"), Paths.get(LICENSE_DOWNLOAD_PATH)); } - + // Only use pre-processing script if one is provided if (!env.getOrDefault("PREPROCESSING_SCRIPT_PATH", "").isEmpty()) { downloadFileFromGcs( - getFromEnv(env, "PREPROCESSING_SCRIPT_PATH"), - Paths.get(PREPROCESSING_SCRIPT_PATH)); + getFromEnv(env, "PREPROCESSING_SCRIPT_PATH"), Paths.get(PREPROCESSING_SCRIPT_PATH)); } } diff --git a/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/config/ScorerConfiguration.java b/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/config/ScorerConfiguration.java index b30f5795..648a9f05 100644 --- a/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/config/ScorerConfiguration.java +++ b/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/config/ScorerConfiguration.java @@ -64,12 +64,12 @@ public MojoScorer mojoScorer( ScoreRequestTransformer scoreRequestTransformer, CsvToMojoFrameConverter csvConverter) { return new MojoScorer( - requestConverter, - responseConverter, - contributionRequestConverter, - contributionResponseConverter, - modelInfoConverter, - scoreRequestTransformer, - csvConverter); + requestConverter, + responseConverter, + contributionRequestConverter, + contributionResponseConverter, + modelInfoConverter, + scoreRequestTransformer, + csvConverter); } } diff --git a/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/controller/ModelsApiController.java b/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/controller/ModelsApiController.java index 1c2c7df2..977cd50f 100644 --- a/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/controller/ModelsApiController.java +++ b/gcp-vertex-ai-mojo-scorer/src/main/java/ai/h2o/mojos/deploy/gcp/vertex/ai/controller/ModelsApiController.java @@ -14,6 +14,8 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.UUID; import org.codehaus.jackson.map.ObjectMapper; @@ -48,8 +50,7 @@ public ResponseEntity getModelId() { @Override public ResponseEntity getScore( - ai.h2o.mojos.deploy.common.rest.vertex.ai.model.ScoreRequest gcpRequest - ) { + ai.h2o.mojos.deploy.common.rest.vertex.ai.model.ScoreRequest gcpRequest) { try { log.info("Got scoring request"); // Convert GCP request to REST request @@ -63,7 +64,7 @@ public ResponseEntity getScore( return ResponseEntity.badRequest().build(); } } - + /** * Converts GCP Vertex AI request to REST module request. * @@ -71,22 +72,21 @@ public ResponseEntity getScore( * Vertex AI request to be converted */ public static ScoreRequest getRestScoreRequest( - ai.h2o.mojos.deploy.common.rest.vertex.ai.model.ScoreRequest gcpRequest - ) { + ai.h2o.mojos.deploy.common.rest.vertex.ai.model.ScoreRequest gcpRequest) { ScoreRequest request = new ScoreRequest(); - + if (gcpRequest.getParameters().getIncludeFieldsInOutput() != null) { request.setIncludeFieldsInOutput(gcpRequest.getParameters().getIncludeFieldsInOutput()); } - - request.setNoFieldNamesInOutput(gcpRequest.getParameters().isNoFieldNamesInOutput()); + + request.setNoFieldNamesInOutput(gcpRequest.getParameters().getNoFieldNamesInOutput()); request.setIdField(gcpRequest.getParameters().getIdField()); request.setFields(gcpRequest.getParameters().getFields()); - + // Check if a pre-processing script was provided, if so, use it first Map env = System.getenv(); String preProccessingScript = env.getOrDefault("PREPROCESSING_SCRIPT_PATH", ""); - + if (!preProccessingScript.isEmpty()) { // Write data to file to be injested by preprocessing script String fileName = UUID.randomUUID().toString() + ".json"; @@ -96,7 +96,7 @@ public static ScoreRequest getRestScoreRequest( } catch (IOException e) { log.error("Failed writing JSON file: {}", e.getMessage()); } - + // Run preprocessing script on request data try (FileReader fileReader = new FileReader("/tmp/" + fileName); JsonReader reader = new JsonReader(fileReader)) { @@ -105,10 +105,11 @@ public static ScoreRequest getRestScoreRequest( processBuilder.redirectErrorStream(true); Process process = processBuilder.start(); process.waitFor(); - + Gson gson = new GsonBuilder().setPrettyPrinting().create(); - gcpRequest = gson.fromJson(reader, - ai.h2o.mojos.deploy.common.rest.vertex.ai.model.ScoreRequest.class); + gcpRequest = + gson.fromJson( + reader, ai.h2o.mojos.deploy.common.rest.vertex.ai.model.ScoreRequest.class); } catch (JsonSyntaxException e) { log.error("Malformed JSON when reading from file: {}", e.getMessage()); } catch (Exception e) { @@ -121,46 +122,43 @@ public static ScoreRequest getRestScoreRequest( } } } - - Row row; - for (ai.h2o.mojos.deploy.common.rest.vertex.ai.model.Row gcpRow: - gcpRequest.getInstances() - ) { + + List row; + for (List gcpRow : gcpRequest.getInstances()) { row = new Row(); for (int i = 0; i < gcpRow.size(); i++) { row.add(gcpRow.get(i)); } - + request.addRowsItem(row); } - + return request; } - + /** * Converts REST module response to GCP Vertex AI response. * - * @param restResponse {@link ai.h2o.mojos.deploy.common.rest.model.ScoreResponse} REST - * module response to convert + * @param restResponse {@link ai.h2o.mojos.deploy.common.rest.model.ScoreResponse} REST module + * response to convert */ public static ScoreResponse getGcpScoreResponse( - ai.h2o.mojos.deploy.common.rest.model.ScoreResponse restResponse - ) { + ai.h2o.mojos.deploy.common.rest.model.ScoreResponse restResponse) { ScoreResponse response = new ScoreResponse(); - + response.setId(restResponse.getId()); response.setFields(restResponse.getFields()); - - ai.h2o.mojos.deploy.common.rest.vertex.ai.model.Row row; - for (Row restRow: restResponse.getScore()) { - row = new ai.h2o.mojos.deploy.common.rest.vertex.ai.model.Row(); + + List row; + for (List restRow : restResponse.getScore()) { + row = new ArrayList<>(); for (int i = 0; i < restRow.size(); i++) { row.add(restRow.get(i)); } - + response.addPredictionsItem(row); } - + return response; } } diff --git a/gradle.properties b/gradle.properties index 421bdcd3..99a044a4 100644 --- a/gradle.properties +++ b/gradle.properties @@ -5,25 +5,25 @@ version = 1.1.21-SNAPSHOT # Internal dependencies: h2oVersion = 3.40.0.3 -mojoRuntimeVersion = 2.8.3 +mojoRuntimeVersion = 2.8.5 # External dependencies: awsLambdaCoreVersion = 1.2.0 awsLambdaEventsVersion = 2.2.3 awsSdkS3Version = 1.11.445 -javaxAnnotationVersion = 1.3.2 gsonVersion = 2.8.9 -jupiterPioneerVersion = 1.9.1 -jupiterVersion = 5.7.2 -mockitoVersion = 4.11.0 -snakeYamlVersion = 2.2 +jupiterPioneerVersion = 2.2.0 +jupiterVersion = 5.10.1 +mockitoVersion = 5.8.0 +mockitoInlineVersion = 5.2.0 springFoxVersion = 3.0.0 -swaggerCodegenVersion = 3.0.46 -swaggerCoreVersion = 2.2.11 -swaggerCoreSpringVersion = 1.6.11 -shadowJarVersion = 4.0.4 -slf4jVersion = 1.7.36 -log4jVersion = 2.22.0 +swaggerCodegenVersion = 3.0.52 +swaggerParserVersion = 2.1.14 +swaggerCoreVersion = 2.2.20 +swaggerCoreSpringVersion = 1.6.12 +shadowJarVersion = 7.1.0 +slf4jVersion = 2.0.9 +logbackVersion = 1.4.14 apacheCommonsCliVersion = 1.4 truthVersion = 0.42 guavaVersion = 32.0.0-jre @@ -32,20 +32,22 @@ sparkVersion = 2.4.4 scalaVersion = 2.12.15 sparklingWaterVersion = 3.30.1.3-1-3.0 configVersion = 1.3.4 -tomcatEmbedVersion = 9.0.75 +openApiJacksonNullableVersion = 0.2.6 +jakartaServletVersion = 6.0.0 +tomcatVersion = 9.0.75 # External plugins: -springBootPluginVersion = 2.7.12 +springBootPluginVersion = 3.2.0 swaggerGradlePluginVersion = 2.19.2 -spotlessPluginVersion = 3.24.2 errorpronePluginVersion = 3.1.0 jibPluginVersion = 2.7.1 +openApiGeneratorGradlePluginVersion = 7.0.1 # External tools: checkStyleVersion = 8.21 googleJavaFormatVersion = 1.7 errorproneJavacVersion = 9+181-r4173-1 -errorproneVersion = 2.3.3 +errorproneVersion = 2.23.0 # Docker settings dockerRepositoryPrefix = harbor.h2o.ai/opsh2oai/h2oai/ diff --git a/gradle/java.gradle b/gradle/java.gradle index 6f13bcf3..185c5f5a 100644 --- a/gradle/java.gradle +++ b/gradle/java.gradle @@ -3,5 +3,4 @@ apply from: project(":").file('gradle/java_no_style.gradle') apply from: project(":").file('gradle/mixins/checkstyle.gradle') -apply from: project(":").file('gradle/mixins/spotless.gradle') apply from: project(":").file('gradle/mixins/errorprone.gradle') diff --git a/gradle/java_no_style.gradle b/gradle/java_no_style.gradle index e64264ec..8fa91f0f 100644 --- a/gradle/java_no_style.gradle +++ b/gradle/java_no_style.gradle @@ -5,5 +5,5 @@ apply plugin: 'java' apply from: project(":").file('gradle/mixins/dependencies.gradle') -sourceCompatibility = 1.8 -targetCompatibility = 1.8 +sourceCompatibility = JavaVersion.VERSION_17 +targetCompatibility = JavaVersion.VERSION_17 diff --git a/gradle/mixins/dependencies.gradle b/gradle/mixins/dependencies.gradle index 8bd47d08..3e238012 100644 --- a/gradle/mixins/dependencies.gradle +++ b/gradle/mixins/dependencies.gradle @@ -21,12 +21,11 @@ dependencyManagement { dependency group: 'com.google.truth.extensions', name: 'truth-java8-extension', version: truthVersion dependency group: 'com.google.cloud', name: 'google-cloud-storage', version: googleStorageVersion dependency group: 'io.springfox', name: 'springfox-boot-starter', version: springFoxVersion - // https://nvd.nist.gov/vuln/detail/CVE-2022-1471 - dependency group: 'org.yaml', name: 'snakeyaml', version: snakeYamlVersion - dependency group: 'io.swagger', name: 'swagger-annotations', version: swaggerCoreSpringVersion + dependency group: 'org.springframework.boot', name: 'spring-boot-starter-tomcat', version: springBootPluginVersion dependency group: 'io.swagger.core.v3', name: 'swagger-annotations', version: swaggerCoreVersion + dependency group: 'io.swagger.parser.v3', name: 'swagger-parser', version: swaggerParserVersion dependency group: 'io.swagger.codegen.v3', name: 'swagger-codegen-cli', version: swaggerCodegenVersion - dependency group: 'javax.annotation', name: 'javax.annotation-api', version: javaxAnnotationVersion + dependency group: 'org.openapitools', name: 'jackson-databind-nullable', version: openApiJacksonNullableVersion dependencySet(group: 'org.junit.jupiter', version: jupiterVersion) { entry 'junit-jupiter-api' entry 'junit-jupiter-engine' @@ -38,14 +37,14 @@ dependencyManagement { entry 'mockito-junit-jupiter' entry 'mockito-inline' } + dependency group: 'org.junit-pioneer', name: 'junit-pioneer', version: jupiterPioneerVersion dependency group: 'commons-cli', name: 'commons-cli', version: apacheCommonsCliVersion dependency group: 'org.slf4j', name: 'slf4j-api', version: slf4jVersion - - dependencySet(group: 'org.apache.logging.log4j', version: log4jVersion) { - entry 'log4j-api' - entry 'log4j-core' - entry 'log4j-slf4j-impl' + dependencySet(group: 'ch.qos.logback', version: logbackVersion) { + entry 'logback-classic' + entry 'logback-core' } + dependency group: 'jakarta.servlet', name: 'jakarta.servlet-api', version: jakartaServletVersion dependency group: 'org.apache.spark', name: 'spark-core_2.12', version: sparkVersion dependency group: 'org.apache.spark', name: 'spark-sql_2.12', version: sparkVersion diff --git a/gradle/mixins/spotless.gradle b/gradle/mixins/spotless.gradle deleted file mode 100644 index 131ada16..00000000 --- a/gradle/mixins/spotless.gradle +++ /dev/null @@ -1,12 +0,0 @@ -// Defines shared Gradle project Spotless formatter configuration. - -//apply plugin: 'com.diffplug.gradle.spotless' -// -//spotless { -// java { -// target fileTree('src') { -// include '**/*.java' -// } -// googleJavaFormat(googleJavaFormatVersion) -// } -//} diff --git a/init.gradle b/init.gradle new file mode 100644 index 00000000..0f4ade92 --- /dev/null +++ b/init.gradle @@ -0,0 +1,12 @@ +initscript { + // Temporary workaround https://github.com/gradle/gradle/issues/24390 + gradle.allprojects { + buildscript { + configurations.all { + resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-core:2.14.2' + resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-databind:2.14.2' + resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-base:2.14.2' + } + } + } +} diff --git a/kdb-mojo-scorer/build.gradle b/kdb-mojo-scorer/build.gradle index f594fc96..f0dfd43e 100644 --- a/kdb-mojo-scorer/build.gradle +++ b/kdb-mojo-scorer/build.gradle @@ -8,10 +8,10 @@ dependencies { implementation group: 'ai.h2o', name: 'mojo2-runtime-api' implementation group: 'ai.h2o', name: 'mojo2-runtime-impl' implementation group: 'commons-cli', name: 'commons-cli' - implementation group: 'org.slf4j', name: 'slf4j-api' - implementation group: 'org.apache.logging.log4j', name: 'log4j-api' - implementation group: 'org.apache.logging.log4j', name: 'log4j-core' - implementation group: 'org.apache.logging.log4j', name: 'log4j-slf4j-impl' + implementation group: 'org.slf4j', name: 'slf4j-api', version: '1.7.36' + implementation group: 'org.apache.logging.log4j', name: 'log4j-api', version: '2.22.0' + implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version: '2.22.0' + implementation group: 'org.apache.logging.log4j', name: 'log4j-slf4j-impl', version: '2.22.0' testImplementation group: 'com.google.truth.extensions', name: 'truth-java8-extension' testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-api' diff --git a/local-rest-scorer/build.gradle b/local-rest-scorer/build.gradle index 58fda926..f97fb6d8 100644 --- a/local-rest-scorer/build.gradle +++ b/local-rest-scorer/build.gradle @@ -5,6 +5,7 @@ plugins { apply from: project(":").file('gradle/java.gradle') dependencies { + runtimeOnly group: 'org.springframework.boot', name: 'spring-boot-properties-migrator' implementation project(':common:rest-spring-api') implementation project(':common:transform') implementation group: 'ai.h2o', name: 'h2o-genmodel' @@ -14,23 +15,26 @@ dependencies { implementation group: 'ai.h2o', name: 'mojo2-runtime-impl' implementation group: 'io.springfox', name: 'springfox-boot-starter', version: springFoxVersion implementation group: 'org.springframework.boot', name: 'spring-boot-starter-web' - implementation group: 'org.apache.tomcat.embed', name: 'tomcat-embed-core', version: tomcatEmbedVersion + implementation group: 'org.springframework.boot', name: 'spring-boot-starter-tomcat' implementation group: 'com.google.guava', name: 'guava', version: guavaVersion - implementation group: 'org.yaml', name: 'snakeyaml' + implementation group: 'jakarta.servlet', name: 'jakarta.servlet-api' testImplementation group: 'org.springframework.boot', name: 'spring-boot-starter-test' testImplementation group: 'com.google.truth.extensions', name: 'truth-java8-extension' - testImplementation group: 'org.mockito', name: 'mockito-inline' - testImplementation group: 'org.mockito', name : 'mockito-core' - testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-api' - testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-params' - testImplementation group: 'org.junit-pioneer', name: 'junit-pioneer' - testRuntimeOnly group: 'org.junit.jupiter', name: 'junit-jupiter-engine' + testImplementation group: 'org.mockito', name: 'mockito-inline', version: mockitoInlineVersion + testImplementation group: 'org.mockito', name : 'mockito-core', version: mockitoVersion + testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-api', version: jupiterVersion + testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-params', version: jupiterVersion + testImplementation group: 'org.junit-pioneer', name: 'junit-pioneer', version: jupiterPioneerVersion + testRuntimeOnly group: 'org.junit.jupiter', name: 'junit-jupiter-engine', version: jupiterVersion } test { useJUnitPlatform() + + jvmArgs '--add-opens=java.base/java.util=ALL-UNNAMED' + jvmArgs '--add-opens=java.base/java.lang=ALL-UNNAMED' } bootRun { diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/config/ScorerConfiguration.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/config/ScorerConfiguration.java index 7d1bbb70..cf064de9 100644 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/config/ScorerConfiguration.java +++ b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/config/ScorerConfiguration.java @@ -64,12 +64,12 @@ public MojoScorer mojoScorer( ScoreRequestTransformer scoreRequestTransformer, CsvToMojoFrameConverter csvConverter) { return new MojoScorer( - requestConverter, - responseConverter, - contributionRequestConverter, - contributionResponseConverter, - modelInfoConverter, - scoreRequestTransformer, - csvConverter); + requestConverter, + responseConverter, + contributionRequestConverter, + contributionResponseConverter, + modelInfoConverter, + scoreRequestTransformer, + csvConverter); } } diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java index f6fd49fb..10d28128 100644 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java +++ b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java @@ -5,6 +5,7 @@ import ai.h2o.mojos.deploy.common.rest.model.ContributionRequest; import ai.h2o.mojos.deploy.common.rest.model.ContributionResponse; import ai.h2o.mojos.deploy.common.rest.model.Model; +import ai.h2o.mojos.deploy.common.rest.model.ScoreMediaRequest; import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest; import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse; import ai.h2o.mojos.deploy.common.transform.MojoScorer; @@ -16,13 +17,14 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.io.Resource; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Controller; +import org.springframework.web.multipart.MultipartFile; import org.springframework.web.server.ResponseStatusException; @Controller @@ -49,9 +51,17 @@ public class ModelsApiController implements ModelApi { public ModelsApiController(MojoScorer scorer, SampleRequestBuilder sampleRequestBuilder) { this.scorer = scorer; this.sampleRequestBuilder = sampleRequestBuilder; - this.supportedCapabilities = assembleSupportedCapabilities( - scorer.getEnabledShapleyTypes(), scorer.isPredictionIntervalSupport() - ); + this.supportedCapabilities = + assembleSupportedCapabilities( + scorer.getEnabledShapleyTypes(), scorer.isPredictionIntervalSupport()); + } + + @Override + public ResponseEntity getMediaScore( + ScoreMediaRequest request, List files) { + log.info("Got score media request"); + throw new ResponseStatusException( + HttpStatus.NOT_IMPLEMENTED, "score media files is not implemented"); } @Override @@ -82,14 +92,16 @@ public ResponseEntity getScore(ScoreRequest request) { log.debug(" - failure cause: ", e); throw new ResponseStatusException( HttpStatus.BAD_REQUEST, - String.format("Invalid scoring request due to: %s", e.getMessage()), e); + String.format("Invalid scoring request due to: %s", e.getMessage()), + e); } catch (Exception e) { log.error("Failed scoring request due to: {}", e.getMessage()); log.debug(" - request content: ", request); log.debug(" - failure cause: ", e); throw new ResponseStatusException( ErrorUtil.translateErrorCode(e), - String.format("Failed scoring request due to: %s", e.getMessage()), e); + String.format("Failed scoring request due to: %s", e.getMessage()), + e); } } @@ -107,47 +119,51 @@ public ResponseEntity getScoreByFile(String file) { log.debug(" - failure cause: ", e); throw new ResponseStatusException( HttpStatus.INTERNAL_SERVER_ERROR, - String.format("Failed loading CSV file due to: %s", e.getMessage()), e); + String.format("Failed loading CSV file due to: %s", e.getMessage()), + e); } catch (IllegalArgumentException e) { log.error("Invalid scoring request for CSV file {} request due to: {}", file, e.getMessage()); log.debug(" - failure cause: ", e); throw new ResponseStatusException( HttpStatus.BAD_REQUEST, - String.format("Invalid scoring request for CSV file due to: %s", e.getMessage()), e); + String.format("Invalid scoring request for CSV file due to: %s", e.getMessage()), + e); } catch (Exception e) { log.error("Failed scoring CSV file {} due to: {}", file, e.getMessage()); log.debug(" - failure cause: ", e); throw new ResponseStatusException( ErrorUtil.translateErrorCode(e), - String.format("Failed scoring CSV file due to: %s", e.getMessage()), e); + String.format("Failed scoring CSV file due to: %s", e.getMessage()), + e); } } @Override - public ResponseEntity getContribution( - ContributionRequest request) { + public ResponseEntity getContribution(ContributionRequest request) { try { log.info("Got shapley contribution request"); - ContributionResponse contributionResponse - = scorer.computeContribution(request); + ContributionResponse contributionResponse = scorer.computeContribution(request); return ResponseEntity.ok(contributionResponse); } catch (UnsupportedOperationException e) { log.error("Unsupported operation due to: {}", e.getMessage()); throw new ResponseStatusException( HttpStatus.NOT_IMPLEMENTED, - String.format("Unsupported operation due to: %s", e.getMessage()), e); + String.format("Unsupported operation due to: %s", e.getMessage()), + e); } catch (IllegalArgumentException e) { log.error("Invalid shapley contribution request due to: {}", e.getMessage()); log.debug(" - failure cause: ", e); throw new ResponseStatusException( HttpStatus.BAD_REQUEST, - String.format("Invalid shapley contribution request due to: %s", e.getMessage()), e); + String.format("Invalid shapley contribution request due to: %s", e.getMessage()), + e); } catch (Exception e) { log.error("Failed shapley contribution request due to: {}", e.getMessage()); log.debug(" - failure cause: ", e); throw new ResponseStatusException( ErrorUtil.translateErrorCode(e), - String.format("Failed shapley contribution request due to: %s", e.getMessage()), e); + String.format("Failed shapley contribution request due to: %s", e.getMessage()), + e); } } @@ -173,10 +189,11 @@ private static List assembleSupportedCapabilities( } switch (enabledShapleyTypes) { case ALL: - result.addAll(Arrays.asList( - CapabilityType.SCORE, - CapabilityType.CONTRIBUTION_ORIGINAL, - CapabilityType.CONTRIBUTION_TRANSFORMED)); + result.addAll( + Arrays.asList( + CapabilityType.SCORE, + CapabilityType.CONTRIBUTION_ORIGINAL, + CapabilityType.CONTRIBUTION_TRANSFORMED)); break; case ORIGINAL: result.addAll(Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_ORIGINAL)); diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsMediaController.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsMediaController.java deleted file mode 100644 index 0cc7aa75..00000000 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsMediaController.java +++ /dev/null @@ -1,26 +0,0 @@ -package ai.h2o.mojos.deploy.local.rest.controller; - -import ai.h2o.mojos.deploy.common.rest.v1exp.api.ModelApi; -import ai.h2o.mojos.deploy.common.rest.v1exp.model.ScoreMediaRequest; -import ai.h2o.mojos.deploy.common.rest.v1exp.model.ScoreResponse; -import java.util.List; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.core.io.Resource; -import org.springframework.http.HttpStatus; -import org.springframework.http.ResponseEntity; -import org.springframework.stereotype.Controller; -import org.springframework.web.server.ResponseStatusException; - -@Controller -public class ModelsMediaController implements ModelApi { - private static final Logger log = LoggerFactory.getLogger(ModelsMediaController.class); - - @Override - public ResponseEntity getMediaScore( - ScoreMediaRequest request, List files) { - log.info("Got score media request"); - throw new ResponseStatusException(HttpStatus.NOT_IMPLEMENTED, - "score media files is not implemented"); - } -} diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/converter/ScoreMediaRequestConverter.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/converter/ScoreMediaRequestConverter.java index b30f1181..647bb5a7 100644 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/converter/ScoreMediaRequestConverter.java +++ b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/converter/ScoreMediaRequestConverter.java @@ -1,6 +1,6 @@ package ai.h2o.mojos.deploy.local.rest.converter; -import ai.h2o.mojos.deploy.common.rest.v1exp.model.ScoreMediaRequest; +import ai.h2o.mojos.deploy.common.rest.model.ScoreMediaRequest; import com.google.gson.Gson; import org.springframework.core.convert.converter.Converter; diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/error/ErrorUtil.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/error/ErrorUtil.java index 535c3207..d7ce2070 100644 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/error/ErrorUtil.java +++ b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/error/ErrorUtil.java @@ -4,9 +4,7 @@ public class ErrorUtil { - /** - * Translate exception type into error response code. - */ + /** Translate exception type into error response code. */ public static HttpStatus translateErrorCode(Exception exception) { if (exception instanceof IllegalArgumentException) { return HttpStatus.BAD_REQUEST; diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/error/ModelsExceptionHandler.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/error/ModelsExceptionHandler.java index e19d7726..cca1828a 100644 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/error/ModelsExceptionHandler.java +++ b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/error/ModelsExceptionHandler.java @@ -15,68 +15,51 @@ public class ModelsExceptionHandler extends ResponseEntityExceptionHandler { private static final Logger log = LoggerFactory.getLogger(ModelsExceptionHandler.class); - /** - * Custom Exception handler for ResponseStatusException type. - */ + /** Custom Exception handler for ResponseStatusException type. */ @ExceptionHandler(ResponseStatusException.class) public ResponseEntity handleResponseStatusException( ResponseStatusException exception, WebRequest request) { log.error("Runtime exception occurred : {}", exception.getMessage(), exception); - return ResponseEntity - .status(exception.getStatus()) - .body(ImmutableMap - .builder() - .put("detail", exception.getMessage()) - .build() - ); + return ResponseEntity.status(exception.getStatusCode()) + .body(ImmutableMap.builder().put("detail", exception.getMessage()).build()); } - /** - * Custom Exception handler for illegal request exception type. - */ + /** Custom Exception handler for illegal request exception type. */ @ExceptionHandler(IllegalArgumentException.class) public ResponseEntity handleIllegalArgumentException( IllegalArgumentException exception, WebRequest request) { log.error("Illegal request exception occurred : {}", exception.getMessage(), exception); - return ResponseEntity - .badRequest() + return ResponseEntity.badRequest() .body(ImmutableMap.builder().put("detail", exception.getMessage()).build()); } - /** - * Custom Exception handler for illegal state exception type. - */ + /** Custom Exception handler for illegal state exception type. */ @ExceptionHandler(IllegalStateException.class) public ResponseEntity handleIllegalStateException( IllegalStateException exception, WebRequest request) { log.error("Illegal state exception occurred : {}", exception.getMessage(), exception); - return ResponseEntity - .status(HttpStatus.SERVICE_UNAVAILABLE) + return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE) .body(ImmutableMap.builder().put("detail", exception.getMessage()).build()); } - /** - * Custom Exception handler for unsupported exception type. - */ + /** Custom Exception handler for unsupported exception type. */ @ExceptionHandler(UnsupportedOperationException.class) public ResponseEntity handleUnsupportedException( UnsupportedOperationException exception, WebRequest request) { - log.error("Unsupported request exception occurred {} : {}", - request, exception.getMessage(), exception); - return ResponseEntity - .status(HttpStatus.NOT_IMPLEMENTED) + log.error( + "Unsupported request exception occurred {} : {}", + request, + exception.getMessage(), + exception); + return ResponseEntity.status(HttpStatus.NOT_IMPLEMENTED) .body(ImmutableMap.builder().put("detail", exception.getMessage()).build()); } - /** - * Custom Exception handler for all Exception type. - */ + /** Custom Exception handler for all Exception type. */ @ExceptionHandler(Exception.class) - public ResponseEntity handleAllException( - Exception exception, WebRequest request) { + public ResponseEntity handleAllException(Exception exception, WebRequest request) { log.error("Unexpected exception occurred : {}", exception.getMessage(), exception); - return ResponseEntity - .internalServerError() + return ResponseEntity.internalServerError() .body(ImmutableMap.builder().put("detail", exception.getMessage()).build()); } } diff --git a/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java b/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java index 64f5a716..70ddfdb9 100644 --- a/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java +++ b/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java @@ -9,27 +9,24 @@ import ai.h2o.mojos.deploy.common.rest.model.CapabilityType; import ai.h2o.mojos.deploy.common.rest.model.Model; +import ai.h2o.mojos.deploy.common.rest.model.ScoreMediaRequest; import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest; import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse; -import ai.h2o.mojos.deploy.common.rest.v1exp.model.ScoreMediaRequest; import ai.h2o.mojos.deploy.common.transform.MojoScorer; import ai.h2o.mojos.deploy.common.transform.SampleRequestBuilder; import ai.h2o.mojos.deploy.common.transform.ShapleyLoadOption; import ai.h2o.mojos.runtime.MojoPipeline; import ai.h2o.mojos.runtime.api.MojoPipelineService; import ai.h2o.mojos.runtime.api.PipelineConfig; - import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; - import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junitpioneer.jupiter.SetEnvironmentVariable; - import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.Mockito; @@ -37,11 +34,13 @@ import org.springframework.core.io.Resource; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; +import org.springframework.web.multipart.MultipartFile; import org.springframework.web.server.ResponseStatusException; @ExtendWith(MockitoExtension.class) class ModelsApiControllerTest { - @Mock private SampleRequestBuilder sampleRequestBuilder; + @Mock + private SampleRequestBuilder sampleRequestBuilder; @BeforeAll static void setup() throws IOException { @@ -53,9 +52,12 @@ static void setup() throws IOException { private static void mockMojoPipeline(File tmpModel) { MojoPipeline mojoPipeline = Mockito.mock(MojoPipeline.class); MockedStatic theMock = Mockito.mockStatic(MojoPipelineService.class); - theMock.when(() -> MojoPipelineService - .loadPipeline(Mockito.eq(new File(tmpModel.getAbsolutePath())), any(PipelineConfig.class))) - .thenReturn(mojoPipeline); + theMock + .when( + () -> + MojoPipelineService.loadPipeline( + Mockito.eq(new File(tmpModel.getAbsolutePath())), any(PipelineConfig.class))) + .thenReturn(mojoPipeline); } @Test @@ -79,10 +81,11 @@ void verifyCapabilities_DefaultShapley_ReturnsExpected() { @Test void verifyCapabilities_AllShapleyEnabled_ReturnsExpected() { // Given - List expectedCapabilities = Arrays.asList( - CapabilityType.SCORE, - CapabilityType.CONTRIBUTION_ORIGINAL, - CapabilityType.CONTRIBUTION_TRANSFORMED); + List expectedCapabilities = + Arrays.asList( + CapabilityType.SCORE, + CapabilityType.CONTRIBUTION_ORIGINAL, + CapabilityType.CONTRIBUTION_TRANSFORMED); MojoScorer scorer = mock(MojoScorer.class); when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.ALL); when(scorer.isPredictionIntervalSupport()).thenReturn(false); @@ -99,9 +102,8 @@ void verifyCapabilities_AllShapleyEnabled_ReturnsExpected() { @Test void verifyCapabilities_OriginalShapleyEnabled_ReturnsExpected() { // Given - List expectedCapabilities = Arrays.asList( - CapabilityType.SCORE, - CapabilityType.CONTRIBUTION_ORIGINAL); + List expectedCapabilities = + Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_ORIGINAL); MojoScorer scorer = mock(MojoScorer.class); when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.ORIGINAL); when(scorer.isPredictionIntervalSupport()).thenReturn(false); @@ -118,9 +120,8 @@ void verifyCapabilities_OriginalShapleyEnabled_ReturnsExpected() { @Test void verifyCapabilities_TransformedShapleyEnabled_ReturnsExpected() { // Given - List expectedCapabilities = Arrays.asList( - CapabilityType.SCORE, - CapabilityType.CONTRIBUTION_TRANSFORMED); + List expectedCapabilities = + Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_TRANSFORMED); MojoScorer scorer = mock(MojoScorer.class); when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.TRANSFORMED); when(scorer.isPredictionIntervalSupport()).thenReturn(false); @@ -187,7 +188,7 @@ void verifyScore_Fails_ReturnsException() { } catch (Exception ex) { assertTrue(ex instanceof ResponseStatusException); assertTrue(ex.getCause() instanceof IllegalStateException); - assertEquals(HttpStatus.SERVICE_UNAVAILABLE, ((ResponseStatusException) ex).getStatus()); + assertEquals(HttpStatus.SERVICE_UNAVAILABLE, ((ResponseStatusException) ex).getStatusCode()); } } @@ -195,8 +196,11 @@ void verifyScore_Fails_ReturnsException() { void verifyScoreMedia_ReturnsUnimplemented() { // Given ScoreMediaRequest request = mock(ScoreMediaRequest.class); - List files = new ArrayList<>(); - ModelsMediaController controller = new ModelsMediaController(); + List files = new ArrayList<>(); + MojoScorer scorer = mock(MojoScorer.class); + when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.TRANSFORMED); + when(scorer.isPredictionIntervalSupport()).thenReturn(false); + ModelsApiController controller = new ModelsApiController(scorer, new SampleRequestBuilder()); // When & Then try { @@ -204,7 +208,7 @@ void verifyScoreMedia_ReturnsUnimplemented() { fail("exception is expected, but fail to raise"); } catch (Exception ex) { assertTrue(ex instanceof ResponseStatusException); - assertEquals(HttpStatus.NOT_IMPLEMENTED, ((ResponseStatusException) ex).getStatus()); + assertEquals(HttpStatus.NOT_IMPLEMENTED, ((ResponseStatusException) ex).getStatusCode()); } } } diff --git a/settings.gradle b/settings.gradle index 825f8aba..29b59ed7 100644 --- a/settings.gradle +++ b/settings.gradle @@ -1,3 +1,18 @@ +pluginManagement { + plugins { + id 'com.google.cloud.tools.jib' version "${jibPluginVersion}" + id 'com.github.johnrengelman.shadow' version "${shadowJarVersion}" + id 'org.springframework.boot' version "${springBootPluginVersion}" + id 'org.openapi.generator' version "${openApiGeneratorGradlePluginVersion}" + } +} + +dependencyResolutionManagement { + repositories { + mavenCentral() + } +} + rootProject.name = 'dai-deployment-templates' include 'aws-lambda-scorer:lambda-template' include 'aws-lambda-scorer:terraform-recipe' diff --git a/sql-jdbc-scorer/build.gradle b/sql-jdbc-scorer/build.gradle index 242422c3..3e568f0b 100644 --- a/sql-jdbc-scorer/build.gradle +++ b/sql-jdbc-scorer/build.gradle @@ -16,8 +16,7 @@ dependencies { exclude group: 'spring-boot-starter-logging' } - testImplementation 'junit:junit:4.12' - testImplementation group: 'junit', name: 'junit', version: '4.12' + testImplementation group: 'junit', name: 'junit', version: jupiterVersion } configurations {