-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
26 changed files
with
352 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
include environments/* | ||
include apps/* |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from hypha_launcher.utils.container import ContainerEngine | ||
from pathlib import Path | ||
from pyotritonclient import get_config, execute | ||
from functools import partial | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
container_engine = ContainerEngine() | ||
TRITON_IMAGE = "docker://nvcr.io/nvidia/tritonserver:23.03-py3" | ||
|
||
|
||
async def hypha_startup(server): | ||
# get current dir | ||
current_dir = Path(__file__).parent | ||
host_port = "9302" | ||
logger.info(f"Pulling triton image {TRITON_IMAGE}") | ||
container_engine.pull_image(TRITON_IMAGE) | ||
logger.info(f"Starting triton server at port {host_port}") | ||
container_engine.run_command( | ||
f'bash -c "tritonserver --model-repository=/models --log-verbose=3 --log-info=1 --log-warning=1 --log-error=1 --model-control-mode=poll --exit-on-error=false --repository-poll-secs=10 --allow-grpc=False --http-port={host_port}"', # noqa | ||
TRITON_IMAGE, | ||
ports={host_port: host_port}, | ||
volumes={str(current_dir): "/models"}, | ||
) | ||
|
||
server_url = f"http://localhost:{host_port}" | ||
logger.info(f"Triton server is running at {server_url}") | ||
|
||
svc = await server.register_service({ | ||
"name": "CellPose", | ||
"id": "cellpose", | ||
"config": { | ||
"visibility": "public" | ||
}, | ||
"train": partial(execute, server_url=server_url, model_name="cellpose-train"), | ||
"train_config": partial(get_config, server_url=server_url, model_name="cellpose-train"), | ||
"predict": partial(execute, server_url=server_url, model_name="cellpose-predict"), | ||
"predict_config": partial(get_config, server_url=server_url, model_name="cellpose-predict"), | ||
}) | ||
|
||
logger.info(f"CellPose service is registered as `{svc['id']}`") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
name: Cellpose | ||
id: cellpose | ||
description: Cellpose is a generalist algorithm for cell and nucleus segmentation | ||
runtime: python | ||
entrypoint: main.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
import logging | ||
import sys, os | ||
import imagej | ||
import scyjava as sj | ||
import asyncio | ||
import traceback | ||
import numpy as np | ||
import xarray as xr | ||
from jpype import JOverride, JImplements | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
os.environ["JAVA_HOME"] = os.sep.join(sys.executable.split(os.sep)[:-2] + ["jre"]) | ||
|
||
def capture_console(ij, print=True): | ||
logs = {} | ||
logs["stdout"] = [] | ||
logs["stderr"] = [] | ||
|
||
@JImplements("org.scijava.console.OutputListener") | ||
class JavaOutputListener: | ||
@JOverride | ||
def outputOccurred(self, e): | ||
source = e.getSource().toString | ||
output = e.getOutput() | ||
|
||
if print: | ||
if source == "STDOUT": | ||
sys.stdout.write(output) | ||
logs["stdout"].append(output) | ||
elif source == "STDERR": | ||
sys.stderr.write(output) | ||
logs["stderr"].append(output) | ||
else: | ||
output = "[{}] {}".format(source, output) | ||
sys.stderr.write(output) | ||
logs["stderr"].append(output) | ||
|
||
ij.py._outputMapper = JavaOutputListener() | ||
ij.console().addOutputListener(ij.py._outputMapper) | ||
return logs | ||
|
||
|
||
def format_logs(logs): | ||
output = "" | ||
if logs["stdout"]: | ||
output += "STDOUT:\n" | ||
output += "\n".join(logs["stdout"]) | ||
output += "\n" | ||
if logs["stderr"]: | ||
output += "STDERR:\n" | ||
output += "\n".join(logs["stderr"]) | ||
output += "\n" | ||
return output | ||
|
||
|
||
def get_module_info(ij, custom_script, name=None): | ||
name = name or "scijava_script" | ||
ScriptInfo = sj.jimport("org.scijava.script.ScriptInfo") | ||
StringReader = sj.jimport("java.io.StringReader") | ||
moduleinfo = ScriptInfo(ij.getContext(), name, StringReader(custom_script)) | ||
inputs = {} | ||
outputs = {} | ||
|
||
for inp in ij.py.from_java(moduleinfo.inputs()): | ||
input_type = str(inp.getType().getName()) | ||
input_name = str(inp.getName()) | ||
print(input_type, input_name) | ||
inputs[input_name] = {"name": input_name, "type": input_type} | ||
|
||
for outp in ij.py.from_java(moduleinfo.outputs()): | ||
output_type = str(outp.getType().getName()) | ||
output_name = str(outp.getName()) | ||
outputs[output_name] = {"name": output_name, "type": output_type} | ||
|
||
return {"id": moduleinfo.getIdentifier(), "outputs": outputs, "inputs": inputs} | ||
|
||
|
||
def check_size(array): | ||
result_bytes = array.tobytes() | ||
if len(result_bytes) > 20000000: # 20MB | ||
raise Exception( | ||
f"The data is too large ({len(result_bytes)} bytes) to be transfered." | ||
) | ||
|
||
|
||
async def execute(config, context=None): | ||
loop = asyncio.get_event_loop() | ||
return await loop.run_in_executor(None, run_imagej, config) | ||
|
||
|
||
def run_imagej(config): | ||
headless = config.get("headless", False) | ||
ij = imagej.init(os.environ["IMAGEJ_DIR"], headless=headless) | ||
try: | ||
WindowManager = sj.jimport("ij.WindowManager") | ||
ImagePlus = sj.jimport("ij.ImagePlus") | ||
logs = capture_console(ij) | ||
script = config.get("script") | ||
lang = config.get("lang", "ijm") | ||
assert script is not None, "script is required" | ||
module_info = get_module_info(ij, script) | ||
inputs_info = module_info["inputs"] | ||
outputs_info = module_info["outputs"] | ||
inputs = config.get("inputs", {}) | ||
select_outputs = config.get("select_outputs") | ||
args = {} | ||
for k in inputs: | ||
if isinstance(inputs[k], (np.ndarray, np.generic, dict)): | ||
if isinstance(inputs[k], (np.ndarray, np.generic)): | ||
if inputs[k].ndim == 2: | ||
dims = ["x", "y"] | ||
elif inputs[k].ndim == 3 and inputs[k].shape[2] in [1, 3, 4]: | ||
dims = ["x", "y", "c"] | ||
elif inputs[k].ndim == 3 and inputs[k].shape[0] in [1, 3, 4]: | ||
dims = ["c", "x", "y"] | ||
elif inputs[k].ndim == 3: | ||
dims = ["z", "x", "y"] | ||
elif inputs[k].ndim == 4: | ||
dims = ["z", "x", "y", "c"] | ||
elif inputs[k].ndim == 5: | ||
dims = ["t", "z", "x", "y", "c"] | ||
else: | ||
raise Exception(f"Unsupported ndim: {inputs[k].ndim}") | ||
inputs[k] = {"data": inputs[k], "dims": dims} | ||
|
||
img = inputs[k] | ||
assert isinstance( | ||
img, dict | ||
), f"input {k} must be a dictionary or a numpy array" | ||
assert "data" in img, f"data is required for {k}" | ||
assert "dims" in img, f"dims is required for {k}" | ||
da = xr.DataArray( | ||
data=img["data"], | ||
dims=img["dims"], | ||
attrs=img.get("attrs", {}), | ||
name=k, | ||
) | ||
inputs[k] = ij.py.to_java(da) | ||
if lang == "ijm": | ||
# convert to ImagePlus | ||
inputs[k] = ij.convert().convert(inputs[k], ImagePlus) | ||
if inputs[k]: | ||
inputs[k].setTitle(k) | ||
# Display the image | ||
if not headless: | ||
inputs[k].show() | ||
else: | ||
raise NotImplementedError( | ||
"Don't know how to display the image (only ijm is supported)." | ||
) | ||
if k in inputs_info: | ||
args[k] = ij.py.to_java(inputs[k]) | ||
|
||
# Run the script | ||
macro_result = ij.py.run_script(lang, script, args) | ||
results = {} | ||
if select_outputs is None: | ||
select_outputs = list(outputs_info.keys()) | ||
for k in select_outputs: | ||
if k in outputs_info: | ||
results[k] = macro_result.getOutput(k) | ||
if results[k] and not isinstance(results[k], (int, str, float, bool)): | ||
try: | ||
results[k] = ij.py.from_java(results[k]).to_numpy() | ||
check_size(results[k]) | ||
except Exception: | ||
# TODO: This is needed due to a bug in pyimagej for converting java string | ||
if str(type(results[k])) == "<java class 'java.lang.String'>": | ||
results[k] = str(results[k]) | ||
else: | ||
results[k] = { | ||
"type": str(type(results[k])), | ||
"text": str(results[k]), | ||
} | ||
else: | ||
# If the output name is not in the script annotation, | ||
# Try to get the image from the WindowManager by title | ||
img = WindowManager.getImage(k) | ||
if not img: | ||
raise Exception(f"Output not found: {k}\n{format_logs(logs)}") | ||
results[k] = ij.py.from_java(img).to_numpy() | ||
check_size(results[k]) | ||
except Exception as exp: | ||
raise exp | ||
finally: | ||
ij.dispose() | ||
|
||
return {"outputs": results, "logs": logs} | ||
|
||
|
||
|
||
test_macro = """ | ||
#@ String name | ||
#@ int age | ||
#@ String city | ||
#@output Object greeting | ||
greeting = "Hi " + name + ". You are " + age + " years old, and live in " + city + "." | ||
""" | ||
|
||
async def hypha_startup(server): | ||
try: | ||
print("Testing the imagej service...") | ||
ret = await execute( | ||
{ | ||
"script": test_macro, | ||
"inputs": {"name": "Tom", "age": 20, "city": "Shanghai"}, | ||
} | ||
) | ||
outputs = ret["outputs"] | ||
assert ( | ||
outputs["greeting"] == "Hi Tom. You are 20 years old, and live in Shanghai." | ||
) | ||
except Exception: | ||
print(traceback.format_exc()) | ||
sys.exit(1) | ||
|
||
logger.info("Starting the imagej service...") | ||
svc = await server.register_service( | ||
{ | ||
"id": "imagej-service", | ||
"type": "imagej-service", | ||
"config": {"require_context": True, "visibility": "public"}, | ||
"execute": execute, | ||
} | ||
) | ||
|
||
logger.info(f"ImageJ service is registered as `{svc['id']}`") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
name: ImageJ | ||
id: imagej | ||
description: ImageJ is a public domain Java image processing program inspired by NIH Image for the Macintosh. | ||
runtime: conda | ||
entrypoint: main.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,13 @@ | ||
from bioengine.app_loader import load_apps | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def connect_server(server_url): | ||
raise NotImplementedError("This function is not implemented yet.") | ||
|
||
def register_bioengine(server): | ||
raise NotImplementedError("This function is not implemented yet.") | ||
async def register_bioengine_apps(server): | ||
for app in load_apps(): | ||
logger.info(f"Registering service {app.name}") | ||
await app.run(server) | ||
logger.info(f"Service {app.name} registered.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from pydantic import BaseModel | ||
from pathlib import Path | ||
import logging | ||
from yaml import safe_load | ||
from typing import Callable, Optional | ||
import sys | ||
import importlib.util | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# define runtime type enum for app runtime | ||
class AppRuntime(str): | ||
python = "python" | ||
pyodide = "pyodide" | ||
triton = "triton" | ||
|
||
class AppInfo(BaseModel): | ||
name: str | ||
id: str | ||
description: str | ||
runtime: AppRuntime | ||
entrypoint: Optional[str] = None | ||
|
||
async def run(self, server): | ||
if self.runtime == AppRuntime.python: | ||
assert self.entrypoint | ||
|
||
file_path = Path(__file__).parent.parent / "apps" / self.id / self.entrypoint | ||
module_name = 'bioengine.apps.' + self.id.replace('-', '_') | ||
spec = importlib.util.spec_from_file_location(module_name, file_path) | ||
module = importlib.util.module_from_spec(spec) | ||
spec.loader.exec_module(module) | ||
assert hasattr(module, "hypha_startup") | ||
await module.hypha_startup(server) | ||
|
||
def load_apps(): | ||
current_dir = Path(__file__).parent | ||
apps_dir = current_dir.parent / "apps" | ||
# list folders under apps_dir | ||
for app_dir in apps_dir.iterdir(): | ||
if app_dir.is_dir(): | ||
manifest_file = app_dir / "manifest.yaml" | ||
if manifest_file.exists(): | ||
manifest = safe_load(manifest_file.read_text()) | ||
yield AppInfo.parse_obj(manifest) |
Empty file.
Oops, something went wrong.