From 952b9f427251b2a8bb65492b92f6356a9b619fef Mon Sep 17 00:00:00 2001 From: Elia Migliore Date: Tue, 24 Oct 2023 10:09:51 +0200 Subject: [PATCH] Added support to other naming strategy, refactored different unrelated stuff and added a couple of tests --- README.rst | 2 +- karapace/config.py | 21 +-- karapace/kafka_rest_apis/__init__.py | 58 +++--- karapace/protobuf/schema.py | 58 ++++-- karapace/schema_reader.py | 35 +++- karapace/serialization.py | 131 ++++++++----- karapace/typing.py | 27 +++ tests/conftest.py | 4 +- tests/integration/test_rest.py | 6 +- .../protobuf/test_protobuf_schema_name.py | 105 +++++++++++ tests/unit/test_protobuf_serialization.py | 13 +- tests/unit/test_serialization.py | 177 +++++++++++++++--- 12 files changed, 484 insertions(+), 153 deletions(-) create mode 100644 tests/unit/protobuf/test_protobuf_schema_name.py diff --git a/README.rst b/README.rst index 93875719a..288685281 100644 --- a/README.rst +++ b/README.rst @@ -461,7 +461,7 @@ Keys to take special care are the ones needed to configure Kafka and advertised_ - Runtime directory for the ``protoc`` protobuf schema parser and code generator * - ``name_strategy`` - ``topic_name`` - - Name strategy to use when storing schemas from the kafka rest proxy service + - Name strategy to use when storing schemas from the kafka rest proxy service. You can opt between ``name_strategy`` , ``record_name`` and ``topic_record_name`` * - ``name_strategy_validation`` - ``true`` - If enabled, validate that given schema is registered under used name strategy when producing messages from Kafka Rest diff --git a/karapace/config.py b/karapace/config.py index c87275a8f..9212f6348 100644 --- a/karapace/config.py +++ b/karapace/config.py @@ -6,8 +6,8 @@ """ from __future__ import annotations -from enum import Enum, unique from karapace.constants import DEFAULT_AIOHTTP_CLIENT_MAX_SIZE, DEFAULT_PRODUCER_MAX_REQUEST, DEFAULT_SCHEMA_TOPIC +from karapace.typing import ElectionStrategy, NameStrategy from karapace.utils import json_decode, json_encode, JSONDecodeError from pathlib import Path from typing import IO, Mapping @@ -158,19 +158,6 @@ class InvalidConfiguration(Exception): pass -@unique -class ElectionStrategy(Enum): - highest = "highest" - lowest = "lowest" - - -@unique -class NameStrategy(Enum): - topic_name = "topic_name" - record_name = "record_name" - topic_record_name = "topic_record_name" - - def parse_env_value(value: str) -> str | int | bool: # we only have ints, strings and bools in the config try: @@ -273,8 +260,10 @@ def validate_config(config: Config) -> None: try: NameStrategy(name_strategy) except ValueError: - valid_strategies = [strategy.value for strategy in NameStrategy] - raise InvalidConfiguration(f"Invalid name strategy: {name_strategy}, valid values are {valid_strategies}") from None + valid_strategies = list(NameStrategy) + raise InvalidConfiguration( + f"Invalid default name strategy: {name_strategy}, valid values are {valid_strategies}" + ) from None if config["rest_authorization"] and config["sasl_bootstrap_uri"] is None: raise InvalidConfiguration( diff --git a/karapace/kafka_rest_apis/__init__.py b/karapace/kafka_rest_apis/__init__.py index c63194e52..b47dabad2 100644 --- a/karapace/kafka_rest_apis/__init__.py +++ b/karapace/kafka_rest_apis/__init__.py @@ -28,8 +28,14 @@ from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE from karapace.schema_models import TypedSchema, ValidatedTypedSchema from karapace.schema_type import SchemaType -from karapace.serialization import InvalidMessageSchema, InvalidPayload, SchemaRegistrySerializer, SchemaRetrievalError -from karapace.typing import SchemaId, Subject +from karapace.serialization import ( + get_subject_name, + InvalidMessageSchema, + InvalidPayload, + SchemaRegistrySerializer, + SchemaRetrievalError, +) +from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient from typing import Callable, Dict, List, Optional, Tuple, Union @@ -39,7 +45,7 @@ import logging import time -RECORD_KEYS = ["key", "value", "partition"] +SUBJECT_VALID_POSTFIX = [SubjectType.key, SubjectType.value] PUBLISH_KEYS = {"records", "value_schema", "value_schema_id", "key_schema", "key_schema_id"} RECORD_CODES = [42201, 42202] KNOWN_FORMATS = {"json", "avro", "protobuf", "binary"} @@ -439,6 +445,7 @@ def __init__( self._async_producer_lock = asyncio.Lock() self._async_producer: Optional[AIOKafkaProducer] = None + self.naming_strategy = NameStrategy(self.config["name_strategy"]) def __str__(self) -> str: return f"UserRestProxy(username={self.config['sasl_plain_username']})" @@ -759,7 +766,7 @@ async def get_schema_id( self, data: dict, topic: str, - prefix: str, + subject_type: SubjectType, schema_type: SchemaType, ) -> SchemaId: """ @@ -770,21 +777,27 @@ async def get_schema_id( """ log.debug("[resolve schema id] Retrieving schema id for %r", data) schema_id: Union[SchemaId, None] = ( - SchemaId(int(data[f"{prefix}_schema_id"])) if f"{prefix}_schema_id" in data else None + SchemaId(int(data[f"{subject_type}_schema_id"])) if f"{subject_type}_schema_id" in data else None ) - schema_str = data.get(f"{prefix}_schema") + schema_str = data.get(f"{subject_type}_schema") if schema_id is None and schema_str is None: raise InvalidSchema() if schema_id is None: parsed_schema = ValidatedTypedSchema.parse(schema_type, schema_str) - subject_name = self.serializer.get_subject_name(topic, parsed_schema, prefix, schema_type) + + subject_name = get_subject_name( + topic, + parsed_schema, + subject_type, + self.naming_strategy, + ) schema_id = await self._query_schema_id_from_cache_or_registry(parsed_schema, schema_str, subject_name) else: def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool: - subject = self.serializer.get_subject_name(topic, schema, prefix, schema_type) + subject = get_subject_name(topic, schema, subject_type, self.naming_strategy) return subject not in subjects parsed_schema, valid_subjects = await self._query_schema_and_subjects( @@ -833,7 +846,9 @@ async def _query_schema_id_from_cache_or_registry( ) return schema_id - async def validate_schema_info(self, data: dict, prefix: str, content_type: str, topic: str, schema_type: str): + async def validate_schema_info( + self, data: dict, subject_type: SubjectType, content_type: str, topic: str, schema_type: str + ): try: schema_type = SCHEMA_MAPPINGS[schema_type] except KeyError: @@ -848,7 +863,7 @@ async def validate_schema_info(self, data: dict, prefix: str, content_type: str, # will do in place updates of id keys, since calling these twice would be expensive try: - data[f"{prefix}_schema_id"] = await self.get_schema_id(data, topic, prefix, schema_type) + data[f"{subject_type}_schema_id"] = await self.get_schema_id(data, topic, subject_type, schema_type) except InvalidPayload: log.exception("Unable to retrieve schema id") KafkaRest.r( @@ -863,16 +878,17 @@ async def validate_schema_info(self, data: dict, prefix: str, content_type: str, KafkaRest.r( body={ "error_code": RESTErrorCodes.SCHEMA_RETRIEVAL_ERROR.value, - "message": f"Error when registering schema. format = {schema_type.value}, subject = {topic}-{prefix}", + "message": f"Error when registering schema." + f"format = {schema_type.value}, subject = {topic}-{subject_type}", }, content_type=content_type, status=HTTPStatus.REQUEST_TIMEOUT, ) except InvalidSchema: - if f"{prefix}_schema" in data: - err = f'schema = {data[f"{prefix}_schema"]}' + if f"{subject_type}_schema" in data: + err = f'schema = {data[f"{subject_type}_schema"]}' else: - err = f'schema_id = {data[f"{prefix}_schema_id"]}' + err = f'schema_id = {data[f"{subject_type}_schema_id"]}' KafkaRest.r( body={ "error_code": RESTErrorCodes.INVALID_DATA.value, @@ -1002,7 +1018,7 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte status=HTTPStatus.BAD_REQUEST, ) convert_to_int(r, "partition", content_type) - if set(r.keys()).difference(RECORD_KEYS): + if set(r.keys()).difference({subject_type.value for subject_type in SubjectType}): KafkaRest.unprocessable_entity( message="Invalid request format", content_type=content_type, @@ -1010,18 +1026,18 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte ) # disallow missing id and schema for any key/value list that has at least one populated element if formats["embedded_format"] in {"avro", "jsonschema", "protobuf"}: - for prefix, code in zip(RECORD_KEYS, RECORD_CODES): - if self.all_empty(data, prefix): + for subject_type, code in zip(SUBJECT_VALID_POSTFIX, RECORD_CODES): + if self.all_empty(data, subject_type): continue - if not self.is_valid_schema_request(data, prefix): + if not self.is_valid_schema_request(data, subject_type): KafkaRest.unprocessable_entity( - message=f"Request includes {prefix}s and uses a format that requires schemas " - f"but does not include the {prefix}_schema or {prefix}_schema_id fields", + message=f"Request includes {subject_type}s and uses a format that requires schemas " + f"but does not include the {subject_type}_schema or {subject_type.value}_schema_id fields", content_type=content_type, sub_code=code, ) try: - await self.validate_schema_info(data, prefix, content_type, topic, formats["embedded_format"]) + await self.validate_schema_info(data, subject_type, content_type, topic, formats["embedded_format"]) except InvalidMessageSchema as e: KafkaRest.unprocessable_entity( message=str(e), diff --git a/karapace/protobuf/schema.py b/karapace/protobuf/schema.py index 676591870..157eb5447 100644 --- a/karapace/protobuf/schema.py +++ b/karapace/protobuf/schema.py @@ -2,6 +2,9 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ + +from __future__ import annotations + from karapace.dataclasses import default_dataclass # Ported from square/wire: @@ -21,7 +24,7 @@ from karapace.protobuf.type_element import TypeElement from karapace.protobuf.utils import append_documentation, append_indented from karapace.schema_references import Reference -from typing import Iterable, List, Mapping, Optional, Sequence, Set, Tuple +from typing import Iterable, Mapping, Sequence import itertools @@ -126,10 +129,10 @@ class SourceFileReference: @default_dataclass class TypeTree: token: str - children: List["TypeTree"] - source_reference: Optional[SourceFileReference] + children: list[TypeTree] + source_reference: SourceFileReference | None - def source_reference_tree_recursive(self) -> Iterable[Optional[SourceFileReference]]: + def source_reference_tree_recursive(self) -> Iterable[SourceFileReference | None]: sources = [] if self.source_reference is None else [self.source_reference] for child in self.children: sources = itertools.chain(sources, child.source_reference_tree()) @@ -201,7 +204,7 @@ def __repr__(self) -> str: def _add_new_type_recursive( parent_tree: TypeTree, - remaining_tokens: List[str], + remaining_tokens: list[str], file: str, inserted_elements: int, ) -> None: @@ -249,8 +252,8 @@ class ProtobufSchema: def __init__( self, schema: str, - references: Optional[Sequence[Reference]] = None, - dependencies: Optional[Mapping[str, Dependency]] = None, + references: Sequence[Reference] | None = None, + dependencies: Mapping[str, Dependency] | None = None, ) -> None: if type(schema).__name__ != "str": raise IllegalArgumentException("Non str type of schema string") @@ -260,7 +263,7 @@ def __init__( self.references = references self.dependencies = dependencies - def type_in_tree(self, tree: TypeTree, remaining_tokens: List[str]) -> Optional[TypeTree]: + def type_in_tree(self, tree: TypeTree, remaining_tokens: list[str]) -> TypeTree | None: if remaining_tokens: to_seek = remaining_tokens.pop() @@ -270,10 +273,33 @@ def type_in_tree(self, tree: TypeTree, remaining_tokens: List[str]) -> Optional[ return None return tree - def type_exist_in_tree(self, tree: TypeTree, remaining_tokens: List[str]) -> bool: + def record_name(self) -> str | None: + if len(self.proto_file_element.types) == 0: + return None + + package_name = ( + self.proto_file_element.package_name + "." if self.proto_file_element.package_name not in [None, ""] else "" + ) + + first_element = None + first_enum = None + + for inspected_type in self.proto_file_element.types: + if isinstance(inspected_type, MessageElement): + first_element = inspected_type + break + + if first_enum is None and isinstance(inspected_type, EnumElement): + first_enum = inspected_type + + naming_element = first_element if first_element is not None else first_enum + + return package_name + naming_element.name + + def type_exist_in_tree(self, tree: TypeTree, remaining_tokens: list[str]) -> bool: return self.type_in_tree(tree, remaining_tokens) is not None - def recursive_imports(self) -> Set[str]: + def recursive_imports(self) -> set[str]: imports = set(self.proto_file_element.imports) if self.dependencies: @@ -282,7 +308,7 @@ def recursive_imports(self) -> Set[str]: return imports - def are_type_usage_valid(self, root_type_tree: TypeTree, used_types: List[UsedType]) -> Tuple[bool, Optional[str]]: + def are_type_usage_valid(self, root_type_tree: TypeTree, used_types: list[UsedType]) -> tuple[bool, str | None]: # Please note that this check only ensures the requested type exists. However, for performance reasons, it works in # the opposite way of how specificity works in Protobuf. In Protobuf, the type is matched not only to check if it # exists, but also based on the order of search: local definition comes before imported types. In this code, we @@ -408,7 +434,7 @@ def types_tree(self) -> TypeTree: return root_tree @staticmethod - def used_type(parent: str, element_type: str) -> List[UsedType]: + def used_type(parent: str, element_type: str) -> list[UsedType]: if element_type.find("map<") == 0: end = element_type.find(">") virgule = element_type.find(",") @@ -426,7 +452,7 @@ def dependencies_one_of( package_name: str, parent_name: str, one_of: OneOfElement, - ) -> List[UsedType]: + ) -> list[UsedType]: parent = package_name + "." + parent_name dependencies = [] for field in one_of.fields: @@ -438,7 +464,7 @@ def dependencies_one_of( ) return dependencies - def used_types(self) -> List[UsedType]: + def used_types(self) -> list[UsedType]: dependencies_used_types = [] if self.dependencies: for key in self.dependencies: @@ -469,7 +495,7 @@ def nested_used_type( package_name: str, parent_name: str, element_type: TypeElement, - ) -> List[str]: + ) -> list[str]: used_types = [] if isinstance(element_type, MessageElement): @@ -540,7 +566,7 @@ def to_schema(self) -> str: return "".join(strings) - def compare(self, other: "ProtobufSchema", result: CompareResult) -> CompareResult: + def compare(self, other: ProtobufSchema, result: CompareResult) -> CompareResult: return self.proto_file_element.compare( other.proto_file_element, result, diff --git a/karapace/schema_reader.py b/karapace/schema_reader.py index 3dec4a887..36a9bd0e0 100644 --- a/karapace/schema_reader.py +++ b/karapace/schema_reader.py @@ -8,6 +8,7 @@ from avro.schema import Schema as AvroSchema from contextlib import closing, ExitStack +from enum import Enum from jsonschema.validators import Draft7Validator from kafka import KafkaConsumer, TopicPartition from kafka.admin import KafkaAdminClient, NewTopic @@ -58,6 +59,13 @@ METRIC_SUBJECT_DATA_SCHEMA_VERSIONS_GAUGE: Final = "karapace_schema_reader_subject_data_schema_versions" +class MessageType(Enum): + config = "CONFIG" + schema = "SCHEMA" + delete_subject = "DELETE_SUBJECT" + no_operation = "NOOP" + + def _create_consumer_from_config(config: Config) -> KafkaConsumer: # Group not set on purpose, all consumers read the same data session_timeout_ms = config["session_timeout_ms"] @@ -522,14 +530,25 @@ def _handle_msg_schema(self, key: dict, value: dict | None) -> None: self.database.insert_referenced_by(subject=ref.subject, version=ref.version, schema_id=schema_id) def handle_msg(self, key: dict, value: dict | None) -> None: - if key["keytype"] == "CONFIG": - self._handle_msg_config(key, value) - elif key["keytype"] == "SCHEMA": - self._handle_msg_schema(key, value) - elif key["keytype"] == "DELETE_SUBJECT": - self._handle_msg_delete_subject(key, value) - elif key["keytype"] == "NOOP": # for spec completeness - pass + if "keytype" in key: + try: + message_type = MessageType(key["keytype"]) + + if message_type == MessageType.config: + self._handle_msg_config(key, value) + elif message_type == MessageType.schema: + self._handle_msg_schema(key, value) + elif message_type == MessageType.delete_subject: + self._handle_msg_delete_subject(key, value) + elif message_type == MessageType.no_operation: + pass + except ValueError: + LOG.error("The message %s-%s has been discarded because the %s is not managed", key, value, key["keytype"]) + + else: + LOG.error( + "The message %s-%s has been discarded because doesn't contain the `keytype` key in the key", key, value + ) def remove_referenced_by( self, diff --git a/karapace/serialization.py b/karapace/serialization.py index 29dc51a6c..c199bad7a 100644 --- a/karapace/serialization.py +++ b/karapace/serialization.py @@ -2,6 +2,8 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ +from __future__ import annotations + from aiohttp import BasicAuth from avro.io import BinaryDecoder, BinaryEncoder, DatumReader, DatumWriter from cachetools import TTLCache @@ -13,11 +15,12 @@ from karapace.errors import InvalidReferences from karapace.protobuf.exception import ProtobufTypeException from karapace.protobuf.io import ProtobufDatumReader, ProtobufDatumWriter +from karapace.protobuf.schema import ProtobufSchema from karapace.schema_models import InvalidSchema, ParsedTypedSchema, SchemaType, TypedSchema, ValidatedTypedSchema from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping -from karapace.typing import ResolvedVersion, SchemaId, Subject +from karapace.typing import NameStrategy, ResolvedVersion, SchemaId, Subject, SubjectType from karapace.utils import json_decode, json_encode -from typing import Any, Callable, Dict, List, MutableMapping, Optional, Set, Tuple +from typing import Any, Callable, MutableMapping from urllib.parse import quote import asyncio @@ -59,22 +62,44 @@ class SchemaUpdateError(SchemaError): pass -def topic_name_strategy(topic_name: str, record_name: str) -> str: # pylint: disable=unused-argument - return topic_name +class InvalidRecord(Exception): + pass + + +def topic_name_strategy( + topic_name: str, + record_name: str | None, # pylint: disable=unused-argument + subject_type: SubjectType, +) -> Subject: + return Subject(f"{topic_name}-{subject_type}") + +def record_name_strategy( + topic_name: str, # pylint: disable=unused-argument, + record_name: str | None, + subject_type: SubjectType, # pylint: disable=unused-argument +) -> Subject: + if record_name is None: + raise InvalidRecord( + "The provided record doesn't have a valid `record_name`, use another naming strategy or fix the schema" + ) -def record_name_strategy(topic_name: str, record_name: str) -> str: # pylint: disable=unused-argument - return record_name + return Subject(record_name) -def topic_record_name_strategy(topic_name: str, record_name: str) -> str: - return topic_name + "-" + record_name +def topic_record_name_strategy( + topic_name: str, + record_name: str | None, + subject_type: SubjectType, +) -> Subject: + validated_record_name = record_name_strategy(topic_name, record_name, subject_type) + return Subject(f"{topic_name}-{validated_record_name}") NAME_STRATEGIES = { - "topic_name": topic_name_strategy, - "record_name": record_name_strategy, - "topic_record_name": topic_record_name_strategy, + NameStrategy.topic_name: topic_name_strategy, + NameStrategy.record_name: record_name_strategy, + NameStrategy.topic_record_name: topic_record_name_strategy, } @@ -82,14 +107,14 @@ class SchemaRegistryClient: def __init__( self, schema_registry_url: str = "http://localhost:8081", - server_ca: Optional[str] = None, - session_auth: Optional[BasicAuth] = None, + server_ca: str | None = None, + session_auth: BasicAuth | None = None, ): self.client = Client(server_uri=schema_registry_url, server_ca=server_ca, session_auth=session_auth) self.base_url = schema_registry_url async def post_new_schema( - self, subject: str, schema: ValidatedTypedSchema, references: Optional[Reference] = None + self, subject: str, schema: ValidatedTypedSchema, references: Reference | None = None ) -> SchemaId: if schema.schema_type is SchemaType.PROTOBUF: if references: @@ -103,12 +128,12 @@ async def post_new_schema( raise SchemaRetrievalError(result.json()) return SchemaId(result.json()["id"]) - async def _get_schema_r( + async def _get_schema_recursive( self, subject: Subject, - explored_schemas: Set[Tuple[Subject, Optional[ResolvedVersion]]], - version: Optional[ResolvedVersion] = None, - ) -> Tuple[SchemaId, ValidatedTypedSchema, ResolvedVersion]: + explored_schemas: set[tuple[Subject, ResolvedVersion | None]], + version: ResolvedVersion | None = None, + ) -> tuple[SchemaId, ValidatedTypedSchema, ResolvedVersion]: if (subject, version) in explored_schemas: raise InvalidSchema( f"The schema has at least a cycle in dependencies, " @@ -131,7 +156,7 @@ async def _get_schema_r( references = [Reference.from_dict(data) for data in json_result["references"]] dependencies = {} for reference in references: - _, schema, version = await self._get_schema_r(reference.subject, explored_schemas, reference.version) + _, schema, version = await self._get_schema_recursive(reference.subject, explored_schemas, reference.version) dependencies[reference.name] = Dependency( name=reference.name, subject=reference.subject, version=version, target_schema=schema ) @@ -158,8 +183,8 @@ async def _get_schema_r( async def get_schema( self, subject: Subject, - version: Optional[ResolvedVersion] = None, - ) -> Tuple[SchemaId, ValidatedTypedSchema, ResolvedVersion]: + version: ResolvedVersion | None = None, + ) -> tuple[SchemaId, ValidatedTypedSchema, ResolvedVersion]: """ Retrieves the schema and its dependencies for the specified subject. @@ -174,9 +199,9 @@ async def get_schema( - ValidatedTypedSchema: The retrieved schema, validated and typed. - ResolvedVersion: The version of the schema that was retrieved. """ - return await self._get_schema_r(subject, set(), version) + return await self._get_schema_recursive(subject, set(), version) - async def get_schema_for_id(self, schema_id: SchemaId) -> Tuple[TypedSchema, List[Subject]]: + async def get_schema_for_id(self, schema_id: SchemaId) -> tuple[TypedSchema, list[Subject]]: result = await self.client.get(f"schemas/ids/{schema_id}", params={"includeSubjects": "True"}) if not result.ok: raise SchemaRetrievalError(result.json()["message"]) @@ -225,6 +250,31 @@ async def close(self): await self.client.close() +def get_subject_name( + topic_name: str, + schema: TypedSchema, + subject_type: SubjectType, + naming_strategy: NameStrategy, +) -> Subject: + record_name = None + + if schema.schema_type is SchemaType.AVRO: + if isinstance(schema.schema, avro.schema.NamedSchema): + record_name = schema.schema.fullname + else: + record_name = None + + if schema.schema_type is SchemaType.JSONSCHEMA: + record_name = schema.to_dict().get("title", None) + + if schema.schema_type is SchemaType.PROTOBUF: + assert isinstance(schema.schema, ProtobufSchema), "Expecting a protobuf schema" + record_name = schema.schema.record_name() + + naming_strategy = NAME_STRATEGIES[naming_strategy] + return naming_strategy(topic_name, record_name, subject_type) + + class SchemaRegistrySerializer: def __init__( self, @@ -232,7 +282,7 @@ def __init__( ) -> None: self.config = config self.state_lock = asyncio.Lock() - session_auth: Optional[BasicAuth] = None + session_auth: BasicAuth | None = None if self.config.get("registry_user") and self.config.get("registry_password"): session_auth = BasicAuth(self.config.get("registry_user"), self.config.get("registry_password"), encoding="utf8") if self.config.get("registry_ca"): @@ -243,37 +293,16 @@ def __init__( else: registry_url = f"http://{self.config['registry_host']}:{self.config['registry_port']}" registry_client = SchemaRegistryClient(registry_url, session_auth=session_auth) - name_strategy = config.get("name_strategy", "topic_name") - self.subject_name_strategy = NAME_STRATEGIES.get(name_strategy, topic_name_strategy) - self.registry_client: Optional[SchemaRegistryClient] = registry_client - self.ids_to_schemas: Dict[int, TypedSchema] = {} - self.ids_to_subjects: MutableMapping[int, List[Subject]] = TTLCache(maxsize=10000, ttl=600) - self.schemas_to_ids: Dict[str, SchemaId] = {} + self.registry_client: SchemaRegistryClient | None = registry_client + self.ids_to_schemas: dict[int, TypedSchema] = {} + self.ids_to_subjects: MutableMapping[int, list[Subject]] = TTLCache(maxsize=10000, ttl=600) + self.schemas_to_ids: dict[str, SchemaId] = {} async def close(self) -> None: if self.registry_client: await self.registry_client.close() self.registry_client = None - def get_subject_name( - self, - topic_name: str, - schema: TypedSchema, - subject_type: str, - schema_type: SchemaType, - ) -> Subject: - namespace = "dummy" - if schema_type is SchemaType.AVRO: - if isinstance(schema.schema, avro.schema.NamedSchema): - namespace = schema.schema.namespace - if schema_type is SchemaType.JSONSCHEMA: - namespace = schema.to_dict().get("namespace", "dummy") - # Protobuf does not use namespaces in terms of AVRO - if schema_type is SchemaType.PROTOBUF: - namespace = "" - - return Subject(f"{self.subject_name_strategy(topic_name, namespace)}-{subject_type}") - async def get_schema_for_subject(self, subject: Subject) -> TypedSchema: assert self.registry_client, "must not call this method after the object is closed." schema_id, schema, _ = await self.registry_client.get_schema(subject) @@ -303,8 +332,8 @@ async def get_schema_for_id( self, schema_id: SchemaId, *, - need_new_call: Optional[Callable[[TypedSchema, List[Subject]], bool]] = None, - ) -> Tuple[TypedSchema, List[Subject]]: + need_new_call: Callable[[TypedSchema, list[Subject]], bool] | None = None, + ) -> tuple[TypedSchema, list[Subject]]: assert self.registry_client, "must not call this method after the object is closed." if schema_id in self.ids_to_subjects: if need_new_call is None or not need_new_call(self.ids_to_schemas[schema_id], self.ids_to_subjects[schema_id]): diff --git a/karapace/typing.py b/karapace/typing.py index fb73c9370..48c6bd815 100644 --- a/karapace/typing.py +++ b/karapace/typing.py @@ -2,6 +2,7 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ +from enum import Enum, unique from typing import Dict, List, Mapping, NewType, Sequence, Union from typing_extensions import TypeAlias @@ -22,3 +23,29 @@ # basically the same SchemaID refer always to the same TypedSchema. SchemaId = NewType("SchemaId", int) TopicName = NewType("TopicName", str) + + +class StrEnum(str, Enum): + def __str__(self) -> str: + return str(self.value) + + +@unique +class ElectionStrategy(Enum): + highest = "highest" + lowest = "lowest" + + +@unique +class NameStrategy(StrEnum, Enum): + topic_name = "topic_name" + record_name = "record_name" + topic_record_name = "topic_record_name" + + +@unique +class SubjectType(StrEnum, Enum): + key = "key" + value = "value" + # partition it's a function of `str` and StrEnum its inherits from it. + partition_ = "partition" diff --git a/tests/conftest.py b/tests/conftest.py index 3b903c699..99ba55809 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -160,7 +160,7 @@ def fixture_session_logdir(request, tmp_path_factory, worker_id) -> Path: @pytest.fixture(scope="session", name="default_config_path") -def fixture_default_config(session_logdir: Path) -> str: +def fixture_default_config(session_logdir: Path) -> Path: path = session_logdir / "karapace_config.json" content = json.dumps({"registry_host": "localhost", "registry_port": 8081}).encode() content_len = len(content) @@ -170,7 +170,7 @@ def fixture_default_config(session_logdir: Path) -> str: raise OSError(f"Writing config failed, tried to write {content_len} bytes, but only {written} were written") fp.flush() os.fsync(fp) - return str(path) + return path @pytest.fixture(name="tmp_file", scope="function") diff --git a/tests/integration/test_rest.py b/tests/integration/test_rest.py index 9fec19285..dc551dad0 100644 --- a/tests/integration/test_rest.py +++ b/tests/integration/test_rest.py @@ -7,7 +7,7 @@ from kafka import KafkaProducer from kafka.errors import UnknownTopicOrPartitionError from karapace.client import Client -from karapace.kafka_rest_apis import KafkaRest, KafkaRestAdminClient +from karapace.kafka_rest_apis import KafkaRest, KafkaRestAdminClient, SUBJECT_VALID_POSTFIX from karapace.version import __version__ from pytest import raises from tests.integration.conftest import REST_PRODUCER_MAX_REQUEST_BYTES @@ -172,9 +172,9 @@ async def test_avro_publish( new_schema_id = res.json()["id"] # test checks schema id use for key and value, register schema for both with topic naming strategy - for pl_type in ["key", "value"]: + for pl_type in SUBJECT_VALID_POSTFIX: res = await registry_async_client.post( - f"subjects/{tn}-{pl_type}/versions", json={"schema": schema_avro_json_evolution} + f"subjects/{tn}-{pl_type.value}/versions", json={"schema": schema_avro_json_evolution} ) assert res.ok assert res.json()["id"] == new_schema_id diff --git a/tests/unit/protobuf/test_protobuf_schema_name.py b/tests/unit/protobuf/test_protobuf_schema_name.py new file mode 100644 index 000000000..bcbca9fd0 --- /dev/null +++ b/tests/unit/protobuf/test_protobuf_schema_name.py @@ -0,0 +1,105 @@ +""" +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +from karapace.protobuf.schema import ProtobufSchema +from karapace.schema_models import ValidatedTypedSchema +from karapace.schema_type import SchemaType +from tests.utils import schema_protobuf_second + +import pytest + +MESSAGE_WITH_ENUM = """\ +syntax = "proto3"; + +option java_package = "com.codingharbour.protobuf"; +option java_outer_classname = "TestEnumOrder"; + +message Speed { + Enum speed = 1; +} + +enum Enum { + HIGH = 0; + MIDDLE = 1; + LOW = 2; +}\ +""" + +MESSAGE_WITH_ENUM_REORDERED = """\ +syntax = "proto3"; + +option java_package = "com.codingharbour.protobuf"; +option java_outer_classname = "TestEnumOrder"; + +enum Enum { + HIGH = 0; + MIDDLE = 1; + LOW = 2; +} + +message Speed { + Enum speed = 1; +}\ +""" + +COMPLEX_MESSAGE_WITH_NESTING = """\ +syntax = "proto3"; + +package fancy.company.in.party.v1; + +message AnotherMessage { + message WowANestedMessage { + enum BamFancyEnum { + // Hei! This is a comment! + MY_AWESOME_FIELD = 0; + } + message DeeplyNestedMsg { + message AnotherLevelOfNesting { + BamFancyEnum im_tricky_im_referring_to_the_previous_enum = 1; + } + } + } +}\ +""" + +MESSAGE_WITH_JUST_ENUMS = """\ +syntax = "proto3"; + +option java_package = "com.codingharbour.protobuf"; +option java_outer_classname = "TestEnumOrder"; + +enum Enum { + HIGH = 0; + MIDDLE = 1; + LOW = 2; +} + +enum Enum2 { + HIGH = 0; + MIDDLE = 1; + LOW = 2; +}\ +""" + + +def parse_avro_schema(schema: str) -> ProtobufSchema: + parsed_schema = ValidatedTypedSchema.parse( + SchemaType.PROTOBUF, + schema, + ) + return parsed_schema.schema + + +@pytest.mark.parametrize( + "schema,expected_record_name", + ( + (parse_avro_schema(MESSAGE_WITH_ENUM), "Speed"), + (parse_avro_schema(MESSAGE_WITH_ENUM_REORDERED), "Speed"), + (parse_avro_schema(schema_protobuf_second), "SensorInfo"), + (parse_avro_schema(COMPLEX_MESSAGE_WITH_NESTING), "fancy.company.in.party.v1.AnotherMessage"), + (parse_avro_schema(MESSAGE_WITH_JUST_ENUMS), "Enum"), + ), +) +def test_record_name(schema: ProtobufSchema, expected_record_name: str): + assert schema.record_name() == expected_record_name diff --git a/tests/unit/test_protobuf_serialization.py b/tests/unit/test_protobuf_serialization.py index 3acd344b8..db039c64f 100644 --- a/tests/unit/test_protobuf_serialization.py +++ b/tests/unit/test_protobuf_serialization.py @@ -15,6 +15,7 @@ START_BYTE, ) from karapace.typing import ResolvedVersion, Subject +from pathlib import Path from tests.utils import schema_protobuf, test_fail_objects_protobuf, test_objects_protobuf from unittest.mock import call, Mock @@ -35,7 +36,7 @@ async def make_ser_deser(config_path: str, mock_client) -> SchemaRegistrySeriali return serializer -async def test_happy_flow(default_config_path): +async def test_happy_flow(default_config_path: Path): mock_protobuf_registry_client = Mock() schema_for_id_one_future = asyncio.Future() schema_for_id_one_future.set_result( @@ -61,7 +62,7 @@ async def test_happy_flow(default_config_path): assert mock_protobuf_registry_client.method_calls == [call.get_schema("top"), call.get_schema_for_id(1)] -async def test_happy_flow_references(default_config_path): +async def test_happy_flow_references(default_config_path: Path): no_ref_schema_str = """ |syntax = "proto3"; | @@ -129,7 +130,7 @@ async def test_happy_flow_references(default_config_path): assert mock_protobuf_registry_client.method_calls == [call.get_schema("top"), call.get_schema_for_id(1)] -async def test_happy_flow_references_two(default_config_path): +async def test_happy_flow_references_two(default_config_path: Path): no_ref_schema_str = """ |syntax = "proto3"; | @@ -216,7 +217,7 @@ async def test_happy_flow_references_two(default_config_path): assert mock_protobuf_registry_client.method_calls == [call.get_schema("top"), call.get_schema_for_id(1)] -async def test_serialization_fails(default_config_path): +async def test_serialization_fails(default_config_path: Path): mock_protobuf_registry_client = Mock() get_latest_schema_future = asyncio.Future() get_latest_schema_future.set_result( @@ -239,7 +240,7 @@ async def test_serialization_fails(default_config_path): assert mock_protobuf_registry_client.method_calls == [call.get_schema("top")] -async def test_deserialization_fails(default_config_path): +async def test_deserialization_fails(default_config_path: Path): mock_protobuf_registry_client = Mock() deserializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) @@ -258,7 +259,7 @@ async def test_deserialization_fails(default_config_path): assert mock_protobuf_registry_client.method_calls == [call.get_schema_for_id(500)] -async def test_deserialization_fails2(default_config_path): +async def test_deserialization_fails2(default_config_path: Path): mock_protobuf_registry_client = Mock() deserializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index 029cae393..54ca8e99a 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -2,10 +2,12 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ +from karapace.client import Path from karapace.config import DEFAULTS, read_config from karapace.schema_models import SchemaType, ValidatedTypedSchema from karapace.serialization import ( flatten_unions, + get_subject_name, HEADER_FORMAT, InvalidMessageHeader, InvalidMessageSchema, @@ -14,7 +16,7 @@ START_BYTE, write_value, ) -from karapace.typing import ResolvedVersion, Subject +from karapace.typing import NameStrategy, ResolvedVersion, Subject, SubjectType from tests.utils import schema_avro_json, test_objects_avro from unittest.mock import call, Mock @@ -29,6 +31,71 @@ log = logging.getLogger(__name__) +TYPED_AVRO_SCHEMA = ValidatedTypedSchema.parse( + SchemaType.AVRO, + json.dumps( + { + "namespace": "io.aiven.data", + "name": "Test", + "type": "record", + "fields": [ + { + "name": "attr1", + "type": ["null", "string"], + }, + { + "name": "attr2", + "type": ["null", "string"], + }, + ], + } + ), +) + +TYPED_JSON_SCHEMA = ValidatedTypedSchema.parse( + SchemaType.JSONSCHEMA, + json.dumps( + { + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Test", + "type": "object", + "properties": {"attr1": {"type": ["null", "string"]}, "attr2": {"type": ["null", "string"]}}, + } + ), +) + +TYPED_JSON_SCHEMA_WITHOUT_NAMESPACE = ValidatedTypedSchema.parse( + SchemaType.AVRO, + json.dumps( + { + "name": "Test", + "type": "record", + "fields": [ + { + "name": "attr1", + "type": ["null", "string"], + }, + { + "name": "attr2", + "type": ["null", "string"], + }, + ], + } + ), +) + +TYPED_PROTOBUF_SCHEMA = ValidatedTypedSchema.parse( + SchemaType.PROTOBUF, + """\ + syntax = "proto3"; + + message Test { + string attr1 = 1; + string attr2 = 2; + }\ + """, +) + async def make_ser_deser(config_path: str, mock_client) -> SchemaRegistrySerializer: with open(config_path, encoding="utf8") as handler: @@ -39,7 +106,7 @@ async def make_ser_deser(config_path: str, mock_client) -> SchemaRegistrySeriali return serializer -async def test_happy_flow(default_config_path): +async def test_happy_flow(default_config_path: Path): mock_registry_client = Mock() get_latest_schema_future = asyncio.Future() get_latest_schema_future.set_result( @@ -47,7 +114,8 @@ async def test_happy_flow(default_config_path): ) mock_registry_client.get_schema.return_value = get_latest_schema_future schema_for_id_one_future = asyncio.Future() - schema_for_id_one_future.set_result((ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), [Subject("stub")])) + schema_for_id_one_future.set_result( + (ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), [Subject("stub")])) mock_registry_client.get_schema_for_id.return_value = schema_for_id_one_future serializer = await make_ser_deser(default_config_path, mock_registry_client) @@ -62,32 +130,12 @@ async def test_happy_flow(default_config_path): def test_flatten_unions_record() -> None: - typed_schema = ValidatedTypedSchema.parse( - SchemaType.AVRO, - json.dumps( - { - "namespace": "io.aiven.data", - "name": "Test", - "type": "record", - "fields": [ - { - "name": "attr1", - "type": ["null", "string"], - }, - { - "name": "attr2", - "type": ["null", "string"], - }, - ], - } - ), - ) record = {"attr1": {"string": "sample data"}, "attr2": None} flatten_record = {"attr1": "sample data", "attr2": None} - assert flatten_unions(typed_schema.schema, record) == flatten_record + assert flatten_unions(TYPED_AVRO_SCHEMA.schema, record) == flatten_record record = {"attr1": None, "attr2": None} - assert flatten_unions(typed_schema.schema, record) == record + assert flatten_unions(TYPED_AVRO_SCHEMA.schema, record) == record def test_flatten_unions_array() -> None: @@ -222,7 +270,8 @@ def test_avro_json_write_accepts_json_encoded_data_without_tagged_unions() -> No { "name": "outter", "type": [ - {"type": "record", "name": duplicated_name, "fields": [{"name": duplicated_name, "type": "string"}]}, + {"type": "record", "name": duplicated_name, + "fields": [{"name": duplicated_name, "type": "string"}]}, "int", ], } @@ -248,7 +297,7 @@ def test_avro_json_write_accepts_json_encoded_data_without_tagged_unions() -> No assert buffer_a.getbuffer() == buffer_b.getbuffer() -async def test_serialization_fails(default_config_path): +async def test_serialization_fails(default_config_path: Path): mock_registry_client = Mock() get_latest_schema_future = asyncio.Future() get_latest_schema_future.set_result( @@ -264,10 +313,11 @@ async def test_serialization_fails(default_config_path): assert mock_registry_client.method_calls == [call.get_schema("topic")] -async def test_deserialization_fails(default_config_path): +async def test_deserialization_fails(default_config_path: Path): mock_registry_client = Mock() schema_for_id_one_future = asyncio.Future() - schema_for_id_one_future.set_result((ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), [Subject("stub")])) + schema_for_id_one_future.set_result( + (ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), [Subject("stub")])) mock_registry_client.get_schema_for_id.return_value = schema_for_id_one_future deserializer = await make_ser_deser(default_config_path, mock_registry_client) @@ -310,3 +360,72 @@ async def test_deserialization_fails(default_config_path): await deserializer.deserialize(enc_bytes) assert mock_registry_client.method_calls == [call.get_schema_for_id(1)] + + +@pytest.mark.parametrize( + "expected_subject,strategy,subject_type", + ( + (Subject("foo-key"), NameStrategy.topic_name, SubjectType.key), + (Subject("io.aiven.data.Test"), NameStrategy.record_name, SubjectType.key), + (Subject("foo-io.aiven.data.Test"), NameStrategy.topic_record_name, SubjectType.key), + (Subject("foo-value"), NameStrategy.topic_name, SubjectType.value), + (Subject("io.aiven.data.Test"), NameStrategy.record_name, SubjectType.value), + (Subject("foo-io.aiven.data.Test"), NameStrategy.topic_record_name, SubjectType.value), + ), +) +def test_name_strategy_for_avro(expected_subject: Subject, strategy: NameStrategy, subject_type: SubjectType): + assert ( + get_subject_name(topic_name="foo", schema=TYPED_AVRO_SCHEMA, subject_type=subject_type, + naming_strategy=strategy) + == expected_subject + ) + + +@pytest.mark.parametrize( + "expected_subject,strategy,subject_type", + ( + (Subject("Test"), NameStrategy.record_name, SubjectType.key), + (Subject("foo-Test"), NameStrategy.topic_record_name, SubjectType.key), + (Subject("Test"), NameStrategy.record_name, SubjectType.value), + (Subject("foo-Test"), NameStrategy.topic_record_name, SubjectType.value), + ), +) +def test_name_strategy_for_json_schema(expected_subject: Subject, strategy: NameStrategy, subject_type: SubjectType): + assert ( + get_subject_name(topic_name="foo", schema=TYPED_JSON_SCHEMA, subject_type=subject_type, + naming_strategy=strategy) + == expected_subject + ) + + +@pytest.mark.parametrize( + "expected_subject,strategy,subject_type", + ( + (Subject("Test"), NameStrategy.record_name, SubjectType.key), + (Subject("foo-Test"), NameStrategy.topic_record_name, SubjectType.key), + (Subject("Test"), NameStrategy.record_name, SubjectType.value), + (Subject("foo-Test"), NameStrategy.topic_record_name, SubjectType.value), + ), +) +def test_name_strategy_for_json_schema_without_namespace(expected_subject: Subject, strategy: NameStrategy, subject_type: SubjectType): + assert ( + get_subject_name(topic_name="foo", schema=TYPED_JSON_SCHEMA_WITHOUT_NAMESPACE, subject_type=subject_type, + naming_strategy=strategy) + == expected_subject + ) + +@pytest.mark.parametrize( + "expected_subject,strategy,subject_type", + ( + (Subject("Test"), NameStrategy.record_name, SubjectType.key), + (Subject("foo-Test"), NameStrategy.topic_record_name, SubjectType.key), + (Subject("Test"), NameStrategy.record_name, SubjectType.value), + (Subject("foo-Test"), NameStrategy.topic_record_name, SubjectType.value), + ), +) +def test_name_strategy_for_protobuf(expected_subject: Subject, strategy: NameStrategy, subject_type: SubjectType): + assert ( + get_subject_name(topic_name="foo", schema=TYPED_PROTOBUF_SCHEMA, subject_type=subject_type, + naming_strategy=strategy) + == expected_subject + )