Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cache: the cache metadata its always replaced even if its only updating a subportion of it #892

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 133 additions & 42 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from binascii import Error as B64DecodeError
from collections import namedtuple
from confluent_kafka.error import KafkaException
Expand Down Expand Up @@ -36,7 +38,7 @@
)
from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType
from karapace.utils import convert_to_int, json_encode
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, TypedDict

import asyncio
import base64
Expand Down Expand Up @@ -66,10 +68,10 @@ def __init__(self, config: Config) -> None:
super().__init__(config=config)
self._add_kafka_rest_routes()
self.serializer = SchemaRegistrySerializer(config=config)
self.proxies: Dict[str, "UserRestProxy"] = {}
self.proxies: dict[str, UserRestProxy] = {}
self._proxy_lock = asyncio.Lock()
log.info("REST proxy starting with (delegated authorization=%s)", self.config.get("rest_authorization", False))
self._idle_proxy_janitor_task: Optional[asyncio.Task] = None
self._idle_proxy_janitor_task: asyncio.Task | None = None

async def close(self) -> None:
log.info("Closing REST proxy application")
Expand Down Expand Up @@ -416,32 +418,56 @@ async def topic_publish(self, topic: str, content_type: str, *, request: HTTPReq
await proxy.topic_publish(topic, content_type, request=request)


class _ReplicaMetadata(TypedDict):
broker: int
leader: bool
in_sync: bool


class _PartitionMetadata(TypedDict):
partition: int
leader: int
replicas: list[_ReplicaMetadata]


class _TopicMetadata(TypedDict):
partitions: list[_PartitionMetadata]


class _ClusterMetadata(TypedDict):
topics: dict[str, _TopicMetadata]
brokers: list[int]


class UserRestProxy:
def __init__(
self,
config: Config,
kafka_timeout: int,
serializer: SchemaRegistrySerializer,
auth_expiry: Optional[datetime.datetime] = None,
auth_expiry: datetime.datetime | None = None,
verify_connection: bool = True,
):
self.config = config
self.kafka_timeout = kafka_timeout
self.serializer = serializer
self._cluster_metadata = None
self._cluster_metadata: _ClusterMetadata = self._empty_cluster_metadata_cache()
self._cluster_metadata_complete = False
self._metadata_birth = None
# birth of all the metadata (when the request was requiring all the metadata available in the cluster)
self._global_metadata_birth: float = 0.0 # set to this value will always require a refresh at the first call.
self._cluster_metadata_topic_birth: dict[str, float] = {}
self.metadata_max_age = self.config["admin_metadata_max_age"]
self.admin_client = None
self.admin_lock = asyncio.Lock()
self.metadata_cache = None
self.topic_schema_cache = TopicSchemaCache()
self.consumer_manager = ConsumerManager(config=config, deserializer=self.serializer)
self.init_admin_client()
self.init_admin_client(verify_connection)
self._last_used = time.monotonic()
self._auth_expiry = auth_expiry

self._async_producer_lock = asyncio.Lock()
self._async_producer: Optional[AsyncKafkaProducer] = None
self._async_producer: AsyncKafkaProducer | None = None
self.naming_strategy = NameStrategy(self.config["name_strategy"])

def __str__(self) -> str:
Expand Down Expand Up @@ -607,28 +633,92 @@ async def get_topic_config(self, topic: str) -> dict:
async with self.admin_lock:
return self.admin_client.get_topic_config(topic)

async def cluster_metadata(self, topics: Optional[List[str]] = None) -> dict:
async with self.admin_lock:
if self._metadata_birth is None or time.monotonic() - self._metadata_birth > self.metadata_max_age:
self._cluster_metadata = None

if self._cluster_metadata:
# Return from metadata only if all queried topics have cached metadata
if topics is None:
if self._cluster_metadata_complete:
return self._cluster_metadata
elif all(topic in self._cluster_metadata["topics"] for topic in topics):
return {
**self._cluster_metadata,
"topics": {topic: self._cluster_metadata["topics"][topic] for topic in topics},
}
def is_global_metadata_old(self) -> bool:
return (time.monotonic() - self._global_metadata_birth) > self.metadata_max_age

def is_metadata_of_topics_old(self, topics: list[str]) -> bool:
# Return from metadata only if all queried topics have cached metadata
are_all_topic_queried_at_least_once = all(topic in self._cluster_metadata_topic_birth for topic in topics)

if not are_all_topic_queried_at_least_once:
return True

oldest_requested_topic_update_timestamp = min(self._cluster_metadata_topic_birth[topic] for topic in topics)
return (
are_all_topic_queried_at_least_once
and (time.monotonic() - oldest_requested_topic_update_timestamp) > self.metadata_max_age
)

def _update_all_metadata(self) -> _ClusterMetadata:
if not self.is_global_metadata_old() and self._cluster_metadata_complete:
return self._cluster_metadata

metadata_birth = time.monotonic()
metadata = self.admin_client.cluster_metadata(None)
for topic in metadata["topics"]:
self._cluster_metadata_topic_birth[topic] = metadata_birth

self._global_metadata_birth = metadata_birth
self._cluster_metadata = metadata
self._cluster_metadata_complete = True
return metadata

def _empty_cluster_metadata_cache(self) -> _ClusterMetadata:
return {"topics": {}, "brokers": []}

def _update_metadata_for_topics(self, topics: list[str]) -> _ClusterMetadata:
if not self.is_metadata_of_topics_old(topics):
return {
**self._cluster_metadata,
"topics": {topic: self._cluster_metadata["topics"][topic] for topic in topics},
}

metadata_birth = time.monotonic()
metadata = self.admin_client.cluster_metadata(topics)

if self._cluster_metadata is None:
self._cluster_metadata = self._empty_cluster_metadata_cache()

# we need to refresh if at least 1 broker isn't present in the current metadata
need_refresh = not all(broker in self._cluster_metadata["brokers"] for broker in metadata["brokers"])

for topic in metadata["topics"]:
eliax1996 marked this conversation as resolved.
Show resolved Hide resolved
# or if there is a new topic
need_refresh = (
need_refresh
or (topic not in self._cluster_metadata["topics"])
# or if a topic has new/different data.
# nb: equality its valid since the _ClusterMetadata object its structurally
# composed only of primitives lists and dicts
or (self._cluster_metadata["topics"][topic] != metadata["topics"][topic])
)
self._cluster_metadata_topic_birth[topic] = metadata_birth
self._cluster_metadata["topics"][topic] = metadata["topics"][topic]

if need_refresh:
# we don't need to reason about expiration time since at each request
# for the global metadata it's checked before performing the request,
# so we need to guard only for new missing pieces of info
self._cluster_metadata_complete = False
else:
# for malicious actors we may also cache that a certain topic (that do not exist) it has been queried
# and for a while the reply isn't present. not implementing this now since its an additional complexity
# that may be unrequired. Leaving a comment and a warning there, if its present often in the logs the feature
# may be needed.
log.warning(
"Requested metadata for topics %s but the reply didn't triggered a cache invalidation. "
"Data not present on server side",
topics,
)
return metadata

async def cluster_metadata(self, topics: list[str] | None = None) -> _ClusterMetadata:
async with self.admin_lock:
try:
metadata_birth = time.monotonic()
metadata = self.admin_client.cluster_metadata(topics)
self._metadata_birth = metadata_birth
self._cluster_metadata = metadata
self._cluster_metadata_complete = topics is None
if topics is None or len(topics) == 0:
metadata = self._update_all_metadata()
else:
metadata = self._update_metadata_for_topics(topics)
except KafkaException:
log.warning("Could not refresh cluster metadata")
KafkaRest.r(
Expand All @@ -641,7 +731,7 @@ async def cluster_metadata(self, topics: Optional[List[str]] = None) -> dict:
)
return metadata

def init_admin_client(self):
def init_admin_client(self, verify_connection: bool = True) -> KafkaAdminClient:
for retry in [True, True, False]:
try:
self.admin_client = KafkaAdminClient(
Expand All @@ -652,6 +742,7 @@ def init_admin_client(self):
ssl_keyfile=self.config["ssl_keyfile"],
metadata_max_age_ms=self.config["metadata_max_age_ms"],
connections_max_idle_ms=self.config["connections_max_idle_ms"],
verify_connection=verify_connection,
**get_kafka_client_auth_parameters_from_config(self.config),
)
break
Expand All @@ -675,7 +766,7 @@ async def aclose(self) -> None:
self.admin_client = None
self.consumer_manager = None

async def publish(self, topic: str, partition_id: Optional[str], content_type: str, request: HTTPRequest) -> None:
async def publish(self, topic: str, partition_id: str | None, content_type: str, request: HTTPRequest) -> None:
"""
:raises NoBrokersAvailable:
:raises AuthenticationFailedError:
Expand Down Expand Up @@ -797,7 +888,7 @@ async def get_schema_id(
:raises InvalidSchema:
"""
log.debug("[resolve schema id] Retrieving schema id for %r", data)
schema_id: Union[SchemaId, None] = (
schema_id: SchemaId | None = (
SchemaId(int(data[f"{subject_type}_schema_id"])) if f"{subject_type}_schema_id" in data else None
)
schema_str = data.get(f"{subject_type}_schema")
Expand All @@ -817,7 +908,7 @@ async def get_schema_id(
schema_id = await self._query_schema_id_from_cache_or_registry(parsed_schema, schema_str, subject_name)
else:

def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool:
def subject_not_included(schema: TypedSchema, subjects: list[Subject]) -> bool:
subject = get_subject_name(topic, schema, subject_type, self.naming_strategy)
return subject not in subjects

Expand All @@ -832,8 +923,8 @@ def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool:
return schema_id

async def _query_schema_and_subjects(
self, schema_id: SchemaId, *, need_new_call: Optional[Callable[[TypedSchema, List[Subject]], bool]]
) -> Tuple[TypedSchema, List[Subject]]:
self, schema_id: SchemaId, *, need_new_call: Callable[[TypedSchema, list[Subject]], bool] | None
) -> tuple[TypedSchema, list[Subject]]:
try:
return await self.serializer.get_schema_for_id(schema_id, need_new_call=need_new_call)
except SchemaRetrievalError as schema_error:
Expand Down Expand Up @@ -924,10 +1015,10 @@ async def _prepare_records(
content_type: str,
data: dict,
ser_format: str,
key_schema_id: Optional[int],
value_schema_id: Optional[int],
default_partition: Optional[int] = None,
) -> List[Tuple]:
key_schema_id: int | None,
value_schema_id: int | None,
default_partition: int | None = None,
) -> list[tuple]:
prepared_records = []
for record in data["records"]:
key = record.get("key")
Expand Down Expand Up @@ -978,8 +1069,8 @@ async def serialize(
self,
content_type: str,
obj=None,
ser_format: Optional[str] = None,
schema_id: Optional[int] = None,
ser_format: str | None = None,
schema_id: int | None = None,
) -> bytes:
if not obj:
return b""
Expand All @@ -1003,7 +1094,7 @@ async def serialize(
return await self.schema_serialize(obj, schema_id)
raise FormatError(f"Unknown format: {ser_format}")

async def schema_serialize(self, obj: dict, schema_id: Optional[int]) -> bytes:
async def schema_serialize(self, obj: dict, schema_id: int | None) -> bytes:
schema, _ = await self.serializer.get_schema_for_id(schema_id)
bytes_ = await self.serializer.serialize(schema, obj)
return bytes_
Expand Down Expand Up @@ -1066,7 +1157,7 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte
sub_code=RESTErrorCodes.INVALID_DATA.value,
)

async def produce_messages(self, *, topic: str, prepared_records: List) -> List:
async def produce_messages(self, *, topic: str, prepared_records: list) -> list:
"""
:raises NoBrokersAvailable:
:raises AuthenticationFailedError:
Expand Down
Empty file.
Loading
Loading