diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d7fefe7..bcaeb1ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,9 @@ limitations under the License. # Changelog -## Unreleased +## 0.5.7 (2024-06-19) + +- New: Open Telemetry propagation support for tracing [//]: <> (put here on external component update with short summary what change or link to changelog) diff --git a/README.md b/README.md index 8f75b394..a0e20e8b 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ The distinct capabilities of PyTriton are summarized in the feature matrix: | ------- | ----------- | | Native Python support | You can create any [Python function](https://triton-inference-server.github.io/pytriton/latest/inference_callables/) and expose it as an HTTP/gRPC API. | | Framework-agnostic | You can run any Python code with any framework of your choice, such as: PyTorch, TensorFlow, or JAX. | -| Performance optimization | You can benefit from [dynamic batching](https://triton-inference-server.github.io/pytriton/latest/inference_callables/decorators/#batch), response cache, model pipelining, [clusters](https://triton-inference-server.github.io/pytriton/latest/guides/deploying_in_clusters/), and GPU/CPU inference. | +| Performance optimization | You can benefit from [dynamic batching](https://triton-inference-server.github.io/pytriton/latest/inference_callables/decorators/#batch), response cache, model pipelining, [clusters](https://triton-inference-server.github.io/pytriton/latest/guides/deploying_in_clusters/), performance [tracing](https://triton-inference-server.github.io/pytriton/latest/guides/distributed_tracing/), and GPU/CPU inference. | Decorators | You can use batching [decorators](https://triton-inference-server.github.io/pytriton/latest/inference_callables/decorators/) to handle batching and other pre-processing tasks for your inference function. | | Easy [installation](https://triton-inference-server.github.io/pytriton/latest/installation/) and setup | You can use a simple and familiar interface based on Flask/FastAPI for easy installation and [setup](https://triton-inference-server.github.io/pytriton/latest/binding_models/). | | [Model clients](https://triton-inference-server.github.io/pytriton/latest/clients) | You can access high-level model clients for HTTP/gRPC requests with configurable options and both synchronous and [asynchronous](https://triton-inference-server.github.io/pytriton/latest/clients/#asynciomodelclient) API. | diff --git a/docs/README.md b/docs/README.md index 6faa09a5..3fc67179 100644 --- a/docs/README.md +++ b/docs/README.md @@ -202,6 +202,8 @@ used to profile models served through PyTriton. We have prepared an example of using the Perf Analyzer to profile the BART PyTorch model. The example code can be found in [examples/perf_analyzer](../examples/perf_analyzer). +Open Telemetry is a set of APIs, libraries, agents, and instrumentation to provide observability for cloud-native software. We have prepared an +[guide](guides/distributed_tracing.md) on how to use Open Telemetry with PyTriton. ## What next? diff --git a/docs/assets/jaeger_traces_list.png b/docs/assets/jaeger_traces_list.png new file mode 100644 index 00000000..d68b821e Binary files /dev/null and b/docs/assets/jaeger_traces_list.png differ diff --git a/docs/assets/jaeger_traces_list_only_triton.png b/docs/assets/jaeger_traces_list_only_triton.png new file mode 100644 index 00000000..633afbbd Binary files /dev/null and b/docs/assets/jaeger_traces_list_only_triton.png differ diff --git a/docs/assets/jaeger_traces_list_propagation.png b/docs/assets/jaeger_traces_list_propagation.png new file mode 100644 index 00000000..c60bbf85 Binary files /dev/null and b/docs/assets/jaeger_traces_list_propagation.png differ diff --git a/docs/guides/assets/jaeger_trace_details.png b/docs/guides/assets/jaeger_trace_details.png deleted file mode 100644 index deb1d8e8..00000000 Binary files a/docs/guides/assets/jaeger_trace_details.png and /dev/null differ diff --git a/docs/guides/assets/jaeger_traces_list.png b/docs/guides/assets/jaeger_traces_list.png deleted file mode 100644 index eb06a20f..00000000 Binary files a/docs/guides/assets/jaeger_traces_list.png and /dev/null differ diff --git a/docs/guides/assets/jaegger_context_propagation_request_details.png b/docs/guides/assets/jaegger_context_propagation_request_details.png deleted file mode 100644 index d62164ae..00000000 Binary files a/docs/guides/assets/jaegger_context_propagation_request_details.png and /dev/null differ diff --git a/docs/guides/assets/jaegger_context_propagation_traces_list.png b/docs/guides/assets/jaegger_context_propagation_traces_list.png deleted file mode 100644 index 07484d52..00000000 Binary files a/docs/guides/assets/jaegger_context_propagation_traces_list.png and /dev/null differ diff --git a/docs/guides/distributed_tracing.md b/docs/guides/distributed_tracing.md index be7a6a90..8da41b55 100644 --- a/docs/guides/distributed_tracing.md +++ b/docs/guides/distributed_tracing.md @@ -36,51 +36,99 @@ This command will initiate a daemon mode HTTP trace collector listening on port ## PyTriton and Distributed Tracing -With the [OpenTelemetry collector set up](#setting-up-the-opentelemetry-environment), you can now configure the Triton Inference Server tracer to send trace spans to it. This is achieved by specifying the `trace_config` parameter in the TritonConfig: +With the [OpenTelemetry collector set up](#setting-up-the-opentelemetry-environment), you can now configure the Triton Inference Server tracer to send trace spans to it. - +The following example demonstrates how to configure the Triton Inference Server to send traces to the OpenTelemetry collector: + +```python +from pytriton.triton import TritonConfig +config=TritonConfig( + trace_config=[ + "level=TIMESTAMPS", + "rate=1", + "mode=opentelemetry", + "opentelemetry,url=http://localhost:4318/v1/traces", + "opentelemetry,resource=service.name=test_server_with_passthrough", + "opentelemetry,resource=test.key=test.value", + ], +) +``` + +Each parameter in the `trace_config` list corresponds to a specific configuration option: + +- `level=TIMESTAMPS`: Specifies the level of detail in the trace spans. +- `rate=1`: Indicates that all requests should be traced. +- `mode=opentelemetry`: Specifies the tracing mode. +- `opentelemetry,url=http://localhost:4318/v1/traces`: Specifies the URL of the OpenTelemetry collector. +- `opentelemetry,resource=service.name=test_server_with_passthrough`: Specifies the resource name for the service. +- `opentelemetry,resource=test.key=test.value`: Specifies additional resource attributes. + +All the supported Triton Inference Server trace API settings are described in the [user guide on tracing](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/trace.md). + + + +You can use config for the Triton Inference Server as follows: + + ```python import time import numpy as np -from pytriton.decorators import batch from pytriton.model_config import ModelConfig, Tensor -from pytriton.triton import Triton, TritonConfig - - -@batch -def passthrough(sleep): - max_sleep = np.max(sleep).item() - time.sleep(max_sleep) - return {"sleep": sleep} - - -with Triton( - config=TritonConfig( - trace_config=[ - "level=TIMESTAMPS", - "rate=1", - "mode=opentelemetry", - "opentelemetry,url=127.0.0.1:4318/v1/traces", - "opentelemetry,resource=service.name=test_server_with_passthrough", - "opentelemetry,resource=test.key=test.value", - ], - ) -) as triton: - triton.bind( - model_name="passthrough", - infer_func=passthrough, - inputs=[Tensor(name="sleep", dtype=np.float32, shape=(1,))], - outputs=[Tensor(name="sleep", dtype=np.float32, shape=(1,))], - config=ModelConfig(max_batch_size=128), - strict=True, - ) - triton.serve() +from pytriton.triton import Triton + +def passthrough(requests): + responses = [] + for request in requests: + sleep = request.data["sleep"] + error = request.data["error"] + raise_error = np.any(error) + if raise_error: + raise ValueError("Requested Error") + max_sleep = np.max(sleep).item() + + time.sleep(max_sleep) + + responses.append({"sleep": sleep, "error": error}) + return responses + +# config was defined in example above +triton = Triton(config=config) + +triton.bind( + model_name="passthrough", + infer_func=passthrough, + inputs=[Tensor(name="sleep", dtype=np.float32, shape=(1,)), Tensor(name="error", dtype=np.bool_, shape=(1,))], + outputs=[Tensor(name="sleep", dtype=np.float32, shape=(1,)), Tensor(name="error", dtype=np.bool_, shape=(1,))], + config=ModelConfig(max_batch_size=128), + strict=False, +) +triton.run() ``` -All the supported Triton Inference Server trace API settings are described in the [user guide on tracing](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/trace.md). + + + -Now, you can send requests with curl to the Triton Inference Server and analyze the trace visualizations in the Jaeger UI: +Now, you can send requests with curl to the Triton Inference Server: @@ -89,12 +137,209 @@ curl http://127.0.0.1:8000/v2/models/passthrough/generate \ -H "Content-Type: application/json" \ -sS \ -w "\n" \ - -d '{"sleep": 2}' + -d '{"sleep": 0.001, "error": false}' +``` + +The Triton Inference Server will send trace spans to the OpenTelemetry collector, which will visualize the trace in the Jaeger UI. The trace included above contains several span generated by Triton and PyTriton: + +- ``InferRequest``: The span representing the entire request processing lifecycle. +- ``passthrough``: The span representing the execution of the `passthrough` model. +- ``compute``: The span representing the computation of the model as seen in the Triton Inference Server. +- ``python_backend_execute``: The span representing the ``execute`` of the Python backend function. +- ``proxy_inference_callable``: The span representing the proxy inference callable execution. + +![Jaeger Traces Propagation](../assets/jaeger_traces_list_only_triton.png) + + +## Custom tracing with `traced_span` method + +PyTriton provides a simplified way to instrument your code with telemetry using the `traced_span` method from the `Request` class. This method allows you to easily create spans for different parts of your code without needing to directly interact with the OpenTelemetry API. + +### Example + +Here is an example of using the `traced_span` method within a passthrough function. This function processes requests, performing actions such as getting data, sleeping, and appending responses, with each step being traced for telemetry purposes. + +```python +import time +import numpy as np + +def passthrough(requests): + responses = [] + for request in requests: + # Create a traced span for getting data + with request.traced_span("pass-through-get-data"): + sleep = request.data["sleep"] + error = request.data["error"] + raise_error = np.any(error) + if raise_error: + raise ValueError("Requested Error") + max_sleep = np.max(sleep).item() + + # Create a traced span for sleeping + with request.traced_span("pass-through-sleep"): + time.sleep(max_sleep) + + # Create a traced span for appending responses + with request.traced_span("pass-through-append"): + responses.append({"sleep": sleep, "error": error}) + + return responses +``` + + + + +The introduction of three spans (`pass-through-get-data`, `pass-through-sleep`, `pass-through-append`) in the `passthrough` function allows you to track the time spent on each operation. These spans will be sent to the OpenTelemetry collector and visualized in the Jaeger UI. + +![Jaeger Traces Propagation](../assets/jaeger_traces_list.png) + + +### Explanation + +1. **Creating Spans with `traced_span`**: + - For each request, we use the `traced_span` method provided by the `Request` class to create spans for different operations. This method automatically handles the start and end of spans, simplifying the instrumentation process. + +2. **Getting Data**: + - We wrap the data extraction logic in a `traced_span` named `"pass-through-get-data"`. This span captures the process of extracting the `sleep` and `error` data from the request, checking for errors, and determining the maximum sleep time. + +3. **Sleeping**: + - We wrap the sleep operation in a `traced_span` named `"pass-through-sleep"`. This span captures the time spent sleeping. + +4. **Appending Responses**: + - We wrap the response appending logic in a `traced_span` named `"pass-through-append"`. This span captures the process of appending the response. + +### Benefits + +- **Simplicity**: Using the `traced_span` method is straightforward and does not require direct interaction with the OpenTelemetry API, making it easier to instrument your code. +- **Automatic Management**: The `traced_span` method automatically manages the lifecycle of spans, reducing boilerplate code and potential errors. +- **Seamless Integration**: This approach integrates seamlessly with existing PyTriton infrastructure, ensuring consistent telemetry data collection. + + +## Advanced Telemetry Usage with OpenTelemetry API + +In addition to the simple telemetry example provided using the PyTriton API, we can also leverage the direct usage of the OpenTelemetry API for more fine-grained control over tracing and telemetry. This advanced approach provides flexibility and a deeper integration with OpenTelemetry. + +### Example + +Here is an example of using the OpenTelemetry API directly within a passthrough function. This function processes requests, performing actions such as getting data, sleeping, and appending responses, with each step being traced for telemetry purposes. + +```python +from opentelemetry import trace +import time +import numpy as np + +# Initialize a tracer for the current module +tracer = trace.get_tracer(__name__) + +def passthrough(requests): + responses = [] + for request in requests: + # Use the span associated with the request, but do not end it automatically + with trace.use_span(request.span, end_on_exit=False): + # Start a new span for getting data + with tracer.start_as_current_span("pass-through-get-data"): + sleep = request.data["sleep"] + error = request.data["error"] + raise_error = np.any(error) + if raise_error: + raise ValueError("Requested Error") + max_sleep = np.max(sleep).item() + + # Start a new span for sleeping + with tracer.start_as_current_span("pass-through-sleep"): + time.sleep(max_sleep) + + # Start a new span for appending responses + with tracer.start_as_current_span("pass-through-append"): + responses.append({"sleep": sleep, "error": error}) + + return responses +``` + + + + +![Jaeger Traces Propagation for advanced usage](../assets/jaeger_traces_list.png) + +### Explanation + +1. **Initialization of Tracer**: + - We initialize a tracer for the current module using `trace.get_tracer(__name__)`. This tracer will be used to create spans that represent individual operations within the `passthrough` function. + +2. **Using Existing Spans**: + - For each request, we use the span already associated with it by wrapping the processing logic within `trace.use_span(request.span, end_on_exit=False)`. This ensures that our custom spans are nested within the request's span, providing a hierarchical structure to the telemetry data. -![Jaeger Traces List](./assets/jaeger_traces_list.png) +3. **Creating Custom Spans**: + - We create custom spans for different operations (`pass-through-get-data`, `pass-through-sleep`, `pass-through-append`) using `tracer.start_as_current_span`. Each operation is wrapped in its respective span, capturing the execution time and any additional attributes or events we might want to add. + + +### Benefits + +- **Flexible Integration**: Using the OpenTelemetry API directly allows for greater flexibility in how spans are managed and how telemetry data is collected and reported. +- **Seamless Fallback**: The use of `trace.use_span` ensures that if telemetry is not active, the span operations are effectively no-ops, avoiding unnecessary checks and minimizing overhead. -![Jaeger Trace Details](./assets/jaeger_trace_details.png) ## OpenTelemetry Context Propagation @@ -108,9 +353,7 @@ To test this feature, you can use the following Python client based on python [r pip install opentelemetry-api opentelemetry-sdk opentelemetry-instrumentation-requests opentelemetry-exporter-otlp ``` -Then, run the following Python script: - - +First you need to import the required packages and configure the OpenTelemetry context and instrumet requests library: ```python import time @@ -126,6 +369,13 @@ from opentelemetry.instrumentation.requests import RequestsInstrumentor # Enable instrumentation in the requests library. RequestsInstrumentor().instrument() +``` + +The next step is to configure the OpenTelemetry context: + + + +```python # OTLPSpanExporter can be also configured with OTEL_EXPORTER_OTLP_TRACES_ENDPOINT environment variable trace.set_tracer_provider( TracerProvider( @@ -133,25 +383,72 @@ trace.set_tracer_provider( active_span_processor=BatchSpanProcessor(OTLPSpanExporter(endpoint="http://127.0.0.1:4318/v1/traces")), ) ) +``` + + +The final step is to send a request to the Triton Inference Server and propagate the OpenTelemetry context: + +```python tracer = trace.get_tracer(__name__) with tracer.start_as_current_span("outgoing-request"): - time.sleep(1.0) + time.sleep(0.001) response = requests.post( "http://127.0.0.1:8000/v2/models/passthrough/generate", headers={"Content-Type": "application/json"}, - json={"sleep": 2.0}, + json={"sleep": 0.001, "error": False}, ) - time.sleep(1.0) + time.sleep(0.001) print(response.json()) ``` -This script sends a request to the Triton Inference Server and propagates its own OpenTelemetry context. The Triton Inference Server will then forward this context to the OpenTelemetry collector, which will visualize the trace in the Jaeger UI. - -![Jaeger Trace List with Context Propagation](./assets/jaegger_context_propagation_traces_list.png) + + -![Jaeger Trace Details with Context Propagation](./assets/jaegger_context_propagation_request_details.png) +This script sends a request to the Triton Inference Server and propagates its own OpenTelemetry context. The Triton Inference Server will then forward this context to the OpenTelemetry collector, which will visualize the trace in the Jaeger UI. The trace included above contains two spans: `outgoing-request` explicitly created in your script and `POST` created by requests instrumentation. -You can see that the trace spans are now linked across the two services. \ No newline at end of file +![Jaeger Traces Propagation](../assets/jaeger_traces_list_propagation.png) diff --git a/docs/pypi.rst b/docs/pypi.rst index 8f13bc5e..8a531c35 100644 --- a/docs/pypi.rst +++ b/docs/pypi.rst @@ -40,7 +40,9 @@ The distinct capabilities of PyTriton are summarized in the feature matrix: | Performance | You can benefit from `dynamic batching `_, response cache, model | | | pipelining, `clusters `_, and GPU/CPU inference. | +| | guides/deploying_in_clusters/>`_, performance `tracing `_, and GPU/CPU | +| | inference. | +------------------------+--------------------------------------------------------------------------------------+ | Decorators | You can use batching `decorators `_ to handle batching and other | diff --git a/pyproject.toml b/pyproject.toml index 544b2acc..4f6fe18a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,11 @@ test = [ "tqdm >= 4.64.1", "psutil ~= 5.9", "py-spy ~= 0.3", + "opentelemetry-api ~= 1.24", # Used for testing Open Telemetry documentiation + "opentelemetry-sdk ~= 1.24", + "opentelemetry-instrumentation-requests ~= 0.45b0", + "opentelemetry-exporter-otlp ~= 1.24", + ] doc = [ "GitPython >= 3.1", diff --git a/pytriton/decorators.py b/pytriton/decorators.py index dea29d9f..b8f8603f 100644 --- a/pytriton/decorators.py +++ b/pytriton/decorators.py @@ -30,6 +30,7 @@ from pytriton.exceptions import PyTritonBadParameterError, PyTritonRuntimeError, PyTritonValidationError from pytriton.model_config.triton_model_config import TritonModelConfig from pytriton.proxy.data import _serialize_byte_tensor +from pytriton.proxy.telemetry import start_span_from_span class _WrappedWithWrapper(NamedTuple): @@ -190,6 +191,8 @@ def batch(wrapped, instance, args, kwargs): ValueError: If the output tensors have different than expected batch sizes. Expected batch size is calculated as a sum of batch sizes of all requests. """ + telemetry_name = "pytriton-batch-decorator-span" + req_list = args[0] input_names = req_list[0].keys() @@ -205,7 +208,12 @@ def batch(wrapped, instance, args, kwargs): args = args[1:] new_kwargs = dict(kwargs) new_kwargs.update(inputs) - outputs = wrapped(*args, **new_kwargs) + spans = [start_span_from_span(request.span, telemetry_name) for request in req_list if request.span is not None] + try: + outputs = wrapped(*args, **new_kwargs) + finally: + for span in spans: + span.end() def _split_result(_result): outputs = convert_output(_result, wrapped, instance) diff --git a/pytriton/models/model.py b/pytriton/models/model.py index ac715255..3c0890a5 100644 --- a/pytriton/models/model.py +++ b/pytriton/models/model.py @@ -16,13 +16,14 @@ import base64 import copy import enum +import json import logging import os import pathlib import shutil import threading import typing -from typing import Callable, Optional, Sequence, Union +from typing import Callable, List, Optional, Sequence, Union from pytriton.decorators import TritonContext from pytriton.exceptions import PyTritonValidationError @@ -69,7 +70,7 @@ def _inject_triton_context(triton_context: TritonContext, model_callable: Callab class Model: """Model definition.""" - SCRIPT_FILES_TO_COPY = ["communication.py", "data.py", "model.py", "types.py"] + SCRIPT_FILES_TO_COPY = ["communication.py", "data.py", "model.py", "types.py", "telemetry.py"] def __init__( self, @@ -82,6 +83,7 @@ def __init__( workspace: Workspace, triton_context: TritonContext, strict: bool, + trace_config: Optional[List[str]] = None, ): """Create Python model with required data. @@ -95,6 +97,7 @@ def __init__( workspace: workspace for storing artifacts triton_context: Triton context strict: Enable strict validation of model outputs + trace_config: List of trace config parameters Raises: PyTritonValidationError if one or more of provided values are incorrect. @@ -107,6 +110,7 @@ def __init__( self._requests_respones_connectors = [] self._observers_lock = threading.Lock() self._strict = strict + self._trace_config = trace_config self.infer_functions = [inference_fn] if isinstance(inference_fn, Callable) else inference_fn if not isinstance(self.infer_functions, (Sequence, Callable)): @@ -271,6 +275,9 @@ def _get_triton_model_config(self) -> TritonModelConfig: ModelConfig object with configuration for Python model deployment """ if not self._triton_model_config: + backend_parameters = {"workspace-path": self._workspace.path.as_posix()} + if self._trace_config: + backend_parameters["trace-config"] = base64.b64encode(json.dumps(self._trace_config).encode()).decode() triton_model_config = TritonModelConfig( model_name=self.model_name, model_version=self.model_version, @@ -278,7 +285,7 @@ def _get_triton_model_config(self) -> TritonModelConfig: batcher=self.config.batcher, max_batch_size=self.config.max_batch_size, decoupled=self.config.decoupled, - backend_parameters={"workspace-path": self._workspace.path.as_posix()}, + backend_parameters=backend_parameters, instance_group={DeviceKind.KIND_CPU: len(self.infer_functions)}, ) inputs = [] diff --git a/pytriton/proxy/data.py b/pytriton/proxy/data.py index 559afdf9..d8201738 100644 --- a/pytriton/proxy/data.py +++ b/pytriton/proxy/data.py @@ -42,6 +42,7 @@ import numpy as np +from .telemetry import get_span_dict, start_span_from_remote from .types import Request, Requests, Response, Responses LOGGER = logging.getLogger(__name__) @@ -853,12 +854,14 @@ def serialize_requests(self, requests: Requests) -> bytes: Serialized requests """ serialized_requests = self._serialize_named_tensors_lists(requests) - requests = { - "requests": [ - {"data": serialized_request, "parameters": request.parameters} - for request, serialized_request in zip(requests, serialized_requests) - ] - } + requests_list = [] + for request, serialized_request in zip(requests, serialized_requests): + serialized_request = {"data": serialized_request, "parameters": request.parameters} + if request.span is not None: + serialized_request["span"] = get_span_dict(request.span) + requests_list.append(serialized_request) + + requests = {"requests": requests_list} requests = json.dumps(requests).encode("utf-8") return requests @@ -875,14 +878,18 @@ def deserialize_requests(self, requests_payload: bytes) -> Requests: requests_data = [request["data"] for request in requests["requests"]] requests_data = self._deserialized_named_tensors_lists(requests_data) - requests = [ - Request( - data=request_data, - parameters=request.get("parameters"), - ) - for request, request_data in zip(requests["requests"], requests_data) - ] - return requests + deserialized_requests = [] + for request, request_data in zip(requests["requests"], requests_data): + kwargs = {"data": request_data, "parameters": request.get("parameters")} + # FIXME: move span creation above just after json.loads + if "span" in request: + span_dict = request["span"] + span = start_span_from_remote(span_dict, "proxy_inference_callable") + kwargs["span"] = span + request_wrapped = Request(**kwargs) + deserialized_requests.append(request_wrapped) + + return deserialized_requests def free_requests_resources(self, requests_payload: bytes): """Free resources used by requests.""" @@ -981,12 +988,14 @@ def serialize_requests(self, requests: Requests) -> bytes: Serialized requests """ serialized_requests = self._serialize_named_tensors_lists(requests) - requests = { - "requests": [ - {"data": serialized_request, "parameters": request.parameters} - for request, serialized_request in zip(requests, serialized_requests) - ] - } + requests_list = [] + for request, serialized_request in zip(requests, serialized_requests): + serialized_request = {"data": serialized_request, "parameters": request.parameters} + if request.span is not None: + serialized_request["span"] = get_span_dict(request.span) + requests_list.append(serialized_request) + + requests = {"requests": requests_list} return json.dumps(requests).encode("utf-8") def deserialize_requests(self, requests_payload: bytes) -> Requests: @@ -999,16 +1008,23 @@ def deserialize_requests(self, requests_payload: bytes) -> Requests: List of deserialized requests """ requests = json.loads(requests_payload) - return [ - Request( - data={ - input_name: self._tensor_store.get(tensor_id) - for input_name, tensor_id in request.get("data", {}).items() - }, - parameters=request.get("parameters"), - ) - for request in requests["requests"] - ] + deserialized_requests = [] + for request in requests["requests"]: + kwargs = {} + if "span" in request: + span_dict = request["span"] + span = start_span_from_remote(span_dict, "proxy_inference_callable") + kwargs["span"] = span + request_data = { + input_name: self._tensor_store.get(tensor_id) + for input_name, tensor_id in request.get("data", {}).items() + } + kwargs["data"] = request_data + kwargs["parameters"] = request.get("parameters") + request_wrapped = Request(**kwargs) + deserialized_requests.append(request_wrapped) + + return deserialized_requests def free_requests_resources(self, requests_payload: bytes): """Free resources used by requests.""" diff --git a/pytriton/proxy/inference.py b/pytriton/proxy/inference.py index 05de153f..dccb2c61 100644 --- a/pytriton/proxy/inference.py +++ b/pytriton/proxy/inference.py @@ -26,6 +26,7 @@ from pytriton.exceptions import PyTritonUnrecoverableError from pytriton.proxy.communication import PyTritonResponseFlags, RequestsServerClient +from pytriton.proxy.telemetry import end_span from pytriton.proxy.types import Requests, Responses, ResponsesNoneOrError, Scope from pytriton.proxy.validators import TritonResultsValidator @@ -168,6 +169,7 @@ async def handle_requests(self, scope, requests_payload: bytes, send): requests_id = scope["requests_id"] queue = self._responses_queues[requests_id] = asyncio.Queue() loop = asyncio.get_running_loop() + requests = None def _wait_for_inference_fn(timeout_s: float): with self._run_inference_condition: @@ -208,6 +210,10 @@ def _wait_for_inference_fn(timeout_s: float): flags = PyTritonResponseFlags.ERROR | PyTritonResponseFlags.EOS await send(scope, flags, error_msg) finally: + if requests is not None: + for request in requests: + end_span(request.span) + self._serializer_deserializer.free_requests_resources(requests_payload) self._responses_queues.pop(requests_id) LOGGER.debug(f"Finished handling requests for {scope['requests_id'].hex()}") diff --git a/pytriton/proxy/model.py b/pytriton/proxy/model.py index ff0b70c0..d815a380 100644 --- a/pytriton/proxy/model.py +++ b/pytriton/proxy/model.py @@ -43,6 +43,7 @@ Base64SerializerDeserializer, TensorStoreSerializerDeserializer, ) +from .telemetry import TracableModel # pytype: disable=import-error from .types import Request, Response, ResponsesOrError # pytype: disable=import-error LOGGER = logging.getLogger(__name__) @@ -112,28 +113,41 @@ def shutdown(self): """ self._server.shutdown() - def push(self, requests_id: bytes, triton_requests): + def push(self, requests_id: bytes, triton_requests, spans=None): """Push requests to TritonRequestsServer queue. Args: requests_id: id of requests triton_requests: list of Triton requests + spans: list of OpenTelemetry spans """ self._server.wait_till_running() # wait until loop is up and running, raise RuntimeError if server is stopping or not launched yet - return asyncio.run_coroutine_threadsafe( - self._send_requests(requests_id, triton_requests), self._server.server_loop - ) + kwargs = {"requests_id": requests_id, "triton_requests": triton_requests} + if spans is not None: + kwargs["spans"] = spans + return asyncio.run_coroutine_threadsafe(self._send_requests(**kwargs), self._server.server_loop) - def _wrap_request(self, triton_request, inputs) -> Request: + def _wrap_request(self, triton_request, inputs, span=None) -> Request: request = {} for input_name in inputs: input_tensor = pb_utils.get_input_tensor_by_name(triton_request, input_name) if input_tensor is not None: request[input_name] = input_tensor.as_numpy() - return Request(data=request, parameters=json.loads(triton_request.parameters())) - - async def _send_requests(self, requests_id: bytes, triton_requests) -> ConcurrentFuture: - requests = [self._wrap_request(triton_request, self._model_inputs_names) for triton_request in triton_requests] + kwargs = {} + if span is not None: + kwargs["span"] = span + return Request(data=request, parameters=json.loads(triton_request.parameters()), **kwargs) + + async def _send_requests(self, requests_id: bytes, triton_requests, spans=None) -> ConcurrentFuture: + requests = triton_requests + if spans is None: + spans = [None] * len(triton_requests) + requests_with_spans = zip(triton_requests, spans) + + requests = [ + self._wrap_request(triton_request, self._model_inputs_names, span) + for triton_request, span in requests_with_spans + ] requests_payload = self._serializer_deserializer.serialize_requests(requests) # will return when socket.send_multipart returns responses_future = ConcurrentFuture() @@ -339,6 +353,7 @@ def __init__(self): self._frontend = None self._requests = None self._id_counter = 0 + self._tracable_model = None def initialize(self, args): """Triton Inference Server Python Backend API called only once when the model is being loaded. @@ -375,6 +390,10 @@ def initialize(self, args): workspace_path = pathlib.Path(model_config["parameters"]["workspace-path"]["string_value"]) + self._tracable_model = TracableModel() + if "trace-config" in model_config["parameters"]: + self._tracable_model.configure_tracing(model_config["parameters"]["trace-config"]["string_value"]) + LOGGER.debug(f"Model instance name: {self._model_instance_name}") LOGGER.debug(f"Decoupled model: {self._decoupled_model}") LOGGER.debug(f"Workspace path: {workspace_path}") @@ -443,6 +462,7 @@ def execute(self, triton_requests): pb_utils.TritonModelException: when model execution fails """ try: + spans = self._tracable_model.start_requests_spans(triton_requests) def _generate_id(): self._id_counter = (self._id_counter + 1) % 2**32 @@ -454,17 +474,21 @@ def _generate_id(): self._requests[requests_id] = triton_requests # TODO: add this future to container to avoid garbage collection - handle_responses_task_async_future = self._requests_server.push(requests_id, triton_requests) + handle_responses_task_async_future = self._requests_server.push(requests_id, triton_requests, spans) if not self._decoupled_model: handle_responses_concurrent_future = handle_responses_task_async_future.result() triton_responses_or_error = handle_responses_concurrent_future.result() + self._tracable_model.end_requests_spans(spans, triton_responses_or_error) + if triton_responses_or_error is not None and isinstance(triton_responses_or_error, Exception): raise triton_responses_or_error else: triton_responses_or_error = None + self._tracable_model.end_requests_spans(spans, triton_responses_or_error) + return triton_responses_or_error except Exception: msg = traceback.format_exc() diff --git a/pytriton/proxy/telemetry.py b/pytriton/proxy/telemetry.py new file mode 100644 index 00000000..677b6e03 --- /dev/null +++ b/pytriton/proxy/telemetry.py @@ -0,0 +1,339 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Telemetry handling module. + +This module contains optional import for Open Telemetry and functions to handle it. +""" + +import base64 +import importlib.util +import json +import logging +from contextlib import contextmanager +from typing import Dict, Generator, List + +# Open Telemetry is not mandatory for PyTriton, but it can be used for tracing +# The import in functions breaks telemetry spans handlign in runtime +try: + import opentelemetry.baggage # pytype: disable=import-error + import opentelemetry.trace # pytype: disable=import-error + import opentelemetry.trace.propagation.tracecontext # pytype: disable=import-error + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter # pytype: disable=import-error + from opentelemetry.sdk.resources import Resource # pytype: disable=import-error + from opentelemetry.sdk.trace import TracerProvider # pytype: disable=import-error + from opentelemetry.sdk.trace.export import ( # pytype: disable=import-error + BatchSpanProcessor, # pytype: disable=import-error + ) + + # from opentelemetry import trace, context + from opentelemetry.trace import ( # pytype: disable=import-error + NonRecordingSpan, + SpanContext, + Status, + StatusCode, + TraceFlags, + ) + from opentelemetry.trace.propagation.tracecontext import ( # pytype: disable=import-error + TraceContextTextMapPropagator, + ) + +except ImportError: + pass + + +LOGGER = logging.getLogger(__name__) + + +_open_telemetry_tracer = None + + +def set_telemetry_tracer(tracer): + """Set tracer for Open Telemetry. + + Sets global tracer used by proxy at inference callable side of communication. + + See trace_config parameter for TritonConfig to set also tracing for Triton. + Function can only be called onece. Second call raises exception. + + Args: + tracer: Tracer object for Open Telemetry + + Raises: + ValueError for second and all further calls + """ + global _open_telemetry_tracer + if _open_telemetry_tracer is not None: + raise ValueError("Telemetry tracer is already set") + LOGGER.debug(f"Setting telemetry tracer: {tracer}") + _open_telemetry_tracer = tracer + + +def get_telemetry_tracer(): + """Return telemetry tracer set by set_telemetry_tracer.""" + global _open_telemetry_tracer + return _open_telemetry_tracer + + +def get_span_dict(span): + """Serialize Open Telemetry span for sending over proxy bus.""" + headers = {} + with opentelemetry.trace.use_span(span, end_on_exit=False): + ctx = opentelemetry.baggage.set_baggage("zmq", "baggage") + opentelemetry.trace.propagation.tracecontext.TraceContextTextMapPropagator().inject(headers, ctx) + return headers + + +def start_span_from_remote(span_dict: Dict[str, int], name: str): + """Create new Open Telemetry span from remote span deserialized from proxy. + + The span ownership goes to caller, which MUST call spand end to register + event in Open Telemetry server. + + Args: + span_dict: dictionary with fields trace_id and span_id or None + name: name of new span started + + Returns: + Open Telemetry span or None if telemetry is not configured or span_dict is None. + """ + global _open_telemetry_tracer + if _open_telemetry_tracer is not None: + ctx = opentelemetry.trace.propagation.tracecontext.TraceContextTextMapPropagator().extract(span_dict) + return _open_telemetry_tracer.start_span(name, context=ctx) + else: + return None + + +def start_span_from_span(span, name): + """Create new Open Telemetry span from existing span. + + The span ownership goes to caller, which MUST call spand end to register + event in Open Telemetry server. + + Args: + span: Open Telemetry span + name: name of new span started + + Returns: + Open Telemetry span + """ + span_context = SpanContext( + trace_id=span.context.trace_id, + span_id=span.context.span_id, + is_remote=True, + trace_flags=TraceFlags(0x01), + ) + ctx = opentelemetry.trace.set_span_in_context(NonRecordingSpan(span_context)) + tracer = get_telemetry_tracer() + return tracer.start_span(name, context=ctx) + + +def parse_trace_config(trace_config_list: List[str]): + """Parse Triton Open Telemetry config. + + The TritonConfig trace_config can be passed here to obtain Open Telemetry resource and + URL to connect to server. + + Example of configuration: + ``` + trace_config=[ + "mode=opentelemetry", + "opentelemetry,url=", + "opentelemetry,resource=service.name=", + "opentelemetry,resource=test.key=test.value", + ] + ``` + Elements: + - List MUST contain mode to indicate opentelemetry support. + - List MUST contain url to allow opening connecion to Open Telemetry server + - List SHOULD contain service.name to improve logging + - List SHOULD contain additional keys like test.key. + + Args: + trace_config_list: list of configuration variable for Tritonconfig + """ + if not any("mode=opentelemetry" in config for config in trace_config_list): + raise ValueError("Only opentelemetry mode is supported") + url_entry = next((config for config in trace_config_list if "opentelemetry,url=" in config), None) + if url_entry is None: + raise ValueError("opentelemetry,url is required") + url = url_entry.split("opentelemetry,url=")[1] + + resource_attributes = {} + for config in trace_config_list: + if config.startswith("opentelemetry,resource="): + resource_str = config.split("opentelemetry,resource=")[1] + resource_parts = resource_str.split(",") + for part in resource_parts: + key, val = part.split("=") + resource_attributes[key] = val + + LOGGER.debug(f"OpenTelemetry URL: {url}") + LOGGER.debug(f"Resource Attributes: {resource_attributes}") + + resource = Resource(attributes=resource_attributes) + return url, resource + + +@contextmanager +def traced_span(request, span_name, **kwargs) -> Generator[None, None, None]: + """Context manager handles opening span for request. + + This context manager opens Open Telemetry span for request. The span is + automatically closed when context manager exits. + + Example of use in inference callable: + ``` + def inference_callable(requests): + responses = [] + for request in requests: + with traced_span(request, "pass-through-get-data"): + # Execute compute for single request + ``` + + Args: + request: Request passed to inference callable + span_name: Name of span to yield + **kwargs: Additional arguments passed to Open Telemetry tracer + """ + global _open_telemetry_tracer + span = request.span + if span is not None: + with opentelemetry.trace.use_span(span, end_on_exit=False, record_exception=False): + with _open_telemetry_tracer.start_as_current_span(span_name, **kwargs): + yield + else: + yield + + +def build_proxy_tracer_from_triton_config(trace_config): + """Build OpenTelemetry tracer from TritonConfig trace_config. + + Args: + trace_config: list of trace configuration variables + + Returns: + OpenTelemetry tracer + """ + raise_if_no_telemetry() + LOGGER.debug(f"Building OpenTelmetry tracer from config: {trace_config}") + url, resource = parse_trace_config(trace_config) + LOGGER.debug(f"Creating OpenTelemetry tracer with URL: {url}") + opentelemetry.trace.set_tracer_provider( + TracerProvider( + resource=resource, + active_span_processor=BatchSpanProcessor(OTLPSpanExporter(endpoint=url)), + ) + ) + + tracer = opentelemetry.trace.get_tracer(__name__) + return tracer + + +def raise_if_no_telemetry(): + """Raise ImportError if OpenTelemetry is not installed.""" + # Import added to trigger error for missing package + if importlib.util.find_spec("opentelemetry.trace") is None: + pip = "pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp" + raise ImportError(f"OpenTelemetry is not installed. Please install it using '{pip}'.") + + +def end_span(span, error=None): + """End Open Telemetry span and set status if error is provided. + + Args: + span: Open Telemetry span + error: error message to set in span status + """ + if span is not None: + if error is not None: + span.set_status(Status(StatusCode.ERROR, error)) + else: + span.set_status(Status(StatusCode.OK)) + span.end() + + +class TracableModel: + """Model class with tracing support. + + This class is base class for model with tracing support. It provides + methods to start and end span for each inference call. + """ + + def __init__(self): + """Initialize TracableModel.""" + self._open_telemetry_tracer = None + + def configure_tracing(self, trace_config): + """Configure tracing for model. + + This method configures OpenTelemetry tracing for model. The trace_config + is list of configuration variables passed by TritonConfig. + + Args: + trace_config: list of trace configuration variables + """ + try: + raise_if_no_telemetry() + + trace_config_json = base64.b64decode(trace_config).decode("utf-8") + trace_config_list = json.loads(trace_config_json) + LOGGER.debug(f"Configuring tracing with {trace_config_list}") + + url, resource = parse_trace_config(trace_config_list) + + opentelemetry.trace.set_tracer_provider(TracerProvider(resource=resource)) + trace_provider = opentelemetry.trace.get_tracer_provider() + self._open_telemetry_tracer = trace_provider.get_tracer("pbe") + trace_provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter(endpoint=url))) + + except Exception as e: + raise ValueError(f"Failed to configure tracing: {e}") from e + + def start_requests_spans(self, triton_requests): + """Start spans for requests. + + This method starts spans for each request in triton_requests. + + Args: + triton_requests: list of Triton requests + """ + if self._open_telemetry_tracer is not None: + spans = [] + for triton_request in triton_requests: + context = triton_request.trace().get_context() + if context is None: + context = "{}" + ctx = TraceContextTextMapPropagator().extract(carrier=json.loads(context)) + span = self._open_telemetry_tracer.start_span("python_backend_execute", context=ctx) + spans.append(span) + return spans + return None + + def end_requests_spans(self, spans, triton_responses_or_error): + """End spans for requests. + + This method ends spans for each request in triton_requests. + + Args: + spans: list of spans for requests + triton_responses_or_error: list of Triton responses or error + """ + if self._open_telemetry_tracer is not None: + status = Status(StatusCode.OK) + if triton_responses_or_error is not None and isinstance(triton_responses_or_error, Exception): + status = Status(StatusCode.ERROR, str(triton_responses_or_error)) + for span in spans: + span.set_status(status) + span.end() diff --git a/pytriton/proxy/types.py b/pytriton/proxy/types.py index 1534e421..dd24e603 100644 --- a/pytriton/proxy/types.py +++ b/pytriton/proxy/types.py @@ -18,6 +18,8 @@ import numpy as np +from .telemetry import traced_span + @dataclasses.dataclass class Request: @@ -27,6 +29,8 @@ class Request: """Input data for the request.""" parameters: Optional[Dict[str, Union[str, int, bool]]] = None """Parameters for the request.""" + span: Optional[Any] = None + """Telemetry span for request""" def __getitem__(self, input_name: str) -> np.ndarray: """Get input data.""" @@ -60,6 +64,14 @@ def values(self): """Iterate over input data.""" return self.data.values() + def traced_span(self, span_name): + """Yields Open Telemetry a span for the request. + + Args: + span_name (str): Name of the span + """ + return traced_span(self, span_name) + Requests = List[Request] diff --git a/pytriton/triton.py b/pytriton/triton.py index 86411a13..2572d4ba 100644 --- a/pytriton/triton.py +++ b/pytriton/triton.py @@ -57,6 +57,7 @@ from pytriton.model_config.tensor import Tensor from pytriton.models.manager import ModelManager from pytriton.models.model import Model, ModelConfig, ModelEvent +from pytriton.proxy.telemetry import build_proxy_tracer_from_triton_config, get_telemetry_tracer, set_telemetry_tracer from pytriton.server.python_backend_config import PythonBackendConfig from pytriton.server.triton_server import TritonServer from pytriton.server.triton_server_config import TritonServerConfig @@ -251,6 +252,10 @@ def _cast_value(_field, _value): is_optional = typing_inspect.is_optional_type(field_type) if is_optional: field_type = field_type.__args__[0] + if hasattr(field_type, "__origin__") and field_type.__origin__ is list: + return list(_value) if _value is not None else None + elif isinstance(_value, str) and isinstance(field_type, type) and issubclass(field_type, list): + return _value.split(",") return field_type(_value) config_with_casted_values = { @@ -275,7 +280,29 @@ def from_env(cls) -> "TritonConfig": TritonConfig class instantiated from environment variables. """ prefix = "PYTRITON_TRITON_CONFIG_" - config = {name[len(prefix) :].lower(): value for name, value in os.environ.items() if name.startswith(prefix)} + config = {} + list_pattern = re.compile(r"^(.+?)_(\d+)$") + + for name, value in os.environ.items(): + if name.startswith(prefix): + key = name[len(prefix) :].lower() + match = list_pattern.match(key) + if match: + list_key, index = match.groups() + index = int(index) + if list_key not in config: + config[list_key] = [] + if len(config[list_key]) <= index: + config[list_key].extend([None] * (index + 1 - len(config[list_key]))) + config[list_key][index] = value + else: + config[key] = value + + # Remove None values from lists (in case of non-sequential indexes) + for key in config: + if isinstance(config[key], list): + config[key] = [item for item in config[key] if item is not None] + return cls.from_dict(config) @@ -385,6 +412,7 @@ def bind( model_version: int = 1, config: Optional[ModelConfig] = None, strict: bool = False, + trace_config: Optional[List[str]] = None, ) -> None: """Create a model with given name and inference callable binding into Triton Inference Server. @@ -401,8 +429,27 @@ def bind( model_version: Version of model config: Model configuration for Triton Inference Server deployment strict: Enable strict validation between model config outputs and inference function result + trace_config: List of trace config parameters """ self._validate_model_name(model_name) + model_kwargs = {} + if trace_config is None: + triton_config = getattr(self, "_config", None) + if triton_config is not None: + trace_config = getattr(triton_config, "trace_config", None) + if trace_config is not None: + LOGGER.info(f"Using trace config from TritonConfig: {trace_config}") + model_kwargs["trace_config"] = trace_config + else: + model_kwargs["trace_config"] = trace_config + telemetry_tracer = get_telemetry_tracer() + + # Automatically set telemetry tracer if not set at the proxy side + if telemetry_tracer is None and trace_config is not None: + LOGGER.info("Setting telemetry tracer from TritonConfig") + telemetry_tracer = build_proxy_tracer_from_triton_config(trace_config) + set_telemetry_tracer(telemetry_tracer) + model = Model( model_name=model_name, model_version=model_version, @@ -413,6 +460,7 @@ def bind( workspace=self._workspace, triton_context=self._triton_context, strict=strict, + **model_kwargs, ) model.on_model_event(self._on_model_event)