Skip to content

Commit

Permalink
Performance improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
pziecina-nv authored and piotrm-nvidia committed Apr 12, 2024
1 parent 701203c commit fe39f62
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 71 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ limitations under the License.

# Changelog

## Unreleased

- Fix: Performance improvements

[//]: <> (put here on external component update with short summary what change or link to changelog)

- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.44.0](https://github.com/triton-inference-server/server/releases/tag/v2.44.0)

## 0.5.4 (2024-04-09)

- New: Python 3.12 support
Expand Down
8 changes: 6 additions & 2 deletions pytriton/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import enum
import logging
import os
import pathlib
import shutil
import threading
Expand All @@ -30,7 +31,7 @@
from pytriton.model_config.tensor import Tensor
from pytriton.model_config.triton_model_config import DeviceKind, ResponseCache, TensorSpec, TritonModelConfig
from pytriton.proxy.communication import get_config_from_handshake_server
from pytriton.proxy.data import TensorStoreSerializerDeserializer
from pytriton.proxy.data import Base64SerializerDeserializer, TensorStoreSerializerDeserializer
from pytriton.proxy.inference import InferenceHandler, InferenceHandlerEvent, RequestsResponsesConnector
from pytriton.proxy.validators import TritonResultsValidator
from pytriton.utils.workspace import Workspace
Expand Down Expand Up @@ -119,7 +120,10 @@ def __init__(

self.config = config
self._workspace = workspace
self._serializer_deserializer = TensorStoreSerializerDeserializer()
if os.environ.get("PYTRITON_NO_TENSORSTORE"):
self._serializer_deserializer = Base64SerializerDeserializer()
else:
self._serializer_deserializer = TensorStoreSerializerDeserializer()
self._triton_model_config: Optional[TritonModelConfig] = None
self._model_events_observers: typing.List[ModelEventsHandler] = []

Expand Down
4 changes: 0 additions & 4 deletions pytriton/proxy/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def _all_responses_processed():
requests_id, flags, responses_payload = await asyncio.wait_for(
self._socket.recv_multipart(), flag_check_interval_s
)
SERVER_LOGGER.debug(f"Received response {requests_id.hex()} {flags.hex()} {responses_payload}")
flags = int.from_bytes(flags, byteorder="big")
responses_queue = self._responses_queues[requests_id]
responses_queue.put_nowait((flags, responses_payload)) # queue have no max_size
Expand Down Expand Up @@ -266,7 +265,6 @@ async def send_requests(self, requests_id: bytes, requests_payload: bytes) -> as
# sending in same loop, thus thread as handle_messages
# send_multipart doesn't return anything, as it copies requests_payload
await self._socket.send_multipart([requests_id, requests_payload])
SERVER_LOGGER.debug(f"Sent requests {requests_id.hex()}")

return handle_responses_task

Expand All @@ -278,13 +276,11 @@ async def _handle_responses(self, scope, responses_queue: asyncio.Queue):
responses_queue: queue with responses payload from InferenceHandler
"""
requests_id = scope["requests_id"]
SERVER_LOGGER.debug(f"Started handling responses {requests_id.hex()}")
try:
return await self._handle_responses_fn(scope, responses_queue)
finally:
self._responses_queues.pop(requests_id)
self._handle_responses_tasks.pop(requests_id)
SERVER_LOGGER.debug(f"Finished handling responses {requests_id.hex()}")


class RequestsServerClient:
Expand Down
136 changes: 134 additions & 2 deletions pytriton/proxy/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import abc
import atexit
import base64
import ctypes
import ctypes.util
import dataclasses
Expand All @@ -29,6 +30,7 @@
import multiprocessing.managers
import multiprocessing.popen_spawn_posix
import multiprocessing.shared_memory
import os
import pathlib
import signal
import struct
Expand Down Expand Up @@ -379,6 +381,14 @@ def _run_server(cls, registry, address, authkey, serializer, writer, initializer
PR_SET_PDEATHSIG = 1 # noqa
libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
libc.prctl(PR_SET_PDEATHSIG, signal.SIGTERM) # terminate process when parent **thread** dies

if bool(os.environ.get("PYTRITON_VIZTRACER")):
from viztracer import VizTracer # type: ignore # pytype: disable=import-error

cls._tracer = VizTracer(log_async=True, log_gc=True, tracer_entries=10000000, pid_suffix=True)
cls._tracer.register_exit()
cls._tracer.start()

super()._run_server(
registry, address, authkey, serializer, writer, initializer, initargs
) # pytype: disable=attribute-error
Expand Down Expand Up @@ -662,8 +672,6 @@ def release_block(self, tensor_id: str):
Args:
tensor_id: id of tensor to release
"""
LOGGER.debug(f"Releasing shared memory block for tensor {tensor_id}")

tensor_ref = None
with self._handled_blocks_lock:
tensor_ref = self._handled_blocks.pop(tensor_id, None)
Expand Down Expand Up @@ -832,6 +840,130 @@ def free_responses_resources(self, responses_payload: bytes):
pass


class Base64SerializerDeserializer(BaseRequestsResponsesSerializerDeserializer):
"""Serializer/deserializer for requests/responses using base64 implementation."""

def serialize_requests(self, requests: Requests) -> bytes:
"""Serialize requests.
Args:
requests: list of requests to serialize
Returns:
Serialized requests
"""
serialized_requests = self._serialize_named_tensors_lists(requests)
requests = {
"requests": [
{"data": serialized_request, "parameters": request.parameters}
for request, serialized_request in zip(requests, serialized_requests)
]
}
requests = json.dumps(requests).encode("utf-8")
return requests

def deserialize_requests(self, requests_payload: bytes) -> Requests:
"""Deserialize requests.
Args:
requests_payload: serialized requests
Returns:
List of deserialized requests
"""
requests = json.loads(requests_payload)
requests_data = [request["data"] for request in requests["requests"]]
requests_data = self._deserialized_named_tensors_lists(requests_data)

requests = [
Request(
data=request_data,
parameters=request.get("parameters"),
)
for request, request_data in zip(requests["requests"], requests_data)
]
return requests

def free_requests_resources(self, requests_payload: bytes):
"""Free resources used by requests."""
pass

def serialize_responses(self, responses: Responses) -> bytes:
"""Serialize responses.
Args:
responses: list of responses to serialize
Returns:
Serialized responses
"""
responses = self._serialize_named_tensors_lists(responses)
responses = {"responses": [{"data": response} for response in responses]}
return json.dumps(responses).encode("utf-8")

def deserialize_responses(self, responses_payload: bytes) -> Responses:
"""Deserialize responses.
Args:
responses_payload: serialized responses
Returns:
List of deserialized responses
"""
if responses_payload:
responses = json.loads(responses_payload)
responses = [response["data"] for response in responses["responses"]]
responses = self._deserialized_named_tensors_lists(responses)
return [Response(data=response) for response in responses]
else:
return []

def free_responses_resources(self, responses_payload: bytes):
"""Free resources used by responses."""
pass

def _serialize_named_tensors_lists(self, named_tensors_lists):
def _encode(_tensor):
frames = serialize_numpy_with_struct_header(_tensor)
return [base64.b64encode(frame).decode("utf-8") for frame in frames]

return [
{tensor_name: _encode(tensor) for tensor_name, tensor in tensors.items()} for tensors in named_tensors_lists
]

def _deserialized_named_tensors_lists(self, named_tensors_lists):
def _decode(decoded_tensor):
frames = [base64.b64decode(frame.encode("utf-8")) for frame in decoded_tensor]
return deserialize_numpy_with_struct_header(frames)

return [
{tensor_name: _decode(encoded_tensor) for tensor_name, encoded_tensor in tensors.items()}
for tensors in named_tensors_lists
]

def start(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
"""Start Dummy implementation.
Args:
url: address of data store
authkey: authentication key required to setup connection. If not provided, current process authkey will be used
"""
pass

def connect(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
"""Connect to Dummy implementation.
Args:
url: address of data store
authkey: authentication key required to setup connection. If not provided, current process authkey will be used
"""
pass

def close(self):
"""Close Dummy implementation."""
pass


class TensorStoreSerializerDeserializer(BaseRequestsResponsesSerializerDeserializer):
"""Serializer/deserializer for requests/responses using TensorStore."""

Expand Down
19 changes: 7 additions & 12 deletions pytriton/proxy/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _wait_for_inference_fn(timeout_s: float):
)

try:
requests = await self.preprocess(scope, requests_payload)
requests = self.preprocess(scope, requests_payload)

if self._run_inference_fn is None:
await loop.run_in_executor(None, _wait_for_inference_fn, self.INFERENCE_FN_REGISTER_WAIT_TIME_S)
Expand All @@ -196,7 +196,7 @@ def _wait_for_inference_fn(timeout_s: float):
await send(scope, flags, error_msg)
break

responses_payload = await self.postprocess(scope, responses_or_error)
responses_payload = self.postprocess(scope, responses_or_error)

await send(scope, flags, responses_payload)
if flags & PyTritonResponseFlags.EOS:
Expand All @@ -208,11 +208,11 @@ def _wait_for_inference_fn(timeout_s: float):
flags = PyTritonResponseFlags.ERROR | PyTritonResponseFlags.EOS
await send(scope, flags, error_msg)
finally:
await loop.run_in_executor(None, self._serializer_deserializer.free_requests_resources, requests_payload)
self._serializer_deserializer.free_requests_resources(requests_payload)
self._responses_queues.pop(requests_id)
LOGGER.debug(f"Finished handling requests for {scope['requests_id'].hex()}")

async def preprocess(self, scope: Scope, requests_payload: bytes) -> Requests:
def preprocess(self, scope: Scope, requests_payload: bytes) -> Requests:
"""Preprocess requests before running inference on them.
Currently, this method only deserializes requests.
Expand All @@ -224,11 +224,9 @@ async def preprocess(self, scope: Scope, requests_payload: bytes) -> Requests:
Returns:
deserialized requests
"""
LOGGER.debug(f"Preprocessing requests for {scope['requests_id'].hex()}")
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self._serializer_deserializer.deserialize_requests, requests_payload)
return self._serializer_deserializer.deserialize_requests(requests_payload)

async def postprocess(self, scope: Scope, responses: Responses) -> bytes:
def postprocess(self, scope: Scope, responses: Responses) -> bytes:
"""Postprocess responses before sending them back to Triton.
Currently, this method only serializes responses.
Expand All @@ -240,12 +238,10 @@ async def postprocess(self, scope: Scope, responses: Responses) -> bytes:
Returns:
serialized responses
"""
LOGGER.debug(f"Postprocessing responses for {scope['requests_id'].hex()}")
if responses is None:
return b""
else:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self._serializer_deserializer.serialize_responses, responses)
return self._serializer_deserializer.serialize_responses(responses)

def register_inference_hook(self, run_inference_fn: typing.Callable[[Scope, Requests], concurrent.futures.Future]):
"""Register inference hook.
Expand Down Expand Up @@ -278,7 +274,6 @@ def send(self, scope: Scope, flags: PyTritonResponseFlags, responses: ResponsesN
responses: responses to send back to server
"""
requests_id = scope["requests_id"]
LOGGER.debug(f"Pushing responses for {scope['requests_id'].hex()} into responses queue ({flags}, {responses})")
queue = self._responses_queues[requests_id]
loop = self._requests_server_client.loop
# use no_wait as there is no limit for responses queues
Expand Down
Loading

0 comments on commit fe39f62

Please sign in to comment.