Skip to content

Commit

Permalink
Add basic support for running app
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Mar 2, 2024
1 parent a0f4747 commit 7d54441
Show file tree
Hide file tree
Showing 26 changed files with 352 additions and 105 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
include environments/*
include apps/*
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
42 changes: 42 additions & 0 deletions apps/cellpose/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from hypha_launcher.utils.container import ContainerEngine
from pathlib import Path
from pyotritonclient import get_config, execute
from functools import partial
import logging

logger = logging.getLogger(__name__)

container_engine = ContainerEngine()
TRITON_IMAGE = "docker://nvcr.io/nvidia/tritonserver:23.03-py3"


async def hypha_startup(server):
# get current dir
current_dir = Path(__file__).parent
host_port = "9302"
logger.info(f"Pulling triton image {TRITON_IMAGE}")
container_engine.pull_image(TRITON_IMAGE)
logger.info(f"Starting triton server at port {host_port}")
container_engine.run_command(
f'bash -c "tritonserver --model-repository=/models --log-verbose=3 --log-info=1 --log-warning=1 --log-error=1 --model-control-mode=poll --exit-on-error=false --repository-poll-secs=10 --allow-grpc=False --http-port={host_port}"', # noqa
TRITON_IMAGE,
ports={host_port: host_port},
volumes={str(current_dir): "/models"},
)

server_url = f"http://localhost:{host_port}"
logger.info(f"Triton server is running at {server_url}")

svc = await server.register_service({
"name": "CellPose",
"id": "cellpose",
"config": {
"visibility": "public"
},
"train": partial(execute, server_url=server_url, model_name="cellpose-train"),
"train_config": partial(get_config, server_url=server_url, model_name="cellpose-train"),
"predict": partial(execute, server_url=server_url, model_name="cellpose-predict"),
"predict_config": partial(get_config, server_url=server_url, model_name="cellpose-predict"),
})

logger.info(f"CellPose service is registered as `{svc['id']}`")
5 changes: 5 additions & 0 deletions apps/cellpose/manifest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name: Cellpose
id: cellpose
description: Cellpose is a generalist algorithm for cell and nucleus segmentation
runtime: python
entrypoint: main.py
228 changes: 228 additions & 0 deletions apps/imagej/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import logging
import sys, os
import imagej
import scyjava as sj
import asyncio
import traceback
import numpy as np
import xarray as xr
from jpype import JOverride, JImplements


logger = logging.getLogger(__name__)
os.environ["JAVA_HOME"] = os.sep.join(sys.executable.split(os.sep)[:-2] + ["jre"])

def capture_console(ij, print=True):
logs = {}
logs["stdout"] = []
logs["stderr"] = []

@JImplements("org.scijava.console.OutputListener")
class JavaOutputListener:
@JOverride
def outputOccurred(self, e):
source = e.getSource().toString
output = e.getOutput()

if print:
if source == "STDOUT":
sys.stdout.write(output)
logs["stdout"].append(output)
elif source == "STDERR":
sys.stderr.write(output)
logs["stderr"].append(output)
else:
output = "[{}] {}".format(source, output)
sys.stderr.write(output)
logs["stderr"].append(output)

ij.py._outputMapper = JavaOutputListener()
ij.console().addOutputListener(ij.py._outputMapper)
return logs


def format_logs(logs):
output = ""
if logs["stdout"]:
output += "STDOUT:\n"
output += "\n".join(logs["stdout"])
output += "\n"
if logs["stderr"]:
output += "STDERR:\n"
output += "\n".join(logs["stderr"])
output += "\n"
return output


def get_module_info(ij, custom_script, name=None):
name = name or "scijava_script"
ScriptInfo = sj.jimport("org.scijava.script.ScriptInfo")
StringReader = sj.jimport("java.io.StringReader")
moduleinfo = ScriptInfo(ij.getContext(), name, StringReader(custom_script))
inputs = {}
outputs = {}

for inp in ij.py.from_java(moduleinfo.inputs()):
input_type = str(inp.getType().getName())
input_name = str(inp.getName())
print(input_type, input_name)
inputs[input_name] = {"name": input_name, "type": input_type}

for outp in ij.py.from_java(moduleinfo.outputs()):
output_type = str(outp.getType().getName())
output_name = str(outp.getName())
outputs[output_name] = {"name": output_name, "type": output_type}

return {"id": moduleinfo.getIdentifier(), "outputs": outputs, "inputs": inputs}


def check_size(array):
result_bytes = array.tobytes()
if len(result_bytes) > 20000000: # 20MB
raise Exception(
f"The data is too large ({len(result_bytes)} bytes) to be transfered."
)


async def execute(config, context=None):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, run_imagej, config)


def run_imagej(config):
headless = config.get("headless", False)
ij = imagej.init(os.environ["IMAGEJ_DIR"], headless=headless)
try:
WindowManager = sj.jimport("ij.WindowManager")
ImagePlus = sj.jimport("ij.ImagePlus")
logs = capture_console(ij)
script = config.get("script")
lang = config.get("lang", "ijm")
assert script is not None, "script is required"
module_info = get_module_info(ij, script)
inputs_info = module_info["inputs"]
outputs_info = module_info["outputs"]
inputs = config.get("inputs", {})
select_outputs = config.get("select_outputs")
args = {}
for k in inputs:
if isinstance(inputs[k], (np.ndarray, np.generic, dict)):
if isinstance(inputs[k], (np.ndarray, np.generic)):
if inputs[k].ndim == 2:
dims = ["x", "y"]
elif inputs[k].ndim == 3 and inputs[k].shape[2] in [1, 3, 4]:
dims = ["x", "y", "c"]
elif inputs[k].ndim == 3 and inputs[k].shape[0] in [1, 3, 4]:
dims = ["c", "x", "y"]
elif inputs[k].ndim == 3:
dims = ["z", "x", "y"]
elif inputs[k].ndim == 4:
dims = ["z", "x", "y", "c"]
elif inputs[k].ndim == 5:
dims = ["t", "z", "x", "y", "c"]
else:
raise Exception(f"Unsupported ndim: {inputs[k].ndim}")
inputs[k] = {"data": inputs[k], "dims": dims}

img = inputs[k]
assert isinstance(
img, dict
), f"input {k} must be a dictionary or a numpy array"
assert "data" in img, f"data is required for {k}"
assert "dims" in img, f"dims is required for {k}"
da = xr.DataArray(
data=img["data"],
dims=img["dims"],
attrs=img.get("attrs", {}),
name=k,
)
inputs[k] = ij.py.to_java(da)
if lang == "ijm":
# convert to ImagePlus
inputs[k] = ij.convert().convert(inputs[k], ImagePlus)
if inputs[k]:
inputs[k].setTitle(k)
# Display the image
if not headless:
inputs[k].show()
else:
raise NotImplementedError(
"Don't know how to display the image (only ijm is supported)."
)
if k in inputs_info:
args[k] = ij.py.to_java(inputs[k])

# Run the script
macro_result = ij.py.run_script(lang, script, args)
results = {}
if select_outputs is None:
select_outputs = list(outputs_info.keys())
for k in select_outputs:
if k in outputs_info:
results[k] = macro_result.getOutput(k)
if results[k] and not isinstance(results[k], (int, str, float, bool)):
try:
results[k] = ij.py.from_java(results[k]).to_numpy()
check_size(results[k])
except Exception:
# TODO: This is needed due to a bug in pyimagej for converting java string
if str(type(results[k])) == "<java class 'java.lang.String'>":
results[k] = str(results[k])
else:
results[k] = {
"type": str(type(results[k])),
"text": str(results[k]),
}
else:
# If the output name is not in the script annotation,
# Try to get the image from the WindowManager by title
img = WindowManager.getImage(k)
if not img:
raise Exception(f"Output not found: {k}\n{format_logs(logs)}")
results[k] = ij.py.from_java(img).to_numpy()
check_size(results[k])
except Exception as exp:
raise exp
finally:
ij.dispose()

return {"outputs": results, "logs": logs}



test_macro = """
#@ String name
#@ int age
#@ String city
#@output Object greeting
greeting = "Hi " + name + ". You are " + age + " years old, and live in " + city + "."
"""

async def hypha_startup(server):
try:
print("Testing the imagej service...")
ret = await execute(
{
"script": test_macro,
"inputs": {"name": "Tom", "age": 20, "city": "Shanghai"},
}
)
outputs = ret["outputs"]
assert (
outputs["greeting"] == "Hi Tom. You are 20 years old, and live in Shanghai."
)
except Exception:
print(traceback.format_exc())
sys.exit(1)

logger.info("Starting the imagej service...")
svc = await server.register_service(
{
"id": "imagej-service",
"type": "imagej-service",
"config": {"require_context": True, "visibility": "public"},
"execute": execute,
}
)

logger.info(f"ImageJ service is registered as `{svc['id']}`")
5 changes: 5 additions & 0 deletions apps/imagej/manifest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name: ImageJ
id: imagej
description: ImageJ is a public domain Java image processing program inspired by NIH Image for the Macintosh.
runtime: conda
entrypoint: main.py
11 changes: 9 additions & 2 deletions bioengine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from bioengine.app_loader import load_apps
import logging

logger = logging.getLogger(__name__)

def connect_server(server_url):
raise NotImplementedError("This function is not implemented yet.")

def register_bioengine(server):
raise NotImplementedError("This function is not implemented yet.")
async def register_bioengine_apps(server):
for app in load_apps():
logger.info(f"Registering service {app.name}")
await app.run(server)
logger.info(f"Service {app.name} registered.")
12 changes: 3 additions & 9 deletions bioengine/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,24 @@
import argparse
import asyncio
import subprocess
import os
# import os

def start_server(args):
# get current file path so we can get the path of apps under the same directory
current_dir = os.path.dirname(os.path.abspath(__file__))
# current_dir = os.path.dirname(os.path.abspath(__file__))
command = [
sys.executable,
"-m",
"hypha.server",
f"--host={args.host}",
f"--port={args.port}",
f"--public-base-url={args.public_base_url}",
"--startup-functions=bioengine:register_bioengine"
"--startup-functions=bioengine:register_bioengine_apps"
]
subprocess.run(command)

def connect_server(args):
from bioengine import connect_server
if args.login_required:
os.environ["BIOIMAGEIO_LOGIN_REQUIRED"] = "true"
server_url = args.server_url
loop = asyncio.get_event_loop()
loop.create_task(connect_server(server_url))
Expand All @@ -33,10 +31,6 @@ def main():

subparsers = parser.add_subparsers()

# Init command
parser_init = subparsers.add_parser("init")
parser_init.set_defaults(func=init)

# Start server command
parser_start_server = subparsers.add_parser("start-server")
parser_start_server.add_argument("--host", type=str, default="0.0.0.0")
Expand Down
45 changes: 45 additions & 0 deletions bioengine/app_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from pydantic import BaseModel
from pathlib import Path
import logging
from yaml import safe_load
from typing import Callable, Optional
import sys
import importlib.util

logger = logging.getLogger(__name__)

# define runtime type enum for app runtime
class AppRuntime(str):
python = "python"
pyodide = "pyodide"
triton = "triton"

class AppInfo(BaseModel):
name: str
id: str
description: str
runtime: AppRuntime
entrypoint: Optional[str] = None

async def run(self, server):
if self.runtime == AppRuntime.python:
assert self.entrypoint

file_path = Path(__file__).parent.parent / "apps" / self.id / self.entrypoint
module_name = 'bioengine.apps.' + self.id.replace('-', '_')
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
assert hasattr(module, "hypha_startup")
await module.hypha_startup(server)

def load_apps():
current_dir = Path(__file__).parent
apps_dir = current_dir.parent / "apps"
# list folders under apps_dir
for app_dir in apps_dir.iterdir():
if app_dir.is_dir():
manifest_file = app_dir / "manifest.yaml"
if manifest_file.exists():
manifest = safe_load(manifest_file.read_text())
yield AppInfo.parse_obj(manifest)
Empty file removed bioengine/services/__init__.py
Empty file.
Loading

0 comments on commit 7d54441

Please sign in to comment.