Skip to content

Commit

Permalink
Refactor model info and shapes
Browse files Browse the repository at this point in the history
The previous design of the ModelInfo was coupled to the representation of the ModelSession message used for the grpc communication. The new design creates a more feature-rich interface to work easily with the concept of shape, and the compression of the ModelInfo needed to be transferred by the server is localized.
  • Loading branch information
thodkatz committed Aug 6, 2024
1 parent 609cccf commit 0c1cf06
Show file tree
Hide file tree
Showing 5 changed files with 363 additions and 100 deletions.
48 changes: 14 additions & 34 deletions tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from numpy.testing import assert_array_equal

from tiktorch.converters import (
NamedExplicitOutputShape,
NamedImplicitOutputShape,
NamedParametrizedShape,
input_shape_to_pb_input_shape,
numpy_to_pb_tensor,
output_shape_to_pb_output_shape,
Expand All @@ -15,6 +12,7 @@
xarray_to_pb_tensor,
)
from tiktorch.proto import inference_pb2
from tiktorch.server.session.process import AxisWithValue, ParameterizedShape, ShapeWithHalo, ShapeWithReference


def _numpy_to_pb_tensor(arr):
Expand Down Expand Up @@ -186,32 +184,12 @@ def test_should_same_data(self, shape):


class TestShapeConversions:
def to_named_explicit_shape(self, shape, axes, halo):
return NamedExplicitOutputShape(
halo=[(name, dim) for name, dim in zip(axes, halo)], shape=[(name, dim) for name, dim in zip(axes, shape)]
)

def to_named_implicit_shape(self, axes, halo, offset, scales, reference_tensor):
return NamedImplicitOutputShape(
halo=[(name, dim) for name, dim in zip(axes, halo)],
offset=[(name, dim) for name, dim in zip(axes, offset)],
scale=[(name, scale) for name, scale in zip(axes, scales)],
reference_tensor=reference_tensor,
)

def to_named_paramtrized_shape(self, min_shape, axes, step):
return NamedParametrizedShape(
min_shape=[(name, dim) for name, dim in zip(axes, min_shape)],
step_shape=[(name, dim) for name, dim in zip(axes, step)],
)

@pytest.mark.parametrize(
"shape,axes,halo",
[((42,), "x", (0,)), ((42, 128, 5), "abc", (1, 1, 1)), ((5, 4, 3, 2, 1, 42), "btzyxc", (1, 2, 3, 4, 5, 24))],
[((42,), "x", (0,)), ((42, 128, 5), "xyz", (1, 1, 1)), ((5, 4, 3, 2, 1, 42), "btzyxc", (1, 2, 3, 4, 5, 24))],
)
def test_explicit_output_shape(self, shape, axes, halo):
named_shape = self.to_named_explicit_shape(shape, axes, halo)
pb_shape = output_shape_to_pb_output_shape(named_shape)
pb_shape = output_shape_to_pb_output_shape(ShapeWithHalo.from_values(shape=shape, halo=halo, axes=axes))

assert pb_shape.shapeType == 0
assert pb_shape.referenceTensor == ""
Expand All @@ -223,11 +201,13 @@ def test_explicit_output_shape(self, shape, axes, halo):

@pytest.mark.parametrize(
"axes,halo,offset,scales,reference_tensor",
[("x", (0,), (10,), (1.0,), "forty-two"), ("abc", (1, 1, 1), (1, 2, 3), (1.0, 2.0, 3.0), "helloworld")],
[("x", (0,), (10,), (1.0,), "forty-two"), ("xyz", (1, 1, 1), (1, 2, 3), (1.0, 2.0, 3.0), "helloworld")],
)
def test_implicit_output_shape(self, axes, halo, offset, scales, reference_tensor):
named_shape = self.to_named_implicit_shape(axes, halo, offset, scales, reference_tensor)
pb_shape = output_shape_to_pb_output_shape(named_shape)
shape = ShapeWithReference.from_values(
axes=axes, halo=halo, offset=offset, scale=scales, reference_tensor=reference_tensor
)
pb_shape = output_shape_to_pb_output_shape(shape)

assert pb_shape.shapeType == 1
assert pb_shape.referenceTensor == reference_tensor
Expand All @@ -248,11 +228,10 @@ def test_output_shape_raises(self):

@pytest.mark.parametrize(
"shape,axes",
[((42,), "x"), ((42, 128, 5), "abc"), ((5, 4, 3, 2, 1, 42), "btzyxc")],
[((42,), "x"), ((42, 128, 5), "xyz"), ((5, 4, 3, 2, 1, 42), "btzyxc")],
)
def test_explicit_input_shape(self, shape, axes):
named_shape = [(name, dim) for name, dim in zip(axes, shape)]
pb_shape = input_shape_to_pb_input_shape(named_shape)
pb_shape = input_shape_to_pb_input_shape(AxisWithValue(axes=axes, values=shape))

