Skip to content

Commit

Permalink
🐛⚗️Prometheus instrumentation incorrectly setup (#6398)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg authored Sep 19, 2024
1 parent b8592ca commit 1327f2e
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,27 @@


from fastapi import FastAPI
from prometheus_client import CollectorRegistry
from prometheus_fastapi_instrumentator import Instrumentator


def setup_prometheus_instrumentation(app: FastAPI) -> Instrumentator:
# NOTE: use that registry to prevent having a global one
app.state.prometheus_registry = registry = CollectorRegistry(auto_describe=True)
instrumentator = Instrumentator(
should_instrument_requests_inprogress=False, # bug in https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/317
inprogress_labels=False,
registry=registry,
).instrument(app)

instrumentator = (
Instrumentator(
should_instrument_requests_inprogress=True, inprogress_labels=False
)
.instrument(app)
.expose(app, include_in_schema=False)
)
async def _on_startup() -> None:
instrumentator.expose(app, include_in_schema=False)

def _unregister():
for collector in list(instrumentator.registry._collector_to_names.keys()):
instrumentator.registry.unregister(collector)
def _unregister() -> None:
# NOTE: avoid registering collectors multiple times when running unittests consecutively (https://stackoverflow.com/a/62489287)
for collector in list(registry._collector_to_names.keys()): # noqa: SLF001
registry.unregister(collector)

# avoid registering collectors multiple times when running unittests consecutively (https://stackoverflow.com/a/62489287)
app.add_event_handler("startup", _on_startup)
app.add_event_handler("shutdown", _unregister)
return instrumentator
25 changes: 9 additions & 16 deletions packages/service-library/src/servicelib/fastapi/tracing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
""" Adds fastapi middleware for tracing using opentelemetry instrumentation.
"""

import logging

from fastapi import FastAPI
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.instrumentation.fastapi import (
FastAPIInstrumentor, # pylint: disable=no-name-in-module
)
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
Expand All @@ -21,24 +20,19 @@

def setup_tracing(
app: FastAPI, tracing_settings: TracingSettings, service_name: str
) -> FastAPIInstrumentor | None:
) -> None:
if (
not tracing_settings.TRACING_OPENTELEMETRY_COLLECTOR_ENDPOINT
and not tracing_settings.TRACING_OPENTELEMETRY_COLLECTOR_PORT
):
log.warning("Skipping opentelemetry tracing setup")
return None
if (
not tracing_settings.TRACING_OPENTELEMETRY_COLLECTOR_ENDPOINT
or not tracing_settings.TRACING_OPENTELEMETRY_COLLECTOR_PORT
):
raise RuntimeError(
f"Variable opentelemetry_collector_endpoint [{tracing_settings.TRACING_OPENTELEMETRY_COLLECTOR_ENDPOINT}] or opentelemetry_collector_port [{tracing_settings.TRACING_OPENTELEMETRY_COLLECTOR_PORT}] unset. Tracing options incomplete."
)
return

