Skip to content

Commit

Permalink
Removed unecessary interfaces from MLProvider. Infer endpoint can now…
Browse files Browse the repository at this point in the history
… return DICOM Seg files.
  • Loading branch information
Rui-Jesus committed Jun 20, 2023
1 parent a220463 commit e0e8049
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
import pt.ua.dicoogle.sdk.datastructs.SearchResult;
import pt.ua.dicoogle.sdk.datastructs.dim.DimLevel;
import pt.ua.dicoogle.sdk.mlprovider.MLDataset;
import pt.ua.dicoogle.sdk.mlprovider.MLPrediction;
import pt.ua.dicoogle.sdk.mlprovider.MLPredictionRequest;
import pt.ua.dicoogle.sdk.mlprovider.MLInference;
import pt.ua.dicoogle.sdk.mlprovider.MLInferenceRequest;
import pt.ua.dicoogle.sdk.mlprovider.MLProviderInterface;
import pt.ua.dicoogle.sdk.settings.ConfigurationHolder;
import pt.ua.dicoogle.sdk.task.JointQueryTask;
Expand Down Expand Up @@ -837,13 +837,13 @@ public List<Report> indexBlocking(URI path) {
* @param predictionRequest
* @return the created task
*/
public Task<MLPrediction> makePredictionOverImage(final String provider, final MLPredictionRequest predictionRequest) {
public Task<MLInference> infer(final String provider, final MLInferenceRequest predictionRequest) {
MLProviderInterface providerInterface = this.getMachineLearningProviderByName(provider, true);
if(providerInterface == null)
return null;

String taskName = "MLPredictionTask" + UUID.randomUUID();
Task<MLPrediction> result = providerInterface.makePrediction(predictionRequest);
Task<MLInference> result = providerInterface.infer(predictionRequest);
result.setName(taskName);
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.dcm4che3.imageio.plugins.dcm.DicomMetaData;
import org.restlet.data.Status;
import org.slf4j.Logger;
Expand All @@ -13,20 +15,23 @@
import pt.ua.dicoogle.sdk.datastructs.dim.DimLevel;
import pt.ua.dicoogle.sdk.datastructs.dim.Point2D;
import pt.ua.dicoogle.server.web.dicom.WSISopDescriptor;
import pt.ua.dicoogle.sdk.mlprovider.MLPrediction;
import pt.ua.dicoogle.sdk.mlprovider.MLPredictionRequest;
import pt.ua.dicoogle.sdk.mlprovider.MLInference;
import pt.ua.dicoogle.sdk.mlprovider.MLInferenceRequest;
import pt.ua.dicoogle.sdk.task.Task;
import pt.ua.dicoogle.server.web.dicom.ROIExtractor;
import pt.ua.dicoogle.server.web.utils.cache.WSICache;

import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;

public class MakePredictionServlet extends HttpServlet {
Expand Down Expand Up @@ -78,11 +83,11 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)

String provider = body.get("provider").asText();
String modelID = body.get("modelID").asText();
boolean wsi = body.get("wsi").asBoolean();
boolean wsi = body.has("wsi") && body.get("wsi").asBoolean();
DimLevel level = DimLevel.valueOf(body.get("level").asText().toUpperCase());
String dimUID = body.get("uid").asText();

Task<MLPrediction> task;
Task<MLInference> task;

if(wsi){

Expand Down Expand Up @@ -114,11 +119,11 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
task.run();
}

private Task<MLPrediction> sendWSIRequest(String provider, String modelID, String baseSopInstanceUID, String uid,
BulkAnnotation.AnnotationType roiType, List<Point2D> roi, HttpServletResponse response){
private Task<MLInference> sendWSIRequest(String provider, String modelID, String baseSopInstanceUID, String uid,
BulkAnnotation.AnnotationType roiType, List<Point2D> roi, HttpServletResponse response){

ObjectMapper mapper = new ObjectMapper();
MLPredictionRequest predictionRequest = new MLPredictionRequest(true, DimLevel.INSTANCE, uid, modelID);
MLInferenceRequest predictionRequest = new MLInferenceRequest(true, DimLevel.INSTANCE, uid, modelID);
BulkAnnotation annotation = new BulkAnnotation();
annotation.setPoints(roi);
annotation.setAnnotationType(roiType);
Expand All @@ -127,11 +132,11 @@ private Task<MLPrediction> sendWSIRequest(String provider, String modelID, Strin
DicomMetaData dicomMetaData = this.getDicomMetadata(uid);
BufferedImage bi = roiExtractor.extractROI(dicomMetaData, annotation);
predictionRequest.setRoi(bi);
Task<MLPrediction> task = PluginController.getInstance().makePredictionOverImage(provider, predictionRequest);
Task<MLInference> task = PluginController.getInstance().infer(provider, predictionRequest);
if(task != null){
task.onCompletion(() -> {
try {
MLPrediction prediction = task.get();
MLInference prediction = task.get();

if(prediction == null){
log.error("Provider returned null prediction");
Expand Down Expand Up @@ -174,26 +179,64 @@ private Task<MLPrediction> sendWSIRequest(String provider, String modelID, Strin
}
}

private Task<MLPrediction> sendRequest(String provider, String modelID, DimLevel level, String dimUID, HttpServletResponse response){
private Task<MLInference> sendRequest(String provider, String modelID, DimLevel level, String dimUID, HttpServletResponse response){
ObjectMapper mapper = new ObjectMapper();
MLPredictionRequest predictionRequest = new MLPredictionRequest(true, level, dimUID, modelID);
Task<MLPrediction> task = PluginController.getInstance().makePredictionOverImage(provider, predictionRequest);
MLInferenceRequest predictionRequest = new MLInferenceRequest(false, level, dimUID, modelID);
Task<MLInference> task = PluginController.getInstance().infer(provider, predictionRequest);
if(task != null){
task.onCompletion(() -> {
try {
MLPrediction prediction = task.get();
MLInference prediction = task.get();

if(prediction == null){
log.error("Provider returned null prediction");
response.sendError(Status.SERVER_ERROR_INTERNAL.getCode(), "Could not make prediction");
return;
}

response.setContentType("application/json");
PrintWriter out = response.getWriter();
mapper.writeValue(out, prediction);
out.close();
out.flush();
if(prediction.getDicomSEG() != null){
// We have a file to send, got to build a multi part response
String boundary = UUID.randomUUID().toString();
response.setContentType("multipart/form-data; boundary=" + boundary);

ServletOutputStream out = response.getOutputStream();

out.print("--" + boundary);
out.println();
out.print("Content-Disposition: form-data; name=\"params\"");
out.println();
out.print("Content-Type: application/json");
out.println(); out.println();
out.print(mapper.writeValueAsString(prediction));
out.println();
out.print("--" + boundary);
out.println();
out.print("Content-Disposition: form-data; name=\"dicomseg\"; filename=\"dicomseg.dcm\"");
out.println();
out.print("Content-Type: application/octet-stream");
out.println(); out.println();

byte[] targetArray = new byte[prediction.getDicomSEG().available()];
prediction.getDicomSEG().read(targetArray);
out.write(targetArray);
out.flush();
out.close();
} else {
response.setContentType("application/json");
PrintWriter out = response.getWriter();
mapper.writeValue(out, prediction);
out.close();
out.flush();
}

try{
if(!StringUtils.isBlank(prediction.getResourcesFolder())){
FileUtils.deleteDirectory(new File(prediction.getResourcesFolder()));
}
} catch (IOException e){
log.warn("Could not delete temporary file", e);
}

} catch (InterruptedException | ExecutionException e) {
log.error("Could not make prediction", e);
try {
Expand Down Expand Up @@ -222,7 +265,7 @@ private DicomMetaData getDicomMetadata(String sop) throws IOException{
* @param scale to transform coordinates
* @return the ml prediction with the converted coordinates.
*/
private void convertCoordinates(MLPrediction prediction, Point2D tl, double scale){
private void convertCoordinates(MLInference prediction, Point2D tl, double scale){
for(BulkAnnotation ann : prediction.getAnnotations()){
for(Point2D p : ann.getPoints()){
p.setX((p.getX() + tl.getX())/scale);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
package pt.ua.dicoogle.sdk.mlprovider;

import com.fasterxml.jackson.annotation.JsonIgnore;
import org.dcm4che2.io.DicomInputStream;
import pt.ua.dicoogle.sdk.datastructs.dim.BulkAnnotation;

import java.util.HashMap;
import java.util.List;

/**
* This object maps predictions done by the AI algorithms.
* It can contain a set of metrics and a list of annotations.
* It can contain a set of metrics, annotations and a DICOM SEG file.
*/
public class MLPrediction {
public class MLInference {

private HashMap<String, String> metrics;

private String version;

private List<BulkAnnotation> annotations;

@JsonIgnore
private String resourcesFolder;

@JsonIgnore
private DicomInputStream dicomSEG;

public HashMap<String, String> getMetrics() {
return metrics;
}
Expand All @@ -32,6 +41,22 @@ public void setAnnotations(List<BulkAnnotation> annotations) {
this.annotations = annotations;
}

public String getResourcesFolder() {
return resourcesFolder;
}

public void setResourcesFolder(String resourcesFolder) {
this.resourcesFolder = resourcesFolder;
}

public DicomInputStream getDicomSEG() {
return dicomSEG;
}

public void setDicomSEG(DicomInputStream dicomSEG) {
this.dicomSEG = dicomSEG;
}

public String getVersion() {
return version;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import java.awt.image.BufferedImage;

public class MLPredictionRequest {
public class MLInferenceRequest {

private boolean isWsi;

Expand All @@ -16,7 +16,7 @@ public class MLPredictionRequest {

private String modelID;

public MLPredictionRequest(boolean isWsi, DimLevel level, String dimID, String modelID) {
public MLInferenceRequest(boolean isWsi, DimLevel level, String dimID, String modelID) {
this.isWsi = isWsi;
this.level = level;
this.dimID = dimID;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ public abstract class MLProviderInterface implements DicooglePlugin {
*/
public abstract MLModel createModel();


/**
* This method deploys a model
*/
public abstract void deployModel();

/**
* This method lists the models created on this provider.
*/
Expand All @@ -63,21 +57,6 @@ public abstract class MLProviderInterface implements DicooglePlugin {
*/
public abstract boolean stopTraining(String trainingTaskID);

/**
* This method creates a endpoint that exposes a service
*/
public abstract void createEndpoint();

/**
* This method lists all available endpoints
*/
public abstract List<MLEndpoint> listEndpoints();

/**
* This method deletes a endpoint
*/
public abstract void deleteEndpoint();

/**
* This method deletes a model
*/
Expand All @@ -87,14 +66,14 @@ public abstract class MLProviderInterface implements DicooglePlugin {
* Order a prediction over a single object.
* The object can be a series instance, a sop instance or a 2D/3D ROI.
*
* @param predictionRequest object that defines this prediction request
* @param inferRequest object that defines this inference request
*/
public abstract Task<MLPrediction> makePrediction(MLPredictionRequest predictionRequest);
public abstract Task<MLInference> infer(MLInferenceRequest inferRequest);

/**
* This method makes a bulk prediction using the selected model
* This method makes a bulk inference request using the selected model
*/
public abstract void makeBulkPrediction();
public abstract void batchInfer();

public Set<ML_DATA_TYPE> getAcceptedDataTypes() {
return acceptedDataTypes;
Expand Down

0 comments on commit e0e8049

Please sign in to comment.