assert pb_shape.shapeType == 0
assert [(d.name, d.size) for d in pb_shape.shape.namedInts] == [(name, size) for name, size in zip(axes, shape)]
Expand All @@ -261,13 +240,14 @@ def test_explicit_input_shape(self, shape, axes):
"min_shape,axes,step",
[
((42,), "x", (5,)),
((42, 128, 5), "abc", (1, 2, 3)),
((42, 128, 5), "xyz", (1, 2, 3)),
((5, 4, 3, 2, 1, 42), "btzyxc", (15, 24, 33, 42, 51, 642)),
],
)
def test_parametrized_input_shape(self, min_shape, axes, step):
named_shape = self.to_named_paramtrized_shape(min_shape, axes, step)
pb_shape = input_shape_to_pb_input_shape(named_shape)
pb_shape = input_shape_to_pb_input_shape(
ParameterizedShape.from_values(axes=axes, steps=step, min_shape=min_shape)
)

assert pb_shape.shapeType == 1
assert [(d.name, d.size) for d in pb_shape.shape.namedInts] == [
Expand Down
58 changes: 58 additions & 0 deletions tests/test_server/test_session/test_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest

from tiktorch.server.session.process import AxisWithValue, ParameterizedShape


@pytest.mark.parametrize(
"min_shape, step, axes, expected",
[
((512, 512), (10, 10), "yx", (512, 512)),
((256, 512), (10, 10), "yx", (256, 512)),
((256, 256), (2, 2), "yx", (512, 512)),
((128, 256), (2, 2), "yx", (384, 512)),
((64, 64, 64), (1, 1, 1), "zyx", (64, 64, 64)),
((2, 64, 64), (1, 1, 1), "zyx", (2, 64, 64)),
((2, 2, 64), (1, 1, 1), "zyx", (2, 2, 64)),
((2, 2, 32), (1, 1, 1), "zyx", (34, 34, 64)),
((42, 10, 512, 512), (0, 0, 10, 10), "tcyx", (42, 10, 512, 512)),
],
)
def test_enforce_min_shape(min_shape, step, axes, expected):
shape = ParameterizedShape.from_values(min_shape, step, axes)
assert shape.get_total_shape().values == expected


def test_param_shape_set_custom_multiplier():
min_shape = (512, 512, 256)
step = (2, 2, 2)
axes = "zyx"

shape = ParameterizedShape.from_values(min_shape, step, axes)
shape.multiplier = 2
assert shape.get_total_shape().values == (516, 516, 260)

assert shape.get_total_shape(4).values == (520, 520, 264)
assert shape.multiplier == 4

with pytest.raises(ValueError):
shape.multiplier = -1


@pytest.mark.parametrize(
"sizes, axes, spatial_axes, spatial_sizes",
[
((512, 512), "yx", "yx", (512, 512)),
((1, 256, 512), "tyx", "yx", (256, 512)),
((256, 1, 512), "ytx", "yx", (256, 512)),
((128, 256, 1), "yxt", "yx", (128, 256)),
((64, 64, 64), "zyx", "zyx", (64, 64, 64)),
((1, 2, 64, 64), "bzyx", "zyx", (2, 64, 64)),
((1, 2, 3, 64), "zbyx", "zyx", (1, 3, 64)),
((1, 2, 3, 4), "zybx", "zyx", (1, 2, 4)),
((1, 2, 3, 4, 5), "tczyx", "zyx", (3, 4, 5)),
],
)
def test_spatial_axes(sizes, axes, spatial_axes, spatial_sizes):
shape = AxisWithValue(axes, sizes)
assert shape.spatial_values == spatial_sizes
assert shape.spatial_axes == spatial_axes
55 changes: 43 additions & 12 deletions tiktorch/converters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import dataclasses
from typing import List, Tuple, Union
from typing import List, Tuple

import numpy as np
import xarray as xr

from tiktorch.proto import inference_pb2
from tiktorch.server.session.process import (
AxisWithValue,
InputShapes,
ModelInfo,
OutputShapes,
ParameterizedShape,
ShapeWithHalo,
ShapeWithReference,
)

# pairs of axis-shape for a single tensor
NamedInt = Tuple[str, int]
Expand Down Expand Up @@ -33,6 +42,28 @@ class NamedImplicitOutputShape:
halo: NamedShape


def info2session(session_id: str, model_info: ModelInfo) -> inference_pb2.ModelSession:
inputAxes = "".join(input_shape.axes for input_shape in model_info.input_shapes.values())
outputAxes = "".join(output_shape.axes for output_shape in model_info.output_shapes.values())
pb_input_shapes = [input_shape_to_pb_input_shape(shape) for shape in model_info.input_shapes.values()]
pb_output_shapes = [output_shape_to_pb_output_shape(shape) for shape in model_info.output_shapes.values()]
return inference_pb2.ModelSession(
id=session_id,
name=model_info.name,
inputAxes=inputAxes,
outputAxes=outputAxes,
inputShapes=pb_input_shapes,
outputShapes=pb_output_shapes,
hasTraining=False,
inputNames=list(model_info.input_shapes.keys()),
outputNames=list(model_info.output_shapes.keys()),
)


def session2info(model_session: inference_pb2.ModelSession) -> ModelInfo:
pass


def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor:
if axistags:
shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, axistags)]
Expand All @@ -46,44 +77,44 @@ def xarray_to_pb_tensor(array: xr.DataArray) -> inference_pb2.Tensor:
return inference_pb2.Tensor(dtype=str(array.dtype), shape=shape, buffer=bytes(array.data))


