diff --git a/monailabel/endpoints/infer.py b/monailabel/endpoints/infer.py index fee0232af..f66e2d0a4 100644 --- a/monailabel/endpoints/infer.py +++ b/monailabel/endpoints/infer.py @@ -24,7 +24,7 @@ from requests_toolbelt import MultipartEncoder from monailabel.config import RBAC_USER, settings -from monailabel.datastore.utils.convert import binary_to_image +from monailabel.datastore.utils.convert import binary_to_image, nifti_to_dicom_seg from monailabel.endpoints.user.auth import RBAC, User from monailabel.interfaces.app import MONAILabelApp from monailabel.interfaces.utils.app import app_instance @@ -62,6 +62,7 @@ }, "application/json": {"schema": {"type": "string", "example": "{}"}}, "application/octet-stream": {"schema": {"type": "string", "format": "binary"}}, + "application/dicom": {"schema": {"type": "string", "format": "binary"}}, }, }, }, @@ -72,6 +73,7 @@ class ResultType(str, Enum): image = "image" json = "json" all = "all" + dicom_seg = "dicom_seg" def send_response(datastore, result, output, background_tasks): @@ -93,6 +95,13 @@ def send_response(datastore, result, output, background_tasks): if output == "image": return FileResponse(res_img, media_type=m_type, filename=os.path.basename(res_img)) + if output == "dicom_seg": + res_dicom_seg = result.get("dicom_seg") + if res_dicom_seg is None: + raise HTTPException(status_code=500, detail="Error processing inference") + else: + return FileResponse(res_dicom_seg, media_type="application/dicom", filename=os.path.basename(res_dicom_seg)) + res_fields = dict() res_fields["params"] = (None, json.dumps(res_json), "application/json") if res_img and os.path.exists(res_img): @@ -162,6 +171,22 @@ def run_inference( result = instance.infer(request) if result is None: raise HTTPException(status_code=500, detail="Failed to execute infer") + + # Dicom Seg Integration + if output == "dicom_seg" and image: + dicom_seg_file = None + image_uri = instance.datastore().get_image_uri(image) + if not image_uri: + raise HTTPException(status_code=500, detail="Image not found") + elif p.get("label_info") is None: + raise HTTPException(status_code=404, detail="Parameters for DICOM Seg inference cannot be empty!") + # Transform image uri to id (similar to _to_id in local datastore) + suffixes = [".nii", ".nii.gz", ".nrrd"] + image_path = [image_uri.replace(suffix, "") for suffix in suffixes if image_uri.endswith(suffix)][0] + res_img = result.get("file") if result.get("file") else result.get("label") + dicom_seg_file = nifti_to_dicom_seg(image_path, res_img, p.get("label_info"), use_itk=True) + result["dicom_seg"] = dicom_seg_file + return send_response(instance.datastore(), result, output, background_tasks)