Skip to content

Commit

Permalink
Wrap OTLPLogExporter and OTLPSpanExporter in a QueueService to pull t…
Browse files Browse the repository at this point in the history
…hem out of the main thread (#16439)
  • Loading branch information
bunchesofdonald authored Dec 18, 2024
1 parent 581510a commit 997e61d
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 28 deletions.
10 changes: 6 additions & 4 deletions src/prefect/_internal/concurrency/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def drain(self, at_exit: bool = False) -> None:
return future.result()

@classmethod
def drain_all(cls, timeout: Optional[float] = None) -> Union[Awaitable, None]:
def drain_all(
cls, timeout: Optional[float] = None, at_exit=True
) -> Union[Awaitable, None]:
"""
Stop all instances of the service and wait for all remaining work to be
completed.
Expand All @@ -237,7 +239,7 @@ def drain_all(cls, timeout: Optional[float] = None) -> Union[Awaitable, None]:
instances = tuple(cls._instances.values())

for instance in instances:
futures.append(instance._drain())
futures.append(instance._drain(at_exit=at_exit))

if get_running_loop() is not None:
return (
Expand Down Expand Up @@ -376,10 +378,10 @@ def _get_size(self, item: T) -> int:
@contextlib.contextmanager
def drain_on_exit(service: QueueService):
yield
service.drain_all()
service.drain_all(at_exit=True)


@contextlib.asynccontextmanager
async def drain_on_exit_async(service: QueueService):
yield
await service.drain_all()
await service.drain_all(at_exit=True)
19 changes: 9 additions & 10 deletions src/prefect/telemetry/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from opentelemetry import metrics, trace
from opentelemetry._logs import set_logger_provider
from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
from opentelemetry.sdk._logs.export import SimpleLogRecordProcessor
from opentelemetry.sdk.metrics import MeterProvider
Expand All @@ -19,6 +17,7 @@

from .logging import set_log_handler
from .processors import InFlightSpanProcessor
from .services import QueueingLogExporter, QueueingSpanExporter

if TYPE_CHECKING:
from opentelemetry.sdk._logs import LoggerProvider
Expand Down Expand Up @@ -83,11 +82,10 @@ def _setup_trace_provider(
resource: Resource, headers: dict[str, str], telemetry_url: str
) -> TracerProvider:
trace_provider = TracerProvider(resource=resource)
otlp_span_exporter = OTLPSpanExporter(
endpoint=_url_join(telemetry_url, "v1/traces"),
headers=headers,
queueing_span_exporter = QueueingSpanExporter.instance(
_url_join(telemetry_url, "v1/traces"), tuple(headers.items())
)
trace_provider.add_span_processor(InFlightSpanProcessor(otlp_span_exporter))
trace_provider.add_span_processor(InFlightSpanProcessor(queueing_span_exporter))
trace.set_tracer_provider(trace_provider)

return trace_provider
Expand All @@ -112,11 +110,12 @@ def _setup_logger_provider(
resource: Resource, headers: dict[str, str], telemetry_url: str
) -> LoggerProvider:
logger_provider = LoggerProvider(resource=resource)
otlp_exporter = OTLPLogExporter(
endpoint=_url_join(telemetry_url, "v1/logs"),
headers=headers,
queueing_log_exporter = QueueingLogExporter.instance(
_url_join(telemetry_url, "v1/logs"), tuple(headers.items())
)
logger_provider.add_log_record_processor(
SimpleLogRecordProcessor(queueing_log_exporter)
)
logger_provider.add_log_record_processor(SimpleLogRecordProcessor(otlp_exporter))
set_logger_provider(logger_provider)
log_handler = LoggingHandler(level=logging.NOTSET, logger_provider=logger_provider)

Expand Down
67 changes: 67 additions & 0 deletions src/prefect/telemetry/services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from abc import abstractmethod
from typing import Union

from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk._logs import LogData
from opentelemetry.sdk._logs.export import LogExporter
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanExporter

from prefect._internal.concurrency.services import BatchedQueueService


class BaseQueueingExporter(BatchedQueueService):
_max_batch_size = 512
_min_interval = 2.0
_otlp_exporter: Union[SpanExporter, LogExporter]

def export(self, batch: list[Union[ReadableSpan, LogData]]) -> None:
for item in batch:
self.send(item)

@abstractmethod
def _export_batch(self, items: list[Union[ReadableSpan, LogData]]) -> None:
pass

async def _handle_batch(self, items: list[Union[ReadableSpan, LogData]]) -> None:
try:
self._export_batch(items)
except Exception as e:
self._logger.exception(f"Failed to export batch: {e}")
raise

def shutdown(self) -> None:
if self._stopped:
return

self.drain()
self._otlp_exporter.shutdown()


class QueueingSpanExporter(BaseQueueingExporter, SpanExporter):
_otlp_exporter: OTLPSpanExporter

def __init__(self, endpoint: str, headers: tuple[tuple[str, str]]):
super().__init__()
self._otlp_exporter = OTLPSpanExporter(
endpoint=endpoint,
headers=dict(headers),
)

def _export_batch(self, items: list[ReadableSpan]) -> None:
self._otlp_exporter.export(items)


class QueueingLogExporter(BaseQueueingExporter, LogExporter):
_otlp_exporter: OTLPLogExporter

def __init__(self, endpoint: str, headers: tuple[tuple[str, str]]):
super().__init__()
self._otlp_exporter = OTLPLogExporter(
endpoint=endpoint,
headers=dict(headers),
)

def _export_batch(self, items: list[LogData]) -> None:
self._otlp_exporter.export(items)
32 changes: 18 additions & 14 deletions tests/telemetry/test_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
from opentelemetry import metrics, trace
from opentelemetry._logs._internal import get_logger_provider
from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
from opentelemetry.sdk._logs.export import SimpleLogRecordProcessor
Expand All @@ -18,6 +17,16 @@
)
from prefect.telemetry.logging import get_log_handler
from prefect.telemetry.processors import InFlightSpanProcessor
from prefect.telemetry.services import QueueingLogExporter, QueueingSpanExporter


@pytest.fixture
def shutdown_telemetry():
yield

provider = trace.get_tracer_provider()
if isinstance(provider, TracerProvider):
provider.shutdown()


def test_extract_account_and_workspace_id_valid_url(
Expand Down Expand Up @@ -88,12 +97,10 @@ def test_trace_provider(
span_processor = trace_provider._active_span_processor._span_processors[0]

assert isinstance(span_processor, InFlightSpanProcessor)
assert (
span_processor.span_exporter._endpoint # type: ignore
== (
f"https://api.prefect.cloud/api/accounts/{telemetry_account_id}/"
f"workspaces/{telemetry_workspace_id}/telemetry/v1/traces"
)
assert isinstance(span_processor.span_exporter, QueueingSpanExporter)
assert span_processor.span_exporter._otlp_exporter._endpoint == (
f"https://api.prefect.cloud/api/accounts/{telemetry_account_id}/"
f"workspaces/{telemetry_workspace_id}/telemetry/v1/traces"
)

assert trace.get_tracer_provider() == trace_provider
Expand Down Expand Up @@ -147,14 +154,11 @@ def test_logger_provider(
exporter = processor._exporter # type: ignore

assert isinstance(processor, SimpleLogRecordProcessor)
assert isinstance(exporter, OTLPLogExporter)
assert isinstance(exporter, QueueingLogExporter)

assert (
exporter._endpoint # type: ignore
== (
f"https://api.prefect.cloud/api/accounts/{telemetry_account_id}/"
f"workspaces/{telemetry_workspace_id}/telemetry/v1/logs"
)
assert exporter._otlp_exporter._endpoint == (
f"https://api.prefect.cloud/api/accounts/{telemetry_account_id}/"
f"workspaces/{telemetry_workspace_id}/telemetry/v1/logs"
)

assert get_logger_provider() == logger_provider
Expand Down

0 comments on commit 997e61d

Please sign in to comment.