Skip to content

Commit

Permalink
Encode network's functionality within fixture names
Browse files Browse the repository at this point in the history
Fixtures are bioimageio models with dummy networks that do simple
operations such as add one to the input. To test that, and to know what
a particular model does, we encode it within the name of the fixture
e.g. modelAddOne
  • Loading branch information
thodkatz committed Oct 11, 2024
1 parent 42210b7 commit e5f786a
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 418 deletions.
6 changes: 0 additions & 6 deletions proto/inference.proto
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,3 @@ service FlightControl {
rpc Shutdown(Empty) returns (Empty) {}
}

message CreateModelSessionChunkedRequest {
oneof data {
ModelInfo info = 1;
Blob chunk = 2;
}
}
130 changes: 60 additions & 70 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
from os import getenv
from pathlib import Path
from random import randint
from typing import List, Tuple
from typing import List

import numpy as np
import pytest
import torch
import xarray as xr
from bioimageio.core import AxisId
from bioimageio.spec import save_bioimageio_package_to_stream
from bioimageio.spec.model import v0_4
Expand Down Expand Up @@ -103,24 +102,24 @@ def assert_threads_cleanup():


@pytest.fixture(params=[WeightsFormat.PYTORCH, WeightsFormat.TORCHSCRIPT])
def bioimage_model_explicit_siso(request) -> Tuple[io.BytesIO, xr.DataArray]:
def bioimage_model_explicit_add_one_siso_v5(request) -> io.BytesIO:
input_axes = [
BatchAxis(),
ChannelAxis(channel_names=[Identifier("channel1"), Identifier("channel2")]),
SpaceInputAxis(id=AxisId("x"), size=10),
SpaceInputAxis(id=AxisId("y"), size=10),
SpaceInputAxis(id=AxisId("y"), size=20),
]
input_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10)
input_test_tensor = np.arange(1 * 2 * 10 * 20, dtype="float32").reshape(1, 2, 10, 20)
if request.param == WeightsFormat.PYTORCH:
return _bioimage_model_dummy_v5_siso_pytorch(input_axes, input_test_tensor)
return _bioimage_model_dummy_add_one_siso_pytorch_v5(input_axes, input_test_tensor)
elif request.param == WeightsFormat.TORCHSCRIPT:
return _bioimage_model_dummy_v5_siso_torchscript(input_axes, input_test_tensor)
return _bioimage_model_dummy_add_one_siso_torchscript_v5(input_axes, input_test_tensor)
else:
raise NotImplementedError(f"{request.param}")


@pytest.fixture(params=[WeightsFormat.PYTORCH, WeightsFormat.TORCHSCRIPT])
def bioimage_model_param_siso(request) -> Tuple[io.BytesIO, xr.DataArray]:
def bioimage_model_param_add_one_siso_v5(request) -> io.BytesIO:
input_test_tensor = np.arange(1 * 2 * 10 * 20, dtype="float32").reshape(1, 2, 10, 20)
input_axes = [
BatchAxis(),
Expand All @@ -129,15 +128,15 @@ def bioimage_model_param_siso(request) -> Tuple[io.BytesIO, xr.DataArray]:
SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=20, step=3)),
]
if request.param == WeightsFormat.PYTORCH:
return _bioimage_model_dummy_v5_siso_pytorch(input_axes, input_test_tensor)
return _bioimage_model_dummy_add_one_siso_pytorch_v5(input_axes, input_test_tensor)
elif request.param == WeightsFormat.TORCHSCRIPT:
return _bioimage_model_dummy_v5_siso_torchscript(input_axes, input_test_tensor)
return _bioimage_model_dummy_add_one_siso_torchscript_v5(input_axes, input_test_tensor)
else:
raise NotImplementedError(f"{request.param}")


@pytest.fixture
def bioimage_model_miso() -> Tuple[io.BytesIO, xr.DataArray]:
def bioimage_model_add_one_miso_v5() -> io.BytesIO:
"""
Mocked bioimageio prediction pipeline with three inputs single output
"""
Expand Down Expand Up @@ -188,16 +187,15 @@ def bioimage_model_miso() -> Tuple[io.BytesIO, xr.DataArray]:
test_tensor=FileDescr(source=Path(test_tensor3_file.name)),
)

