Skip to content

Commit

Permalink
feat: add traceID and pass it to ES
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonThordal committed Jul 31, 2024
1 parent 144f640 commit 7478ccc
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 16 deletions.
2 changes: 2 additions & 0 deletions yente/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from yente.data import refresh_catalog
from yente.search.indexer import update_index_threaded
from yente.provider import close_provider
from yente.middleware import OpenTracingMiddleware

log = get_logger("yente")
ExceptionHandler = Callable[[Request, Any], Coroutine[Any, Any, Response]]
Expand Down Expand Up @@ -108,6 +109,7 @@ def create_app() -> FastAPI:
lifespan=lifespan,
)
app.middleware("http")(request_middleware)
app.add_middleware(OpenTracingMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand Down
3 changes: 3 additions & 0 deletions yente/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .open_tracing import OpenTracingMiddleware

__all__ = ["OpenTracingMiddleware", "get_trace_context"]
126 changes: 126 additions & 0 deletions yente/middleware/open_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.responses import Response
from typing import Any, Tuple, List
import secrets
from typing import List
from structlog.contextvars import get_contextvars, bind_contextvars

VENDOR_CODE = (
"yente" # It's available! https://w3c.github.io/tracestate-ids-registry/#registry
)


def parent_id() -> str:
return secrets.token_hex(8)


class TraceParent:
__slots__ = ["version", "trace_id", "parent_id", "trace_flags"]

def __init__(self, version: str, trace_id: str, parent_id: str, trace_flags: str):
self.version = version
self.trace_id = trace_id
self.parent_id = parent_id
self.trace_flags = trace_flags

def __str__(self) -> str:
return f"{self.version}-{self.trace_id}-{self.parent_id}-{self.trace_flags}"

@classmethod
def create(cls) -> "TraceParent":
return cls("00", secrets.token_hex(16), parent_id(), "00")

@classmethod
def from_str(cls, traceparent: str | None) -> "TraceParent":
if traceparent is None:
return cls.create()
parts = traceparent.split("-")
try:
version, trace_id, parent_id, trace_flags = parts[:4]
except Exception:
raise ValueError(f"Invalid traceparent: {traceparent}")
if int(version, 16) == 255:
raise ValueError(f"Unsupported version: {version}")
for i in trace_id:
if i != "0":
break
raise ValueError(f"Invalid trace_id: {trace_id}")
for i in parent_id:
if i != "0":
break
raise ValueError(f"Invalid parent_id: {parent_id}")

return cls(version, trace_id, parent_id, trace_flags)


class TraceState:
__slots__ = ["tracestate"]

def __init__(self, tracestate: List[Tuple[str, str]] = []):
self.tracestate = tracestate

@classmethod
def create(cls, parent: TraceParent, prev_state: str = "") -> "TraceState":
spans_out: List[Tuple[str, str]] = []
for span in prev_state.split(","):
parts = span.split("=")
if len(parts) != 2:
# We are allowed to discard invalid states
continue
vendor, value = parts
if vendor == VENDOR_CODE:
continue
spans_out.append((vendor.lower().strip(), value.lower().strip()))
spans_out.insert(0, (VENDOR_CODE, f"{parent.parent_id}"))
return cls(spans_out)

def __str__(self) -> str:
return ",".join([f"{k}={v}" for k, v in self.tracestate])


class TraceContext:
__slots__ = ["traceparent", "tracestate"]

def __init__(self, traceparent: TraceParent, tracestate: TraceState):
self.traceparent = traceparent
self.tracestate = tracestate

def __repr__(self) -> str:
return str(
{
"traceparent": str(self.traceparent),
"tracestate": str(self.tracestate),
}
)


def get_trace_context() -> TraceContext | None:
vars = get_contextvars()
if "trace_context" in vars:
trace_context = vars["trace_context"]
if isinstance(trace_context, TraceContext):
return trace_context
return None


class OpenTracingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Any) -> Any:
parent_header = request.headers.get("traceparent")
try:
traceparent = TraceParent.from_str(parent_header)
except Exception:
traceparent = TraceParent.create()
state = request.headers.get("tracestate", "")
try:
tracestate = TraceState.create(traceparent, state)
except Exception:
tracestate = TraceState.create(traceparent, "")
context = TraceContext(traceparent, tracestate)
bind_contextvars(trace_context=context)
resp = await call_next(request)
resp.headers["traceparent"] = str(traceparent)
resp.headers["tracestate"] = str(tracestate)
return resp
47 changes: 31 additions & 16 deletions yente/provider/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from yente.logs import get_logger
from yente.search.mapping import make_entity_mapping, INDEX_SETTINGS
from yente.provider.base import SearchProvider, query_semaphore
from yente.middleware.open_tracing import get_trace_context

