diff --git a/yente/app.py b/yente/app.py index 24d61262..0b205884 100644 --- a/yente/app.py +++ b/yente/app.py @@ -19,6 +19,7 @@ from yente.data import refresh_catalog from yente.search.indexer import update_index_threaded from yente.provider import close_provider +from yente.middleware import OpenTracingMiddleware log = get_logger("yente") ExceptionHandler = Callable[[Request, Any], Coroutine[Any, Any, Response]] @@ -108,6 +109,7 @@ def create_app() -> FastAPI: lifespan=lifespan, ) app.middleware("http")(request_middleware) + app.add_middleware(OpenTracingMiddleware) app.add_middleware( CORSMiddleware, allow_origins=["*"], diff --git a/yente/middleware/__init__.py b/yente/middleware/__init__.py new file mode 100644 index 00000000..11e095fe --- /dev/null +++ b/yente/middleware/__init__.py @@ -0,0 +1,3 @@ +from .open_tracing import OpenTracingMiddleware + +__all__ = ["OpenTracingMiddleware", "get_trace_context"] diff --git a/yente/middleware/open_tracing.py b/yente/middleware/open_tracing.py new file mode 100644 index 00000000..ac1d672f --- /dev/null +++ b/yente/middleware/open_tracing.py @@ -0,0 +1,126 @@ +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.datastructures import Headers +from starlette.requests import Request +from starlette.responses import Response +from typing import Any, Tuple, List +import secrets +from typing import List +from structlog.contextvars import get_contextvars, bind_contextvars + +VENDOR_CODE = ( + "yente" # It's available! https://w3c.github.io/tracestate-ids-registry/#registry +) + + +def parent_id() -> str: + return secrets.token_hex(8) + + +class TraceParent: + __slots__ = ["version", "trace_id", "parent_id", "trace_flags"] + + def __init__(self, version: str, trace_id: str, parent_id: str, trace_flags: str): + self.version = version + self.trace_id = trace_id + self.parent_id = parent_id + self.trace_flags = trace_flags + + def __str__(self) -> str: + return f"{self.version}-{self.trace_id}-{self.parent_id}-{self.trace_flags}" + + @classmethod + def create(cls) -> "TraceParent": + return cls("00", secrets.token_hex(16), parent_id(), "00") + + @classmethod + def from_str(cls, traceparent: str | None) -> "TraceParent": + if traceparent is None: + return cls.create() + parts = traceparent.split("-") + try: + version, trace_id, parent_id, trace_flags = parts[:4] + except Exception: + raise ValueError(f"Invalid traceparent: {traceparent}") + if int(version, 16) == 255: + raise ValueError(f"Unsupported version: {version}") + for i in trace_id: + if i != "0": + break + raise ValueError(f"Invalid trace_id: {trace_id}") + for i in parent_id: + if i != "0": + break + raise ValueError(f"Invalid parent_id: {parent_id}") + + return cls(version, trace_id, parent_id, trace_flags) + + +class TraceState: + __slots__ = ["tracestate"] + + def __init__(self, tracestate: List[Tuple[str, str]] = []): + self.tracestate = tracestate + + @classmethod + def create(cls, parent: TraceParent, prev_state: str = "") -> "TraceState": + spans_out: List[Tuple[str, str]] = [] + for span in prev_state.split(","): + parts = span.split("=") + if len(parts) != 2: + # We are allowed to discard invalid states + continue + vendor, value = parts + if vendor == VENDOR_CODE: + continue + spans_out.append((vendor.lower().strip(), value.lower().strip())) + spans_out.insert(0, (VENDOR_CODE, f"{parent.parent_id}")) + return cls(spans_out) + + def __str__(self) -> str: + return ",".join([f"{k}={v}" for k, v in self.tracestate]) + + +class TraceContext: + __slots__ = ["traceparent", "tracestate"] + + def __init__(self, traceparent: TraceParent, tracestate: TraceState): + self.traceparent = traceparent + self.tracestate = tracestate + + def __repr__(self) -> str: + return str( + { + "traceparent": str(self.traceparent), + "tracestate": str(self.tracestate), + } + ) + + +def get_trace_context() -> TraceContext | None: + vars = get_contextvars() + if "trace_context" in vars: + trace_context = vars["trace_context"] + if isinstance(trace_context, TraceContext): + return trace_context + return None + + +class OpenTracingMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: Any) -> Any: + parent_header = request.headers.get("traceparent") + try: + traceparent = TraceParent.from_str(parent_header) + except Exception: + traceparent = TraceParent.create() + state = request.headers.get("tracestate", "") + try: + tracestate = TraceState.create(traceparent, state) + except Exception: + tracestate = TraceState.create(traceparent, "") + context = TraceContext(traceparent, tracestate) + bind_contextvars(trace_context=context) + resp = await call_next(request) + resp.headers["traceparent"] = str(traceparent) + resp.headers["tracestate"] = str(tracestate) + return resp diff --git a/yente/provider/elastic.py b/yente/provider/elastic.py index 36d3b40e..40779259 100644 --- a/yente/provider/elastic.py +++ b/yente/provider/elastic.py @@ -13,6 +13,7 @@ from yente.logs import get_logger from yente.search.mapping import make_entity_mapping, INDEX_SETTINGS from yente.provider.base import SearchProvider, query_semaphore +from yente.middleware.open_tracing import get_trace_context log = get_logger(__name__) warnings.filterwarnings("ignore", category=ElasticsearchWarning) @@ -54,27 +55,40 @@ async def create(cls) -> "ElasticSearchProvider": raise RuntimeError("Could not connect to ElasticSearch.") def __init__(self, client: AsyncElasticsearch) -> None: - self.client = client + self._client = client + + def client(self, **kwargs: Any) -> AsyncElasticsearch: + """Get the client with the current context.""" + if trace_context := get_trace_context(): + arg_headers = kwargs.get("headers", {}) + headers = arg_headers | ( + dict( + traceparent=str(trace_context.traceparent), + tracestate=str(trace_context.tracestate), + ) + ) + kwargs.update(headers=headers) + return self._client.options(**kwargs) async def close(self) -> None: - await self.client.close() + await self._client.close() async def refresh(self, index: str) -> None: """Refresh the index to make changes visible.""" try: - await self.client.indices.refresh(index=index) + await self.client().indices.refresh(index=index) except NotFoundError as nfe: raise YenteNotFoundError(f"Index {index} does not exist.") from nfe async def get_all_indices(self) -> List[str]: """Get a list of all indices in the ElasticSearch cluster.""" - indices: Any = await self.client.cat.indices(format="json") + indices: Any = await self.client().cat.indices(format="json") return [index.get("index") for index in indices] async def get_alias_indices(self, alias: str) -> List[str]: """Get a list of indices that are aliased to the entity query alias.""" try: - resp = await self.client.indices.get_alias(name=alias) + resp = await self.client().indices.get_alias(name=alias) return list(resp.keys()) except NotFoundError: return [] @@ -88,7 +102,7 @@ async def rollover_index(self, alias: str, next_index: str, prefix: str) -> None actions = [] actions.append({"remove": {"index": f"{prefix}*", "alias": alias}}) actions.append({"add": {"index": next_index, "alias": alias}}) - await self.client.indices.update_aliases(actions=actions) + await self.client().indices.update_aliases(actions=actions) except (ApiError, TransportError) as te: raise YenteIndexError(f"Could not rollover index: {te}") from te @@ -97,19 +111,19 @@ async def clone_index(self, base_version: str, target_version: str) -> None: if base_version == target_version: raise ValueError("Cannot clone an index to itself.") try: - await self.client.indices.put_settings( + await self.client().indices.put_settings( index=base_version, settings={"index.blocks.read_only": True}, ) await self.delete_index(target_version) - await self.client.indices.clone( + await self.client().indices.clone( index=base_version, target=target_version, body={ "settings": {"index": {"blocks": {"read_only": False}}}, }, ) - await self.client.indices.put_settings( + await self.client().indices.put_settings( index=base_version, settings={"index.blocks.read_only": False}, ) @@ -122,7 +136,7 @@ async def create_index(self, index: str) -> None: """Create a new index with the given name.""" log.info("Create index", index=index) try: - await self.client.indices.create( + await self.client().indices.create( index=index, mappings=make_entity_mapping(), settings=INDEX_SETTINGS, @@ -135,7 +149,7 @@ async def create_index(self, index: str) -> None: async def delete_index(self, index: str) -> None: """Delete a given index if it exists.""" try: - await self.client.indices.delete(index=index) + await self.client().indices.delete(index=index) except NotFoundError: pass except (ApiError, TransportError) as te: @@ -144,7 +158,7 @@ async def delete_index(self, index: str) -> None: async def exists_index_alias(self, alias: str, index: str) -> bool: """Check if an index exists and is linked into the given alias.""" try: - exists = await self.client.indices.exists_alias(name=alias, index=index) + exists = await self.client().indices.exists_alias(name=alias, index=index) return True if exists.body else False except NotFoundError: return False @@ -153,8 +167,9 @@ async def exists_index_alias(self, alias: str, index: str) -> bool: async def check_health(self, index: str) -> bool: try: - client = self.client.options(request_timeout=5) - health = await client.cluster.health(index=index, timeout=0) + health = await self.client(request_timeout=5).cluster.health( + index=index, timeout=0 + ) return health.get("status") in ("yellow", "green") except NotFoundError as nfe: raise YenteNotFoundError(f"Index {index} does not exist.") from nfe @@ -182,7 +197,7 @@ async def search( try: async with query_semaphore: - response = await self.client.search( + response = await self.client().search( index=index, query=query, size=size, @@ -218,7 +233,7 @@ async def bulk_index(self, entities: AsyncIterator[Dict[str, Any]]) -> None: """Index a list of entities into the search index.""" try: await async_bulk( - self.client, + self.client(), entities, chunk_size=1000, yield_ok=False,