From 5f04a8930287c713fb5aad011b0c0820dee4455b Mon Sep 17 00:00:00 2001 From: SimonThordal Date: Tue, 16 Jul 2024 09:42:38 +0200 Subject: [PATCH] trace_ids should be local --- yente/provider/base.py | 1 + yente/provider/elastic.py | 45 +++++++++++++++++++----------------- yente/provider/opensearch.py | 38 +++++++++++++++--------------- yente/routers/util.py | 3 +-- yente/search/base.py | 6 ++--- 5 files changed, 47 insertions(+), 46 deletions(-) diff --git a/yente/provider/base.py b/yente/provider/base.py index 78224888..181b072d 100644 --- a/yente/provider/base.py +++ b/yente/provider/base.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from typing import Any, Dict, List, Optional from typing import AsyncIterator from threading import Lock diff --git a/yente/provider/elastic.py b/yente/provider/elastic.py index 8610376e..753dcd2d 100644 --- a/yente/provider/elastic.py +++ b/yente/provider/elastic.py @@ -11,7 +11,7 @@ from yente import settings from yente.exc import IndexNotReadyError, YenteIndexError, YenteNotFoundError from yente.logs import get_logger -from yente.search.base import query_semaphore +from yente.search.base import query_semaphore, get_trace_id from yente.search.mapping import make_entity_mapping, INDEX_SETTINGS from yente.provider.base import SearchProvider @@ -55,31 +55,27 @@ async def create(cls) -> "ElasticSearchProvider": raise RuntimeError("Could not connect to ElasticSearch.") def __init__(self, client: AsyncElasticsearch) -> None: - self.client = client + self._client = client async def close(self) -> None: - await self.client.close() - - def set_trace_id(self, id: str) -> None: - """Set the trace ID for the requests.""" - self.client = self.client.options(opaque_id=id) + await self._client.close() async def refresh(self, index: str) -> None: """Refresh the index to make changes visible.""" try: - await self.client.indices.refresh(index=index) + await self.client().indices.refresh(index=index) except NotFoundError as nfe: raise YenteNotFoundError(f"Index {index} does not exist.") from nfe async def get_all_indices(self) -> List[str]: """Get a list of all indices in the ElasticSearch cluster.""" - indices: Any = await self.client.cat.indices(format="json") + indices: Any = await self.client().cat.indices(format="json") return [index.get("index") for index in indices] async def get_alias_indices(self, alias: str) -> List[str]: """Get a list of indices that are aliased to the entity query alias.""" try: - resp = await self.client.indices.get_alias(name=alias) + resp = await self.client().indices.get_alias(name=alias) return list(resp.keys()) except NotFoundError: return [] @@ -93,7 +89,7 @@ async def rollover_index(self, alias: str, next_index: str, prefix: str) -> None actions = [] actions.append({"remove": {"index": f"{prefix}*", "alias": alias}}) actions.append({"add": {"index": next_index, "alias": alias}}) - await self.client.indices.update_aliases(actions=actions) + await self.client().indices.update_aliases(actions=actions) except (ApiError, TransportError) as te: raise YenteIndexError(f"Could not rollover index: {te}") from te @@ -102,19 +98,19 @@ async def clone_index(self, base_version: str, target_version: str) -> None: if base_version == target_version: raise ValueError("Cannot clone an index to itself.") try: - await self.client.indices.put_settings( + await self.client().indices.put_settings( index=base_version, settings={"index.blocks.read_only": True}, ) await self.delete_index(target_version) - await self.client.indices.clone( + await self.client().indices.clone( index=base_version, target=target_version, body={ "settings": {"index": {"blocks": {"read_only": False}}}, }, ) - await self.client.indices.put_settings( + await self.client().indices.put_settings( index=base_version, settings={"index.blocks.read_only": False}, ) @@ -127,7 +123,7 @@ async def create_index(self, index: str) -> None: """Create a new index with the given name.""" log.info("Create index", index=index) try: - await self.client.indices.create( + await self.client().indices.create( index=index, mappings=make_entity_mapping(), settings=INDEX_SETTINGS, @@ -140,7 +136,7 @@ async def create_index(self, index: str) -> None: async def delete_index(self, index: str) -> None: """Delete a given index if it exists.""" try: - await self.client.indices.delete(index=index) + await self.client().indices.delete(index=index) except NotFoundError: pass except (ApiError, TransportError) as te: @@ -149,7 +145,7 @@ async def delete_index(self, index: str) -> None: async def exists_index_alias(self, alias: str, index: str) -> bool: """Check if an index exists and is linked into the given alias.""" try: - exists = await self.client.indices.exists_alias(name=alias, index=index) + exists = await self.client().indices.exists_alias(name=alias, index=index) return True if exists.body else False except NotFoundError: return False @@ -158,8 +154,9 @@ async def exists_index_alias(self, alias: str, index: str) -> bool: async def check_health(self, index: str) -> bool: try: - es_ = self.client.options(request_timeout=5) - health = await es_.cluster.health(index=index, timeout=0) + health = await self.client(request_timeout=5).cluster.health( + index=index, timeout=0 + ) return health.get("status") in ("yellow", "green") except NotFoundError as nfe: raise YenteNotFoundError(f"Index {index} does not exist.") from nfe @@ -187,7 +184,7 @@ async def search( try: async with query_semaphore: - response = await self.client.search( + response = await self.client().search( index=index, query=query, size=size, @@ -223,7 +220,7 @@ async def bulk_index(self, entities: AsyncIterator[Dict[str, Any]]) -> None: """Index a list of entities into the search index.""" try: await async_bulk( - self.client, + self.client(), entities, chunk_size=1000, yield_ok=False, @@ -231,3 +228,9 @@ async def bulk_index(self, entities: AsyncIterator[Dict[str, Any]]) -> None: ) except BulkIndexError as exc: raise YenteIndexError(f"Could not index entities: {exc}") from exc + + def client(self, **kwargs: Any) -> AsyncElasticsearch: + args = { + "opaque_id": get_trace_id(), + } | kwargs or {} + return self._client.options(**args) diff --git a/yente/provider/opensearch.py b/yente/provider/opensearch.py index adbe8b21..bb93c8d7 100644 --- a/yente/provider/opensearch.py +++ b/yente/provider/opensearch.py @@ -49,32 +49,27 @@ async def create(cls) -> "OpenSearchProvider": raise RuntimeError("Could not connect to OpenSearch.") def __init__(self, client: AsyncOpenSearch) -> None: - self.client = client + self._client = client async def close(self) -> None: - await self.client.close() - - def set_trace_id(self, id: str) -> None: - """Set the trace ID for the requests.""" - # self.client.transport. - pass + await self._client.close() async def refresh(self, index: str) -> None: """Refresh the index to make changes visible.""" try: - await self.client.indices.refresh(index=index) + await self.client().indices.refresh(index=index) except NotFoundError as nfe: raise YenteNotFoundError(f"Index {index} does not exist.") from nfe async def get_all_indices(self) -> List[str]: """Get a list of all indices in the ElasticSearch cluster.""" - indices: Any = await self.client.cat.indices(format="json") + indices: Any = await self.client().cat.indices(format="json") return [index.get("index") for index in indices] async def get_alias_indices(self, alias: str) -> List[str]: """Get a list of indices that are aliased to the entity query alias.""" try: - resp = await self.client.indices.get_alias(name=alias) + resp = await self.client().indices.get_alias(name=alias) return list(resp.keys()) except NotFoundError: return [] @@ -91,7 +86,7 @@ async def rollover_index(self, alias: str, next_index: str, prefix: str) -> None {"add": {"index": next_index, "alias": alias}}, ] } - await self.client.indices.update_aliases(body) + await self.client().indices.update_aliases(body) except TransportError as te: raise YenteIndexError(f"Could not rollover index: {te}") from te @@ -100,19 +95,19 @@ async def clone_index(self, base_version: str, target_version: str) -> None: if base_version == target_version: raise ValueError("Cannot clone an index to itself.") try: - await self.client.indices.put_settings( + await self.client().indices.put_settings( index=base_version, body={"settings": {"index.blocks.read_only": True}}, ) await self.delete_index(target_version) - await self.client.indices.clone( + await self.client().indices.clone( index=base_version, target=target_version, body={ "settings": {"index": {"blocks": {"read_only": False}}}, }, ) - await self.client.indices.put_settings( + await self.client().indices.put_settings( index=base_version, body={"settings": {"index.blocks.read_only": False}}, ) @@ -129,7 +124,7 @@ async def create_index(self, index: str) -> None: "settings": INDEX_SETTINGS, "mappings": make_entity_mapping(), } - await self.client.indices.create(index=index, body=body) + await self.client().indices.create(index=index, body=body) except TransportError as exc: if exc.error == "resource_already_exists_exception": return @@ -138,7 +133,7 @@ async def create_index(self, index: str) -> None: async def delete_index(self, index: str) -> None: """Delete a given index if it exists.""" try: - await self.client.indices.delete(index=index) + await self.client().indices.delete(index=index) except NotFoundError: pass except TransportError as te: @@ -147,7 +142,7 @@ async def delete_index(self, index: str) -> None: async def exists_index_alias(self, alias: str, index: str) -> bool: """Check if an index exists and is linked into the given alias.""" try: - resp = await self.client.indices.exists_alias(name=alias, index=index) + resp = await self.client().indices.exists_alias(name=alias, index=index) return bool(resp) except NotFoundError: return False @@ -156,7 +151,7 @@ async def exists_index_alias(self, alias: str, index: str) -> bool: async def check_health(self, index: str) -> bool: try: - health = await self.client.cluster.health(index=index, timeout=5) + health = await self.client().cluster.health(index=index, timeout=5) return health.get("status") in ("yellow", "green") except NotFoundError as nfe: raise YenteNotFoundError(f"Index {index} does not exist.") from nfe @@ -189,7 +184,7 @@ async def search( body["aggregations"] = aggregations if sort is not None: body["sort"] = sort - response = await self.client.search( + response = await self.client().search( index=index, size=size, from_=from_, @@ -217,7 +212,7 @@ async def bulk_index(self, entities: AsyncIterator[Dict[str, Any]]) -> None: """Index a list of entities into the search index.""" try: await async_bulk( - self.client, + self.client(), entities, chunk_size=1000, yield_ok=False, @@ -225,3 +220,6 @@ async def bulk_index(self, entities: AsyncIterator[Dict[str, Any]]) -> None: ) except BulkIndexError as exc: raise YenteIndexError(f"Could not index entities: {exc}") from exc + + def client(self) -> AsyncOpenSearch: + return self._client diff --git a/yente/routers/util.py b/yente/routers/util.py index 4a40ceb6..394d4ec6 100644 --- a/yente/routers/util.py +++ b/yente/routers/util.py @@ -6,7 +6,7 @@ from yente import settings from yente.data import get_catalog from yente.data.dataset import Dataset -from yente.search.base import get_opaque_id +from yente.search.base import get_trace_id from yente.provider import SearchProvider, with_provider @@ -43,5 +43,4 @@ async def get_dataset(name: str) -> Dataset: async def get_request_provider() -> AsyncIterator[SearchProvider]: async with with_provider() as provider: - provider.set_trace_id(get_opaque_id()) yield provider diff --git a/yente/search/base.py b/yente/search/base.py index 96d480db..2f9d7f5c 100644 --- a/yente/search/base.py +++ b/yente/search/base.py @@ -1,5 +1,5 @@ import asyncio -from typing import cast +from typing import Any from structlog.contextvars import get_contextvars from yente import settings @@ -10,6 +10,6 @@ query_semaphore = asyncio.Semaphore(settings.QUERY_CONCURRENCY) -def get_opaque_id() -> str: +def get_trace_id() -> Any: ctx = get_contextvars() - return cast(str, ctx.get("trace_id")) + return ctx.get("trace_id")