Skip to content

Commit

Permalink
Fix access to model config dictionary and usage in fill_optionals dec…
Browse files Browse the repository at this point in the history
…orator
  • Loading branch information
nv-blazejkubiak committed Feb 21, 2023
1 parent e3cc225 commit a013c49
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 23 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ limitations under the License.

# Changelog

## 0.1.3 (2023-02-20)
- Fixed getting model config in `fill_optionals` decorator.

[//]: <> (put here on external component update with short summary what change or link to changelog)
- Version of external components used during testing:
- [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
- Other component versions depend on the used framework and Triton Inference Server containers versions.
Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
for a detailed summary.

## 0.1.2 (2023-02-14)
- Fixed wheel build to support installations on operating systems with glibc version 2.31 or higher.
- Updated the documentation on custom builds of the package.
Expand Down
2 changes: 1 addition & 1 deletion pytriton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# noqa: D104
__version__ = "0.1.2"
__version__ = "0.1.3"
from pytriton import client # noqa: F401
from pytriton import model_config # noqa: F401
from pytriton import triton # noqa: F401
66 changes: 52 additions & 14 deletions pytriton/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import operator
import typing
from bisect import bisect_left
from collections.abc import MutableMapping
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -62,11 +63,52 @@ def _get_wrapt_stack(wrapped) -> List[_WrappedWithWrapper]:
return stack


class ModelConfigDict(MutableMapping):
"""Dictionary for storing model configs for inference callable."""

def __init__(self):
"""Create ModelConfigDict object."""
self._data: Dict[str, TritonModelConfig] = {}
self._keys: List[Callable] = []

def __getitem__(self, infer_callable: Callable) -> TritonModelConfig:
"""Get model config for inference callable."""
key = self._get_model_config_key(infer_callable)
return self._data[key]

def __setitem__(self, infer_callable: Callable, item: TritonModelConfig):
"""Set model config for inference callable."""
self._keys.append(infer_callable)
key = self._get_model_config_key(infer_callable)
self._data[key] = item

def __delitem__(self, infer_callable: Callable):
"""Delete model config for inference callable."""
key = self._get_model_config_key(infer_callable)
del self._data[key]

def __len__(self):
"""Get number of inference callable keys."""
return len(self._data)

def __iter__(self):
"""Iterate over inference callable keys."""
return iter(self._keys)

@staticmethod
def _get_model_config_key(infer_callable: Callable) -> str:
"""Prepares TritonModelConfig dictionary key for function/callable."""
dict_key = infer_callable
if inspect.ismethod(dict_key) and dict_key.__name__ == "__call__":
dict_key = dict_key.__self__
return str(dict_key)


@dataclasses.dataclass
class TritonContext:
"""Triton context definition class."""

model_configs: Dict[str, TritonModelConfig] = dataclasses.field(default_factory=dict)
model_configs: ModelConfigDict = dataclasses.field(default_factory=ModelConfigDict)


def get_triton_context(wrapped, instance) -> TritonContext:
Expand All @@ -82,17 +124,14 @@ def get_triton_context(wrapped, instance) -> TritonContext:


def get_model_config(wrapped, instance) -> TritonModelConfig:
"""Retrieves triton model config from callable.
"""Retrieves instance of TritonModelConfig from callable.
It is internally used in convert_output function to get output list from model.
You can use this in custom decorators if you need access to model_config information.
If you use @triton_context decorator you do not need this function (you can get model_config from triton_context).
If you use @triton_context decorator you do not need this function (you can get model_config directly
from triton_context passing function/callable to dictionary getter).
"""
dict_key = wrapped
if inspect.ismethod(dict_key) and dict_key.__name__ == "__call__":
dict_key = dict_key.__self__
dict_key = str(dict_key)
return get_triton_context(wrapped, instance).model_configs[dict_key]
return get_triton_context(wrapped, instance).model_configs[wrapped]


def convert_output(
Expand Down Expand Up @@ -279,8 +318,8 @@ def infer_fun(**inputs):
so the other decorators (e.g. @group_by_keys) can make bigger consistent groups.
"""

def _verify_defaults(_triton_context: TritonContext, model_callable: Callable):
inputs = {spec.name: spec for spec in _triton_context.model_configs[str(model_callable)].inputs}
def _verify_defaults(model_config: TritonModelConfig):
inputs = {spec.name: spec for spec in model_config.inputs}
not_matching_default_names = sorted(set(defaults) - set(inputs))
if not_matching_default_names:
raise PyTritonBadParameterError(f"Could not found {', '.join(not_matching_default_names)} inputs")
Expand Down Expand Up @@ -323,14 +362,13 @@ def _shape_match(_have_shape, _expected_shape):

@wrapt.decorator
def _wrapper(wrapped, instance, args, kwargs):
_triton_context = get_triton_context(wrapped, instance)

_verify_defaults(_triton_context, wrapped)
model_config = get_model_config(wrapped, instance)
_verify_defaults(model_config)
# verification if not after group wrappers is in group wrappers

(requests,) = args

model_supports_batching = get_model_config(wrapped, instance).batching
model_supports_batching = model_config.batching
for request in requests:
batch_size = get_inference_request_batch_size(request) if model_supports_batching else None
for default_key, default_value in defaults.items():
Expand Down
7 changes: 1 addition & 6 deletions pytriton/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Model base class."""
import copy
import enum
import inspect
import logging
import pathlib
import shutil
Expand Down Expand Up @@ -155,11 +154,7 @@ def setup(self) -> None:
if not self._inference_handlers:
triton_model_config = self._get_triton_model_config()
for i, infer_function in enumerate(self.infer_functions):
dict_key = infer_function
if inspect.ismethod(dict_key) and dict_key.__name__ == "__call__":
dict_key = dict_key.__self__
dict_key = str(dict_key)
self.triton_context.model_configs[dict_key] = copy.deepcopy(triton_model_config)
self.triton_context.model_configs[infer_function] = copy.deepcopy(triton_model_config)
_inject_triton_context(self.triton_context, infer_function)
inference_handler = InferenceHandler(
model_callable=infer_function,
Expand Down
71 changes: 69 additions & 2 deletions tests/unit/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
InferenceRequest,
InferenceRequests,
InputNames,
ModelConfigDict,
TritonContext,
batch,
fill_optionals,
Expand Down Expand Up @@ -62,11 +63,51 @@

def _prepare_and_inject_context_with_config(config, fun):
context = TritonContext()
context.model_configs[str(fun)] = config
context.model_configs[fun] = config
_inject_triton_context(context, fun)
return context


def test_get_model_config_key():
def fn():
pass

def fn2():
pass

class CallableClass:
def __call__(self):
pass

def method(self):
pass

inst = CallableClass()
inst2 = CallableClass()

assert ModelConfigDict._get_model_config_key(fn) == str(fn)
assert ModelConfigDict._get_model_config_key(inst) == str(inst)
assert ModelConfigDict._get_model_config_key(inst.method) == str(inst.method)
assert ModelConfigDict._get_model_config_key(inst.__call__) == str(inst)

config_dict = ModelConfigDict()
config_dict[fn] = TritonModelConfig(model_name="fn")
config_dict[fn2] = TritonModelConfig(model_name="fn2")
assert config_dict[fn] == TritonModelConfig(model_name="fn")
assert config_dict[fn] != config_dict[fn2]

config_dict[inst] = TritonModelConfig(model_name="inst")
config_dict[inst2] = TritonModelConfig(model_name="inst2")
assert config_dict[inst] == TritonModelConfig(model_name="inst")
assert config_dict[inst] != config_dict[inst2]

keys = {fn, fn2, inst, inst2}
keys1 = set(config_dict.keys())
keys2 = set(iter(config_dict))
assert keys == keys1
assert keys == keys2


def _prepare_context_for_input(inputs, fun):
a_input = inputs[0]["a"]
b_input = inputs[0]["b"]
Expand All @@ -76,7 +117,7 @@ def _prepare_context_for_input(inputs, fun):

config = TritonModelConfig("a", inputs=[a_spec, b_spec], outputs=[a_spec, b_spec])
context = TritonContext()
context.model_configs[str(fun)] = config
context.model_configs[fun] = config

return context

Expand Down Expand Up @@ -497,6 +538,32 @@ def _fn(**_requests):
_fn([{"a": np.zeros((1,))}, {"a": np.zeros((1,))}])


def test_fill_optionals_in_instance_callable():
class MyModel:
@fill_optionals(a=np.array([-1, -2]), b=np.array([-5, -6]))
def __call__(self, inputs):
for req in inputs:
assert "a" in req and "b" in req
assert req["a"].shape[0] == req["b"].shape[0]
assert np.all(inputs[1]["a"] == np.array([[-1, -2], [-1, -2]]))
assert np.all(inputs[-1]["b"] == np.array([[-5, -6], [-5, -6], [-5, -6]]))
return inputs

model = MyModel()

_prepare_and_inject_context_with_config(
TritonModelConfig(
model_name="foo",
inputs=[TensorSpec("a", shape=(2,), dtype=np.int64), TensorSpec("b", shape=(2,), dtype=np.int64)],
outputs=[TensorSpec("a", shape=(2,), dtype=np.int64), TensorSpec("b", shape=(2,), dtype=np.int64)],
),
model.__call__,
)

results = model(input_requests)
assert len(results) == len(input_requests)


def test_fill_optionals():
@fill_optionals(a=np.array([-1, -2]), b=np.array([-5, -6]))
def fill_fun(inputs):
Expand Down

0 comments on commit a013c49

Please sign in to comment.