Skip to content

Commit

Permalink
fix cache: the cache each time was replacing the original metadata ev…
Browse files Browse the repository at this point in the history
…en if the metadata request was just for a subset of topics, this pr tries to address this by updating only partially the metadata object
  • Loading branch information
eliax1996 committed Jun 4, 2024
1 parent 3f899ac commit be669bd
Show file tree
Hide file tree
Showing 3 changed files with 364 additions and 40 deletions.
151 changes: 111 additions & 40 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from binascii import Error as B64DecodeError
from collections import namedtuple
from confluent_kafka.error import KafkaException
Expand Down Expand Up @@ -36,7 +38,7 @@
)
from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType
from karapace.utils import convert_to_int, json_encode
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, TypedDict

import asyncio
import base64
Expand Down Expand Up @@ -66,10 +68,10 @@ def __init__(self, config: Config) -> None:
super().__init__(config=config)
self._add_kafka_rest_routes()
self.serializer = SchemaRegistrySerializer(config=config)
self.proxies: Dict[str, "UserRestProxy"] = {}
self.proxies: dict[str, UserRestProxy] = {}
self._proxy_lock = asyncio.Lock()
log.info("REST proxy starting with (delegated authorization=%s)", self.config.get("rest_authorization", False))
self._idle_proxy_janitor_task: Optional[asyncio.Task] = None
self._idle_proxy_janitor_task: asyncio.Task | None = None

