diff --git a/src/prefect/_internal/concurrency/services.py b/src/prefect/_internal/concurrency/services.py index 6e33aca96be4..ad54c076299c 100644 --- a/src/prefect/_internal/concurrency/services.py +++ b/src/prefect/_internal/concurrency/services.py @@ -225,7 +225,9 @@ def drain(self, at_exit: bool = False) -> None: return future.result() @classmethod - def drain_all(cls, timeout: Optional[float] = None) -> Union[Awaitable, None]: + def drain_all( + cls, timeout: Optional[float] = None, at_exit=True + ) -> Union[Awaitable, None]: """ Stop all instances of the service and wait for all remaining work to be completed. @@ -237,7 +239,7 @@ def drain_all(cls, timeout: Optional[float] = None) -> Union[Awaitable, None]: instances = tuple(cls._instances.values()) for instance in instances: - futures.append(instance._drain()) + futures.append(instance._drain(at_exit=at_exit)) if get_running_loop() is not None: return ( @@ -376,10 +378,10 @@ def _get_size(self, item: T) -> int: @contextlib.contextmanager def drain_on_exit(service: QueueService): yield - service.drain_all() + service.drain_all(at_exit=True) @contextlib.asynccontextmanager async def drain_on_exit_async(service: QueueService): yield - await service.drain_all() + await service.drain_all(at_exit=True) diff --git a/src/prefect/telemetry/instrumentation.py b/src/prefect/telemetry/instrumentation.py index f1f458b785c1..7f6479238a26 100644 --- a/src/prefect/telemetry/instrumentation.py +++ b/src/prefect/telemetry/instrumentation.py @@ -7,9 +7,7 @@ from opentelemetry import metrics, trace from opentelemetry._logs import set_logger_provider -from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler from opentelemetry.sdk._logs.export import SimpleLogRecordProcessor from opentelemetry.sdk.metrics import MeterProvider @@ -19,6 +17,7 @@ from .logging import set_log_handler from .processors import InFlightSpanProcessor +from .services import QueueingLogExporter, QueueingSpanExporter if TYPE_CHECKING: from opentelemetry.sdk._logs import LoggerProvider @@ -83,11 +82,10 @@ def _setup_trace_provider( resource: Resource, headers: dict[str, str], telemetry_url: str ) -> TracerProvider: trace_provider = TracerProvider(resource=resource) - otlp_span_exporter = OTLPSpanExporter( - endpoint=_url_join(telemetry_url, "v1/traces"), - headers=headers, + queueing_span_exporter = QueueingSpanExporter.instance( + _url_join(telemetry_url, "v1/traces"), tuple(headers.items()) ) - trace_provider.add_span_processor(InFlightSpanProcessor(otlp_span_exporter)) + trace_provider.add_span_processor(InFlightSpanProcessor(queueing_span_exporter)) trace.set_tracer_provider(trace_provider) return trace_provider @@ -112,11 +110,12 @@ def _setup_logger_provider( resource: Resource, headers: dict[str, str], telemetry_url: str ) -> LoggerProvider: logger_provider = LoggerProvider(resource=resource) - otlp_exporter = OTLPLogExporter( - endpoint=_url_join(telemetry_url, "v1/logs"), - headers=headers, + queueing_log_exporter = QueueingLogExporter.instance( + _url_join(telemetry_url, "v1/logs"), tuple(headers.items()) + ) + logger_provider.add_log_record_processor( + SimpleLogRecordProcessor(queueing_log_exporter) ) - logger_provider.add_log_record_processor(SimpleLogRecordProcessor(otlp_exporter)) set_logger_provider(logger_provider) log_handler = LoggingHandler(level=logging.NOTSET, logger_provider=logger_provider) diff --git a/src/prefect/telemetry/services.py b/src/prefect/telemetry/services.py new file mode 100644 index 000000000000..6d9d7e47bca6 --- /dev/null +++ b/src/prefect/telemetry/services.py @@ -0,0 +1,67 @@ +from abc import abstractmethod +from typing import Union + +from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk._logs import LogData +from opentelemetry.sdk._logs.export import LogExporter +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExporter + +from prefect._internal.concurrency.services import BatchedQueueService + + +class BaseQueueingExporter(BatchedQueueService): + _max_batch_size = 512 + _min_interval = 2.0 + _otlp_exporter: Union[SpanExporter, LogExporter] + + def export(self, batch: list[Union[ReadableSpan, LogData]]) -> None: + for item in batch: + self.send(item) + + @abstractmethod + def _export_batch(self, items: list[Union[ReadableSpan, LogData]]) -> None: + pass + + async def _handle_batch(self, items: list[Union[ReadableSpan, LogData]]) -> None: + try: + self._export_batch(items) + except Exception as e: + self._logger.exception(f"Failed to export batch: {e}") + raise + + def shutdown(self) -> None: + if self._stopped: + return + + self.drain() + self._otlp_exporter.shutdown() + + +class QueueingSpanExporter(BaseQueueingExporter, SpanExporter): + _otlp_exporter: OTLPSpanExporter + + def __init__(self, endpoint: str, headers: tuple[tuple[str, str]]): + super().__init__() + self._otlp_exporter = OTLPSpanExporter( + endpoint=endpoint, + headers=dict(headers), + ) + + def _export_batch(self, items: list[ReadableSpan]) -> None: + self._otlp_exporter.export(items) + + +class QueueingLogExporter(BaseQueueingExporter, LogExporter): + _otlp_exporter: OTLPLogExporter + + def __init__(self, endpoint: str, headers: tuple[tuple[str, str]]): + super().__init__() + self._otlp_exporter = OTLPLogExporter( + endpoint=endpoint, + headers=dict(headers), + ) + + def _export_batch(self, items: list[LogData]) -> None: + self._otlp_exporter.export(items) diff --git a/tests/telemetry/test_instrumentation.py b/tests/telemetry/test_instrumentation.py index c56b8bb58229..c8bb65b2104c 100644 --- a/tests/telemetry/test_instrumentation.py +++ b/tests/telemetry/test_instrumentation.py @@ -4,7 +4,6 @@ import pytest from opentelemetry import metrics, trace from opentelemetry._logs._internal import get_logger_provider -from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler from opentelemetry.sdk._logs.export import SimpleLogRecordProcessor @@ -18,6 +17,16 @@ ) from prefect.telemetry.logging import get_log_handler from prefect.telemetry.processors import InFlightSpanProcessor +from prefect.telemetry.services import QueueingLogExporter, QueueingSpanExporter + + +@pytest.fixture +def shutdown_telemetry(): + yield + + provider = trace.get_tracer_provider() + if isinstance(provider, TracerProvider): + provider.shutdown() def test_extract_account_and_workspace_id_valid_url( @@ -88,12 +97,10 @@ def test_trace_provider( span_processor = trace_provider._active_span_processor._span_processors[0] assert isinstance(span_processor, InFlightSpanProcessor) - assert ( - span_processor.span_exporter._endpoint # type: ignore - == ( - f"https://api.prefect.cloud/api/accounts/{telemetry_account_id}/" - f"workspaces/{telemetry_workspace_id}/telemetry/v1/traces" - ) + assert isinstance(span_processor.span_exporter, QueueingSpanExporter) + assert span_processor.span_exporter._otlp_exporter._endpoint == ( + f"https://api.prefect.cloud/api/accounts/{telemetry_account_id}/" + f"workspaces/{telemetry_workspace_id}/telemetry/v1/traces" ) assert trace.get_tracer_provider() == trace_provider @@ -147,14 +154,11 @@ def test_logger_provider( exporter = processor._exporter # type: ignore assert isinstance(processor, SimpleLogRecordProcessor) - assert isinstance(exporter, OTLPLogExporter) + assert isinstance(exporter, QueueingLogExporter) - assert ( - exporter._endpoint # type: ignore - == ( - f"https://api.prefect.cloud/api/accounts/{telemetry_account_id}/" - f"workspaces/{telemetry_workspace_id}/telemetry/v1/logs" - ) + assert exporter._otlp_exporter._endpoint == ( + f"https://api.prefect.cloud/api/accounts/{telemetry_account_id}/" + f"workspaces/{telemetry_workspace_id}/telemetry/v1/logs" ) assert get_logger_provider() == logger_provider