Skip to content
This repository has been archived by the owner on Nov 29, 2024. It is now read-only.

Commit

Permalink
add media score endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jih147 committed Jun 12, 2023
1 parent dca546a commit bcb2668
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 0 deletions.
46 changes: 46 additions & 0 deletions common/swagger/v2/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,40 @@ paths:
description: Implementation not supported
content: {}
x-codegen-request-body-name: payload
/model/media-score:
post:
tags:
- scoring
summary: Score model with provided media files
description: Computes score of provided data making use of provided media files.
operationId: getMediaScore
requestBody:
content:
multipart/form-data:
schema:
type: object
properties:
mediaScoreRequest:
$ref: '#/components/schemas/MediaScoreRequest'
files:
type: array
items:
type: string
format: binary
required:
- mediaScoreRequest
- files
responses:
'200':
description: Successful scoring operation
content:
application/json:
schema:
$ref: '#/components/schemas/ScoreResponse'
'501':
description: Implementation not supported
'400':
description: Invalid payload
components:
schemas:
Model:
Expand Down Expand Up @@ -365,6 +399,18 @@ components:
type: array
items:
$ref: '#/components/schemas/Row'
MediaScoreRequest:
allOf:
- $ref: '#/components/schemas/ScoreRequest'
- properties:
mediaFields:
description: >
An array holding the names of all fields which are expected to contain media files.
Contents of these fields will be replaced by corresponding uploaded files where the
expected values in the column must be the file names of the uploaded files.
type: array
items:
type: string
securitySchemes:
api_key:
type: apiKey
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package ai.h2o.mojos.deploy.local.rest.Converter;

import org.springframework.core.convert.converter.Converter;
import org.springframework.core.io.Resource;
import org.springframework.web.multipart.MultipartFile;

public class MultipartConverter implements Converter<MultipartFile, Resource> {

@Override
public Resource convert(MultipartFile source) {
return source.getResource();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package ai.h2o.mojos.deploy.local.rest.Converter;

import ai.h2o.mojos.deploy.common.rest.v2.model.ScoreMediaRequest;
import com.google.gson.Gson;
import org.springframework.core.convert.converter.Converter;

public class ScoreMediaRequestConverter implements Converter<String, ScoreMediaRequest> {

@Override
public ScoreMediaRequest convert(String input) {
return new Gson().fromJson(input, ScoreMediaRequest.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package ai.h2o.mojos.deploy.local.rest.config;

import ai.h2o.mojos.deploy.local.rest.converter.MultipartConverter;
import ai.h2o.mojos.deploy.local.rest.converter.ScoreMediaRequestConverter;
import org.springframework.context.annotation.Configuration;
import org.springframework.format.FormatterRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

public class WebConfig implements WebMvcConfigurer {

@Override
public void addFormatters(FormatterRegistry registry) {
registry.addConverter(new MultipartConverter());
registry.addConverter(new ScoreMediaRequestConverter());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
Expand Down Expand Up @@ -156,6 +157,13 @@ public ResponseEntity<ScoreRequest> getSampleRequest() {
return ResponseEntity.ok(sampleRequestBuilder.build(scorer.getPipeline().getInputMeta()));
}

@Override
public ResponseEntity<ScoreResponse> getScoreMedia(
ScoreMediaRequest scoreMediaRequest, List<Resource> files) {
log.info("Received score media request");
return ResponseEntity.status(HttpStatus.NOT_IMPLEMENTED).build();
}

private String getScorerModelId() {
try {
String res = System.getenv(MODEL_ID);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

Expand All @@ -32,6 +33,7 @@
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.server.ResponseStatusException;
Expand Down Expand Up @@ -187,4 +189,20 @@ void verifyScore_Fails_ReturnsException() {
assertEquals(HttpStatus.SERVICE_UNAVAILABLE, ((ResponseStatusException) ex).getStatus());
}
}

@Test
void verifyScoreMedia_ReturnsUnimplemented() {
// Given
MojoScorer scorer = mock(MojoScorer.class);
ScoreMediaRequest request = mock(ScoreMediaRequest.class);
List<Resource> files = new ArrayList<>();
ModelsApiControllerV1Exp controller = new ModelsApiController(scorer, sampleRequestBuilder);

// When
ResponseEntity<ScoreResponse> response =
controller.getScoreMedia(request, files);

// Then
assertEquals(response.getStatusCode(), HttpStatus.NOT_IMPLEMENTED);
}
}

0 comments on commit bcb2668

Please sign in to comment.