async def close(self) -> None:
log.info("Closing REST proxy application")
Expand Down Expand Up @@ -416,32 +418,56 @@ async def topic_publish(self, topic: str, content_type: str, *, request: HTTPReq
await proxy.topic_publish(topic, content_type, request=request)


class _ReplicaMetadata(TypedDict):
broker: int
leader: bool
in_sync: bool


class _PartitionMetadata(TypedDict):
partition: int
leader: int
replicas: list[_ReplicaMetadata]


class _TopicMetadata(TypedDict):
partitions: list[_PartitionMetadata]


class _ClusterMetadata(TypedDict):
topics: dict[str, _TopicMetadata]
brokers: list[int]


class UserRestProxy:
def __init__(
self,
config: Config,
kafka_timeout: int,
serializer: SchemaRegistrySerializer,
auth_expiry: Optional[datetime.datetime] = None,
auth_expiry: datetime.datetime | None = None,
verify_connection: bool = True,
):
self.config = config
self.kafka_timeout = kafka_timeout
self.serializer = serializer
self._cluster_metadata = None
self._cluster_metadata_complete = False
self._metadata_birth = None
# birth of all the metadata (when the request was requiring all the metadata available in the cluster)
self._global_metadata_birth: float | None = None
self._cluster_metadata_topic_birth: dict[str, float] = {}
self.metadata_max_age = self.config["admin_metadata_max_age"]
self.admin_client = None
self.admin_lock = asyncio.Lock()
self.metadata_cache = None
self.topic_schema_cache = TopicSchemaCache()
self.consumer_manager = ConsumerManager(config=config, deserializer=self.serializer)
self.init_admin_client()
self.init_admin_client(verify_connection)
self._last_used = time.monotonic()
self._auth_expiry = auth_expiry

self._async_producer_lock = asyncio.Lock()
self._async_producer: Optional[AsyncKafkaProducer] = None
self._async_producer: AsyncKafkaProducer | None = None
self.naming_strategy = NameStrategy(self.config["name_strategy"])

def __str__(self) -> str:
Expand Down Expand Up @@ -607,28 +633,72 @@ async def get_topic_config(self, topic: str) -> dict:
async with self.admin_lock:
return self.admin_client.get_topic_config(topic)

async def cluster_metadata(self, topics: Optional[List[str]] = None) -> dict:
async with self.admin_lock:
if self._metadata_birth is None or time.monotonic() - self._metadata_birth > self.metadata_max_age:
self._cluster_metadata = None
def is_global_metadata_old(self) -> bool:
return (
self._global_metadata_birth is None or (time.monotonic() - self._global_metadata_birth) > self.metadata_max_age
)

if self._cluster_metadata:
# Return from metadata only if all queried topics have cached metadata
if topics is None:
if self._cluster_metadata_complete:
return self._cluster_metadata
elif all(topic in self._cluster_metadata["topics"] for topic in topics):
return {
**self._cluster_metadata,
"topics": {topic: self._cluster_metadata["topics"][topic] for topic in topics},
}
def is_metadata_of_topics_old(self, topics: list[str]) -> bool:
# Return from metadata only if all queried topics have cached metadata

if self._cluster_metadata_topic_birth is None:
return True

are_all_topic_queried_at_least_once = all(topic in self._cluster_metadata_topic_birth for topic in topics)

if not are_all_topic_queried_at_least_once:
return True

oldest_requested_topic_udpate_timestamp = min(self._cluster_metadata_topic_birth[topic] for topic in topics)
return (
are_all_topic_queried_at_least_once
and (time.monotonic() - oldest_requested_topic_udpate_timestamp) > self.metadata_max_age
)

def _update_all_metadata(self) -> _ClusterMetadata:
if not self.is_global_metadata_old() and self._cluster_metadata_complete:
return self._cluster_metadata

metadata_birth = time.monotonic()
metadata = self.admin_client.cluster_metadata(None)
for topic in metadata["topics"]:
self._cluster_metadata_topic_birth[topic] = metadata_birth

self._global_metadata_birth = metadata_birth
self._cluster_metadata = metadata
self._cluster_metadata_complete = True
return metadata

def _empty_cluster_metadata_cache(self) -> _ClusterMetadata:
return {"topics": {}, "brokers": []}

def _update_metadata_for_topics(self, topics: list[str]) -> _ClusterMetadata:
if not self.is_metadata_of_topics_old(topics):
return {
**self._cluster_metadata,
"topics": {topic: self._cluster_metadata["topics"][topic] for topic in topics},
}

metadata_birth = time.monotonic()
metadata = self.admin_client.cluster_metadata(topics)

if self._cluster_metadata is None:
self._cluster_metadata = self._empty_cluster_metadata_cache()

for topic in metadata["topics"]:
self._cluster_metadata_topic_birth[topic] = metadata_birth
self._cluster_metadata["topics"][topic] = metadata["topics"][topic]

self._cluster_metadata_complete = False
return metadata

async def cluster_metadata(self, topics: list[str] | None = None) -> _ClusterMetadata:
async with self.admin_lock:
try:
metadata_birth = time.monotonic()
metadata = self.admin_client.cluster_metadata(topics)
self._metadata_birth = metadata_birth
self._cluster_metadata = metadata
self._cluster_metadata_complete = topics is None
if topics is None:
metadata = self._update_all_metadata()
else:
metadata = self._update_metadata_for_topics(topics)
except KafkaException:
log.warning("Could not refresh cluster metadata")
KafkaRest.r(
Expand All @@ -641,7 +711,7 @@ async def cluster_metadata(self, topics: Optional[List[str]] = None) -> dict:
)
return metadata

def init_admin_client(self):
def init_admin_client(self, verify_connection: bool = True) -> KafkaAdminClient:
for retry in [True, True, False]:
try:
self.admin_client = KafkaAdminClient(
Expand All @@ -652,6 +722,7 @@ def init_admin_client(self):
ssl_keyfile=self.config["ssl_keyfile"],
metadata_max_age_ms=self.config["metadata_max_age_ms"],
connections_max_idle_ms=self.config["connections_max_idle_ms"],
verify_connection=verify_connection,
**get_kafka_client_auth_parameters_from_config(self.config),
)
break
Expand All @@ -675,7 +746,7 @@ async def aclose(self) -> None:
self.admin_client = None
self.consumer_manager = None

async def publish(self, topic: str, partition_id: Optional[str], content_type: str, request: HTTPRequest) -> None:
async def publish(self, topic: str, partition_id: str | None, content_type: str, request: HTTPRequest) -> None:
"""
:raises NoBrokersAvailable:
:raises AuthenticationFailedError:
Expand Down Expand Up @@ -797,7 +868,7 @@ async def get_schema_id(
:raises InvalidSchema:
"""
log.debug("[resolve schema id] Retrieving schema id for %r", data)
schema_id: Union[SchemaId, None] = (
schema_id: SchemaId | None = (
SchemaId(int(data[f"{subject_type}_schema_id"])) if f"{subject_type}_schema_id" in data else None
)
schema_str = data.get(f"{subject_type}_schema")
Expand All @@ -817,7 +888,7 @@ async def get_schema_id(
schema_id = await self._query_schema_id_from_cache_or_registry(parsed_schema, schema_str, subject_name)
else:

def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool:
def subject_not_included(schema: TypedSchema, subjects: list[Subject]) -> bool:
subject = get_subject_name(topic, schema, subject_type, self.naming_strategy)
return subject not in subjects

Expand All @@ -832,8 +903,8 @@ def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool:
return schema_id

async def _query_schema_and_subjects(
self, schema_id: SchemaId, *, need_new_call: Optional[Callable[[TypedSchema, List[Subject]], bool]]
) -> Tuple[TypedSchema, List[Subject]]:
self, schema_id: SchemaId, *, need_new_call: Callable[[TypedSchema, list[Subject]], bool] | None
) -> tuple[TypedSchema, list[Subject]]:
try:
return await self.serializer.get_schema_for_id(schema_id, need_new_call=need_new_call)
except SchemaRetrievalError as schema_error:
Expand Down Expand Up @@ -924,10 +995,10 @@ async def _prepare_records(
content_type: str,
data: dict,
ser_format: str,
key_schema_id: Optional[int],
value_schema_id: Optional[int],
default_partition: Optional[int] = None,
) -> List[Tuple]:
key_schema_id: int | None,
value_schema_id: int | None,
default_partition: int | None = None,
) -> list[tuple]:
prepared_records = []
for record in data["records"]:
key = record.get("key")
Expand Down Expand Up @@ -978,8 +1049,8 @@ async def serialize(
self,
content_type: str,
obj=None,
ser_format: Optional[str] = None,
schema_id: Optional[int] = None,
ser_format: str | None = None,
schema_id: int | None = None,
) -> bytes:
if not obj:
return b""
Expand All @@ -1003,7 +1074,7 @@ async def serialize(
return await self.schema_serialize(obj, schema_id)
raise FormatError(f"Unknown format: {ser_format}")

async def schema_serialize(self, obj: dict, schema_id: Optional[int]) -> bytes:
async def schema_serialize(self, obj: dict, schema_id: int | None) -> bytes:
schema, _ = await self.serializer.get_schema_for_id(schema_id)
bytes_ = await self.serializer.serialize(schema, obj)
return bytes_
Expand Down Expand Up @@ -1066,7 +1137,7 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte
sub_code=RESTErrorCodes.INVALID_DATA.value,
)

async def produce_messages(self, *, topic: str, prepared_records: List) -> List:
async def produce_messages(self, *, topic: str, prepared_records: list) -> list:
"""
:raises NoBrokersAvailable:
:raises AuthenticationFailedError:
Expand Down
Empty file.
Loading

0 comments on commit be669bd

Please sign in to comment.