dummy_model = _DummyNetwork()
expected_output = _dummy_network_output
dummy_network = _DummyNetworkMultipleInputAddOne()
with tempfile.NamedTemporaryFile(suffix=".pts", delete=False) as weights_file:
torch.save(dummy_model.state_dict(), weights_file.name)
torch.save(dummy_network.state_dict(), weights_file.name)
weights = WeightsDescr(
pytorch_state_dict=PytorchStateDictWeightsDescr(
source=Path(weights_file.name),
architecture=ArchitectureFromLibraryDescr(
import_from="tests.conftest",
callable=Identifier(f"{_DummyNetwork.__name__}"),
callable=Identifier(f"{_DummyNetworkMultipleInputAddOne.__name__}"),
),
pytorch_version=Version("1.1.1"),
)
Expand All @@ -222,15 +220,14 @@ def bioimage_model_miso() -> Tuple[io.BytesIO, xr.DataArray]:
)

model_bytes = _bioimage_model_v5(weights=weights, inputs=[input1, input2, input3], outputs=[output_tensor])
return model_bytes, expected_output
return model_bytes


def _bioimage_model_dummy_v5_siso_torchscript(
def _bioimage_model_dummy_add_one_siso_torchscript_v5(
input_axes: List[InputAxis], input_test_tensor: np.ndarray
) -> Tuple[io.BytesIO, xr.DataArray]:
dummy_model = _DummyNetwork()
expected_output = _dummy_network_output
traced_model = torch.jit.trace(dummy_model, example_inputs=torch.from_numpy(input_test_tensor))
) -> io.BytesIO:
dummy_network = _DummyNetworkSingleInputAddOne()
traced_model = torch.jit.trace(dummy_network, example_inputs=torch.from_numpy(input_test_tensor))
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as model_file:
traced_model.save(model_file.name)
weights = WeightsDescr(
Expand All @@ -245,31 +242,27 @@ def _bioimage_model_dummy_v5_siso_torchscript(
SpaceOutputAxis(id=AxisId("y"), size=20),
]

return (
_bioimage_model_v5_siso(
weights=weights,
input_axes=input_axes,
output_axes=output_axes,
input_test_tensor=input_test_tensor,
output_test_tensor=output_test_tensor,
),
expected_output,
return _bioimage_model_siso_v5(
weights=weights,
input_axes=input_axes,
output_axes=output_axes,
input_test_tensor=input_test_tensor,
output_test_tensor=output_test_tensor,
)


def _bioimage_model_dummy_v5_siso_pytorch(
def _bioimage_model_dummy_add_one_siso_pytorch_v5(
input_axes: List[InputAxis], input_test_tensor: np.ndarray
) -> Tuple[io.BytesIO, xr.DataArray]:
dummy_model = _DummyNetwork()
expected_output = _dummy_network_output
) -> io.BytesIO:
dummy_network = _DummyNetworkSingleInputAddOne()
with tempfile.NamedTemporaryFile(suffix=".pts", delete=False) as weights_file:
torch.save(dummy_model.state_dict(), weights_file.name)
torch.save(dummy_network.state_dict(), weights_file.name)
weights = WeightsDescr(
pytorch_state_dict=PytorchStateDictWeightsDescr(
source=Path(weights_file.name),
architecture=ArchitectureFromLibraryDescr(
import_from="tests.conftest",
callable=Identifier(f"{_DummyNetwork.__name__}"),
callable=Identifier(f"{_DummyNetworkSingleInputAddOne.__name__}"),
),
pytorch_version=Version("1.1.1"),
)
Expand All @@ -283,19 +276,16 @@ def _bioimage_model_dummy_v5_siso_pytorch(
SpaceOutputAxis(id=AxisId("y"), size=20),
]

return (
_bioimage_model_v5_siso(
weights=weights,
input_axes=input_axes,
output_axes=output_axes,
input_test_tensor=input_test_tensor,
output_test_tensor=output_test_tensor,
),
expected_output,
return _bioimage_model_siso_v5(
weights=weights,
input_axes=input_axes,
output_axes=output_axes,
input_test_tensor=input_test_tensor,
output_test_tensor=output_test_tensor,
)


def _bioimage_model_v5_siso(
def _bioimage_model_siso_v5(
weights: WeightsDescr,
input_axes: List[InputAxis],
output_axes: List[OutputAxis],
Expand Down Expand Up @@ -346,46 +336,44 @@ def _bioimage_model_v5(


@pytest.fixture(params=[WeightsFormat.PYTORCH, WeightsFormat.TORCHSCRIPT])
def bioimage_model_v4(request) -> Tuple[io.BytesIO, xr.DataArray]:
def bioimage_model_add_one_v4(request) -> io.BytesIO:
if request.param == WeightsFormat.PYTORCH:
return _bioimage_model_dummy_v4_siso_pytorch()
return _bioimage_model_dummy_add_one_siso_pytorch_v4()
elif request.param == WeightsFormat.TORCHSCRIPT:
return _bioimage_model_dummy_v4_siso_torchscript()
return _bioimage_model_dummy_add_one_siso_torchscript_v4()
else:
raise NotImplementedError(f"{request.param}")


def _bioimage_model_dummy_v4_siso_pytorch() -> Tuple[io.BytesIO, xr.DataArray]:
dummy_model = _DummyNetwork()
dummy_model_expected_output = _dummy_network_output
input_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10)
output_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10)
traced_model = torch.jit.trace(dummy_model, example_inputs=torch.from_numpy(input_test_tensor))
def _bioimage_model_dummy_add_one_siso_pytorch_v4() -> io.BytesIO:
dummy_network = _DummyNetworkSingleInputAddOne()
input_test_tensor = np.arange(1 * 2 * 10 * 20, dtype="float32").reshape(1, 2, 10, 20)
output_test_tensor = np.arange(1 * 2 * 10 * 20, dtype="float32").reshape(1, 2, 10, 20)
traced_model = torch.jit.trace(dummy_network, example_inputs=torch.from_numpy(input_test_tensor))
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as weights_file:
traced_model.save(weights_file.name)
weights = v0_4.WeightsDescr(torchscript=v0_4.TorchscriptWeightsDescr(source=Path(weights_file.name)))
model_bytes = _bioimage_model_v4_siso(
model_bytes = _bioimage_model_siso_v4(
weights=weights, input_test_tensor=input_test_tensor, output_test_tensor=output_test_tensor
)
return model_bytes, dummy_model_expected_output
return model_bytes


def _bioimage_model_dummy_v4_siso_torchscript() -> Tuple[io.BytesIO, xr.DataArray]:
dummy_model = _DummyNetwork()
dummy_model_expected_output = _dummy_network_output
input_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10)
output_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10)
traced_model = torch.jit.trace(dummy_model, example_inputs=torch.from_numpy(input_test_tensor))
def _bioimage_model_dummy_add_one_siso_torchscript_v4() -> io.BytesIO:
dummy_network = _DummyNetworkSingleInputAddOne()
input_test_tensor = np.arange(1 * 2 * 10 * 20, dtype="float32").reshape(1, 2, 10, 20)
output_test_tensor = np.arange(1 * 2 * 10 * 20, dtype="float32").reshape(1, 2, 10, 20)
traced_model = torch.jit.trace(dummy_network, example_inputs=torch.from_numpy(input_test_tensor))
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as model_file:
traced_model.save(model_file.name)
weights = v0_4.WeightsDescr(torchscript=v0_4.TorchscriptWeightsDescr(source=Path(model_file.name)))
model_bytes = _bioimage_model_v4_siso(
model_bytes = _bioimage_model_siso_v4(
weights=weights, input_test_tensor=input_test_tensor, output_test_tensor=output_test_tensor
)
return model_bytes, dummy_model_expected_output
return model_bytes


def _bioimage_model_v4_siso(
def _bioimage_model_siso_v4(
weights: v0_4.WeightsDescr, input_test_tensor: np.ndarray, output_test_tensor: np.ndarray
) -> io.BytesIO:
input_tensor = v0_4.InputTensorDescr(
Expand Down Expand Up @@ -422,9 +410,11 @@ def _bioimage_model_v4_siso(
return model_bytes


_dummy_network_output = xr.DataArray(np.arange(2 * 10 * 10).reshape(1, 2, 10, 10), dims=["batch", "channel", "x", "y"])
class _DummyNetworkSingleInputAddOne(nn.Module):
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor + 1


class _DummyNetwork(nn.Module):
def forward(self, *args):
return torch.from_numpy(_dummy_network_output.values)
class _DummyNetworkMultipleInputAddOne(nn.Module):
def forward(self, *tensors) -> torch.Tensor:
return tensors[0] + 1
Loading

0 comments on commit e5f786a

Please sign in to comment.