def name_int_tuples_to_pb_NamedInts(name_int_tuples) -> inference_pb2.NamedInts:
def name_int_tuples_to_pb_NamedInts(name_int_tuples: AxisWithValue[int]) -> inference_pb2.NamedInts:
return inference_pb2.NamedInts(
namedInts=[inference_pb2.NamedInt(size=dim, name=name) for name, dim in name_int_tuples]
)


def name_float_tuples_to_pb_NamedFloats(name_float_tuples) -> inference_pb2.NamedFloats:
def name_float_tuples_to_pb_NamedFloats(name_float_tuples: AxisWithValue[float]) -> inference_pb2.NamedFloats:
return inference_pb2.NamedFloats(
namedFloats=[inference_pb2.NamedFloat(size=dim, name=name) for name, dim in name_float_tuples]
)


def input_shape_to_pb_input_shape(input_shape: Union[NamedShape, NamedParametrizedShape]) -> inference_pb2.InputShape:
if isinstance(input_shape, NamedParametrizedShape):
def input_shape_to_pb_input_shape(input_shape: InputShapes) -> inference_pb2.InputShape:
if isinstance(input_shape, ParameterizedShape):
return inference_pb2.InputShape(
shapeType=1,
shape=name_int_tuples_to_pb_NamedInts(input_shape.min_shape),
stepShape=name_int_tuples_to_pb_NamedInts(input_shape.step_shape),
stepShape=name_int_tuples_to_pb_NamedInts(input_shape.steps),
)
else:
elif isinstance(input_shape, AxisWithValue):
return inference_pb2.InputShape(
shapeType=0,
shape=name_int_tuples_to_pb_NamedInts(input_shape),
)
else:
raise ValueError(f"Unexpected shape {input_shape}")


def output_shape_to_pb_output_shape(
output_shape: Union[NamedExplicitOutputShape, NamedImplicitOutputShape]
) -> inference_pb2.InputShape:
if isinstance(output_shape, NamedImplicitOutputShape):
def output_shape_to_pb_output_shape(output_shape: OutputShapes) -> inference_pb2.InputShape:
if isinstance(output_shape, ShapeWithReference):
return inference_pb2.OutputShape(
shapeType=1,
halo=name_int_tuples_to_pb_NamedInts(output_shape.halo),
referenceTensor=output_shape.reference_tensor,
scale=name_float_tuples_to_pb_NamedFloats(output_shape.scale),
offset=name_float_tuples_to_pb_NamedFloats(output_shape.offset),
)
elif isinstance(output_shape, NamedExplicitOutputShape):
elif isinstance(output_shape, ShapeWithHalo):
return inference_pb2.OutputShape(
shapeType=0,
shape=name_int_tuples_to_pb_NamedInts(output_shape.shape),
Expand Down
16 changes: 2 additions & 14 deletions tiktorch/server/grpc/inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import xarray

from tiktorch import converters
from tiktorch.converters import info2session
from tiktorch.proto import inference_pb2, inference_pb2_grpc
from tiktorch.server.data_store import IDataStore
from tiktorch.server.device_pool import DeviceStatus, IDevicePool
Expand Down Expand Up @@ -52,20 +53,7 @@ def CreateModelSession(
lease.terminate()
raise

pb_input_shapes = [converters.input_shape_to_pb_input_shape(shape) for shape in model_info.input_shapes]
pb_output_shapes = [converters.output_shape_to_pb_output_shape(shape) for shape in model_info.output_shapes]

return inference_pb2.ModelSession(
id=session.id,
name=model_info.name,
inputAxes=model_info.input_axes,
outputAxes=model_info.output_axes,
inputShapes=pb_input_shapes,
hasTraining=False,
outputShapes=pb_output_shapes,
inputNames=model_info.input_names,
outputNames=model_info.output_names,
)
return info2session(session.id, model_info)

def CreateDatasetDescription(
self, request: inference_pb2.CreateDatasetDescriptionRequest, context
Expand Down
Loading

0 comments on commit 0c1cf06

Please sign in to comment.