log = get_logger(__name__)
warnings.filterwarnings("ignore", category=ElasticsearchWarning)
Expand Down Expand Up @@ -54,27 +55,40 @@ async def create(cls) -> "ElasticSearchProvider":
raise RuntimeError("Could not connect to ElasticSearch.")

def __init__(self, client: AsyncElasticsearch) -> None:
self.client = client
self._client = client

def client(self, **kwargs: Any) -> AsyncElasticsearch:
"""Get the client with the current context."""
if trace_context := get_trace_context():
arg_headers = kwargs.get("headers", {})
headers = arg_headers | (
dict(
traceparent=str(trace_context.traceparent),
tracestate=str(trace_context.tracestate),
)
)
kwargs.update(headers=headers)
return self._client.options(**kwargs)

async def close(self) -> None:
await self.client.close()
await self._client.close()

async def refresh(self, index: str) -> None:
"""Refresh the index to make changes visible."""
try:
await self.client.indices.refresh(index=index)
await self.client().indices.refresh(index=index)
except NotFoundError as nfe:
raise YenteNotFoundError(f"Index {index} does not exist.") from nfe

async def get_all_indices(self) -> List[str]:
"""Get a list of all indices in the ElasticSearch cluster."""
indices: Any = await self.client.cat.indices(format="json")
indices: Any = await self.client().cat.indices(format="json")
return [index.get("index") for index in indices]

async def get_alias_indices(self, alias: str) -> List[str]:
"""Get a list of indices that are aliased to the entity query alias."""
try:
resp = await self.client.indices.get_alias(name=alias)
resp = await self.client().indices.get_alias(name=alias)
return list(resp.keys())
except NotFoundError:
return []
Expand All @@ -88,7 +102,7 @@ async def rollover_index(self, alias: str, next_index: str, prefix: str) -> None
actions = []
actions.append({"remove": {"index": f"{prefix}*", "alias": alias}})
actions.append({"add": {"index": next_index, "alias": alias}})
await self.client.indices.update_aliases(actions=actions)
await self.client().indices.update_aliases(actions=actions)
except (ApiError, TransportError) as te:
raise YenteIndexError(f"Could not rollover index: {te}") from te

Expand All @@ -97,19 +111,19 @@ async def clone_index(self, base_version: str, target_version: str) -> None:
if base_version == target_version:
raise ValueError("Cannot clone an index to itself.")
try:
await self.client.indices.put_settings(
await self.client().indices.put_settings(
index=base_version,
settings={"index.blocks.read_only": True},
)
await self.delete_index(target_version)
await self.client.indices.clone(
await self.client().indices.clone(
index=base_version,
target=target_version,
body={
"settings": {"index": {"blocks": {"read_only": False}}},
},
)
await self.client.indices.put_settings(
await self.client().indices.put_settings(
index=base_version,
settings={"index.blocks.read_only": False},
)
Expand All @@ -122,7 +136,7 @@ async def create_index(self, index: str) -> None:
"""Create a new index with the given name."""
log.info("Create index", index=index)
try:
await self.client.indices.create(
await self.client().indices.create(
index=index,
mappings=make_entity_mapping(),
settings=INDEX_SETTINGS,
Expand All @@ -135,7 +149,7 @@ async def create_index(self, index: str) -> None:
async def delete_index(self, index: str) -> None:
"""Delete a given index if it exists."""
try:
await self.client.indices.delete(index=index)
await self.client().indices.delete(index=index)
except NotFoundError:
pass
except (ApiError, TransportError) as te:
Expand All @@ -144,7 +158,7 @@ async def delete_index(self, index: str) -> None:
async def exists_index_alias(self, alias: str, index: str) -> bool:
"""Check if an index exists and is linked into the given alias."""
try:
exists = await self.client.indices.exists_alias(name=alias, index=index)
exists = await self.client().indices.exists_alias(name=alias, index=index)
return True if exists.body else False
except NotFoundError:
return False
Expand All @@ -153,8 +167,9 @@ async def exists_index_alias(self, alias: str, index: str) -> bool:

async def check_health(self, index: str) -> bool:
try:
client = self.client.options(request_timeout=5)
health = await client.cluster.health(index=index, timeout=0)
health = await self.client(request_timeout=5).cluster.health(
index=index, timeout=0
)
return health.get("status") in ("yellow", "green")
except NotFoundError as nfe:
raise YenteNotFoundError(f"Index {index} does not exist.") from nfe
Expand Down Expand Up @@ -182,7 +197,7 @@ async def search(

try:
async with query_semaphore:
response = await self.client.search(
response = await self.client().search(
index=index,
query=query,
size=size,
Expand Down Expand Up @@ -218,7 +233,7 @@ async def bulk_index(self, entities: AsyncIterator[Dict[str, Any]]) -> None:
"""Index a list of entities into the search index."""
try:
await async_bulk(
self.client,
self.client(),
entities,
chunk_size=1000,
yield_ok=False,
Expand Down

0 comments on commit 7478ccc

Please sign in to comment.