# Set up the tracer provider
resource = Resource(attributes={"service.name": service_name})
trace.set_tracer_provider(TracerProvider(resource=resource))
tracer_provider = trace.get_tracer_provider()
global_tracer_provider = trace.get_tracer_provider()
assert isinstance(global_tracer_provider, TracerProvider) # nosec
tracing_destination: str = f"{tracing_settings.TRACING_OPENTELEMETRY_COLLECTOR_ENDPOINT}:{tracing_settings.TRACING_OPENTELEMETRY_COLLECTOR_PORT}/v1/traces"
log.info(
"Trying to connect service %s to tracing collector at %s.",
Expand All @@ -48,7 +42,6 @@ def setup_tracing(
# Configure OTLP exporter to send spans to the collector
otlp_exporter = OTLPSpanExporterHTTP(endpoint=tracing_destination)
span_processor = BatchSpanProcessor(otlp_exporter)
# Mypy bug --> https://github.com/open-telemetry/opentelemetry-python/issues/3713
tracer_provider.add_span_processor(span_processor) # type: ignore[attr-defined]
global_tracer_provider.add_span_processor(span_processor)
# Instrument FastAPI
return FastAPIInstrumentor().instrument_app(app) # type: ignore[no-any-return]
FastAPIInstrumentor().instrument_app(app)
11 changes: 11 additions & 0 deletions packages/service-library/src/servicelib/instrumentation.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
from dataclasses import dataclass

from prometheus_client import CollectorRegistry


@dataclass(slots=True, kw_only=True)
class MetricsBase:
subsystem: str
registry: CollectorRegistry


def get_metrics_namespace(application_name: str) -> str:
return application_name.replace("-", "_")
8 changes: 4 additions & 4 deletions services/agent/tests/unit/test_core_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def test_client(initialized_app: FastAPI) -> TestClient:
def test_health_ok(env: None, test_client: TestClient):
response = test_client.get("/health")
assert response.status_code == status.HTTP_200_OK
assert response.json() == None
assert response.json() is None


def test_health_fails_not_started(
env: None, initialized_app: FastAPI, test_client: TestClient
):
task_monitor: TaskMonitor = initialized_app.state.task_monitor
# emulate monitor not being started
task_monitor._was_started = False
task_monitor._was_started = False # noqa: SLF001

response = test_client.get("/health")
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
Expand All @@ -50,8 +50,8 @@ def test_health_fails_hanging_tasks(
task_monitor: TaskMonitor = initialized_app.state.task_monitor

# emulate tasks hanging
for task_data in task_monitor._to_start.values():
task_data._start_time = time() - 1e6
for task_data in task_monitor._to_start.values(): # noqa: SLF001
task_data._start_time = time() - 1e6 # noqa: SLF001

response = test_client.get("/health")
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ async def on_startup() -> None:
metrics_subsystem = (
"dynamic" if app_settings.AUTOSCALING_NODES_MONITORING else "computational"
)
app.state.instrumentation = AutoscalingInstrumentation(
registry=instrumentator.registry, subsystem=metrics_subsystem
app.state.instrumentation = (
AutoscalingInstrumentation( # pylint: disable=unexpected-keyword-arg
registry=instrumentator.registry, subsystem=metrics_subsystem
)
)

async def on_shutdown() -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Final

from prometheus_client import CollectorRegistry, Counter, Histogram
from servicelib.instrumentation import MetricsBase

from ...models import BufferPoolManager, Cluster
from ._constants import (
Expand All @@ -13,11 +14,6 @@
from ._utils import TrackedGauge, create_gauge


@dataclass(slots=True, kw_only=True)
class MetricsBase:
subsystem: str


@dataclass(slots=True, kw_only=True)
class ClusterMetrics(MetricsBase): # pylint: disable=too-many-instance-attributes
active_nodes: TrackedGauge = field(init=False)
Expand All @@ -36,7 +32,12 @@ def __post_init__(self) -> None:
cluster_subsystem = f"{self.subsystem}_cluster"
# Creating and assigning gauges using the field names and the metric definitions
for field_name, definition in CLUSTER_METRICS_DEFINITIONS.items():
gauge = create_gauge(field_name, definition, cluster_subsystem)
gauge = create_gauge(
field_name=field_name,
definition=definition,
subsystem=cluster_subsystem,
registry=self.registry,
)
setattr(self, field_name, gauge)

def update_from_cluster(self, cluster: Cluster) -> None:
Expand Down Expand Up @@ -65,27 +66,31 @@ def __post_init__(self) -> None:
labelnames=EC2_INSTANCE_LABELS,
namespace=METRICS_NAMESPACE,
subsystem=self.subsystem,
registry=self.registry,
)
self.started_instances = Counter(
"started_instances_total",
"Number of EC2 instances that were started",
labelnames=EC2_INSTANCE_LABELS,
namespace=METRICS_NAMESPACE,
subsystem=self.subsystem,
registry=self.registry,
)
self.stopped_instances = Counter(
"stopped_instances_total",
"Number of EC2 instances that were stopped",
labelnames=EC2_INSTANCE_LABELS,
namespace=METRICS_NAMESPACE,
subsystem=self.subsystem,
registry=self.registry,
)
self.terminated_instances = Counter(
"terminated_instances_total",
"Number of EC2 instances that were terminated",
labelnames=EC2_INSTANCE_LABELS,
namespace=METRICS_NAMESPACE,
subsystem=self.subsystem,
registry=self.registry,
)

def instance_started(self, instance_type: str) -> None:
Expand Down Expand Up @@ -123,7 +128,12 @@ def __post_init__(self) -> None:
setattr(
self,
field_name,
create_gauge(field_name, definition, buffer_pools_subsystem),
create_gauge(
field_name=field_name,
definition=definition,
subsystem=buffer_pools_subsystem,
registry=self.registry,
),
)
self.instances_ready_to_pull_seconds = Histogram(
"instances_ready_to_pull_duration_seconds",
Expand All @@ -132,6 +142,7 @@ def __post_init__(self) -> None:
namespace=METRICS_NAMESPACE,
subsystem=buffer_pools_subsystem,
buckets=(10, 20, 30, 40, 50, 60, 120),
registry=self.registry,
)
self.instances_completed_pulling_seconds = Histogram(
"instances_completed_pulling_duration_seconds",
Expand All @@ -150,6 +161,7 @@ def __post_init__(self) -> None:
30 * _MINUTE,
40 * _MINUTE,
),
registry=self.registry,
)

def update_from_buffer_pool_manager(
Expand All @@ -174,8 +186,16 @@ class AutoscalingInstrumentation(MetricsBase):
buffer_machines_pools_metrics: BufferPoolsMetrics = field(init=False)

def __post_init__(self) -> None:
self.cluster_metrics = ClusterMetrics(subsystem=self.subsystem)
self.ec2_client_metrics = EC2ClientMetrics(subsystem=self.subsystem)
self.buffer_machines_pools_metrics = BufferPoolsMetrics(
subsystem=self.subsystem
self.cluster_metrics = ClusterMetrics( # pylint: disable=unexpected-keyword-arg
subsystem=self.subsystem, registry=self.registry
)
self.ec2_client_metrics = (
EC2ClientMetrics( # pylint: disable=unexpected-keyword-arg
subsystem=self.subsystem, registry=self.registry
)
)
self.buffer_machines_pools_metrics = (
BufferPoolsMetrics( # pylint: disable=unexpected-keyword-arg
subsystem=self.subsystem, registry=self.registry
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass, field

from aws_library.ec2._models import EC2InstanceData
from prometheus_client import Gauge
from prometheus_client import CollectorRegistry, Gauge

from ._constants import METRICS_NAMESPACE

Expand All @@ -27,9 +27,11 @@ def update_from_instances(self, instances: Iterable[EC2InstanceData]) -> None:


def create_gauge(
*,
field_name: str,
definition: tuple[str, tuple[str, ...]],
subsystem: str,
registry: CollectorRegistry,
) -> TrackedGauge:
description, labelnames = definition
return TrackedGauge(
Expand All @@ -39,5 +41,6 @@ def create_gauge(
labelnames=labelnames,
namespace=METRICS_NAMESPACE,
subsystem=subsystem,
registry=registry,
)
)
8 changes: 7 additions & 1 deletion services/autoscaling/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,17 @@ def enabled_rabbitmq(
return rabbit_service


_LIFESPAN_TIMEOUT: Final[int] = 10


@pytest.fixture
async def initialized_app(app_environment: EnvVarsDict) -> AsyncIterator[FastAPI]:
settings = ApplicationSettings.create_from_envs()
app = create_app(settings)
async with LifespanManager(app):
# NOTE: the timeout is sometime too small for CI machines, and even larger machines
async with LifespanManager(
app, startup_timeout=_LIFESPAN_TIMEOUT, shutdown_timeout=_LIFESPAN_TIMEOUT
):
yield app


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TypedDict

from aws_library.ec2._models import EC2InstanceData
from prometheus_client import CollectorRegistry
from prometheus_client.metrics import MetricWrapperBase
from simcore_service_autoscaling.modules.instrumentation._constants import (
EC2_INSTANCE_LABELS,
Expand Down Expand Up @@ -40,10 +41,12 @@ def test_update_gauge_sets_old_entries_to_0(
fake_ec2_instance_data: Callable[..., EC2InstanceData]
):
# Create a Gauge with example labels
registry = CollectorRegistry()
tracked_gauge = create_gauge(
"example_gauge",
field_name="example_gauge",
definition=("An example gauge", EC2_INSTANCE_LABELS),
subsystem="whatever",
registry=registry,
)

ec2_instance_type_1 = fake_ec2_instance_data()
Expand Down
Loading

0 comments on commit 1327f2e

Please sign in to comment.