diff --git a/karapace/config.py b/karapace/config.py index f426dfcc5..4358b3bfe 100644 --- a/karapace/config.py +++ b/karapace/config.py @@ -62,6 +62,7 @@ class Config(TypedDict): sasl_mechanism: str | None sasl_plain_username: str | None sasl_plain_password: str | None + sasl_oauth_token: str | None topic_name: str metadata_max_age_ms: int admin_metadata_max_age: int @@ -131,6 +132,7 @@ class ConfigDefaults(Config, total=False): "sasl_mechanism": None, "sasl_plain_username": None, SASL_PLAIN_PASSWORD: None, + "sasl_oauth_token": None, "topic_name": DEFAULT_SCHEMA_TOPIC, "metadata_max_age_ms": 60000, "admin_metadata_max_age": 5, diff --git a/karapace/kafka_rest_apis/__init__.py b/karapace/kafka_rest_apis/__init__.py index 1b6bf4b6f..1afd44807 100644 --- a/karapace/kafka_rest_apis/__init__.py +++ b/karapace/kafka_rest_apis/__init__.py @@ -15,11 +15,12 @@ from karapace.config import Config, create_client_ssl_context from karapace.errors import InvalidSchema from karapace.kafka_rest_apis.admin import KafkaRestAdminClient +from karapace.kafka_rest_apis.auth_utils import get_auth_config_from_header, get_kafka_client_auth_parameters_from_config from karapace.kafka_rest_apis.consumer_manager import ConsumerManager from karapace.kafka_rest_apis.error_codes import RESTErrorCodes from karapace.kafka_rest_apis.schema_cache import TopicSchemaCache from karapace.karapace import KarapaceBase -from karapace.rapu import HTTPRequest, HTTPResponse, JSON_CONTENT_TYPE +from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE from karapace.schema_models import TypedSchema, ValidatedTypedSchema from karapace.schema_type import SchemaType from karapace.serialization import InvalidMessageSchema, InvalidPayload, SchemaRegistrySerializer, SchemaRetrievalError @@ -27,7 +28,6 @@ from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient from typing import Callable, Dict, List, Optional, Tuple, Union -import aiohttp.web import asyncio import base64 import logging @@ -271,26 +271,16 @@ async def get_user_proxy(self, request: HTTPRequest) -> "UserRestProxy": try: if self.config.get("rest_authorization", False): auth_header = request.headers.get("Authorization") + auth_config = get_auth_config_from_header(auth_header, self.config) - if auth_header is None: - raise HTTPResponse( - body='{"message": "Unauthorized"}', - status=HTTPStatus.UNAUTHORIZED, - content_type=JSON_CONTENT_TYPE, - headers={"WWW-Authenticate": 'Basic realm="Karapace REST Proxy"'}, - ) key = auth_header if self.proxies.get(key) is None: - auth = aiohttp.BasicAuth.decode(auth_header) config = self.config.copy() config["bootstrap_uri"] = config["sasl_bootstrap_uri"] config["security_protocol"] = ( "SASL_SSL" if config["security_protocol"] in ("SSL", "SASL_SSL") else "SASL_PLAINTEXT" ) - if config["sasl_mechanism"] is None: - config["sasl_mechanism"] = "PLAIN" - config["sasl_plain_username"] = auth.login - config["sasl_plain_password"] = auth.password + config.update(auth_config) self.proxies[key] = UserRestProxy(config, self.kafka_timeout, self.serializer) else: if self.proxies.get(key) is None: @@ -471,9 +461,7 @@ async def _maybe_create_async_producer(self) -> AIOKafkaProducer: metadata_max_age_ms=self.config["metadata_max_age_ms"], security_protocol=self.config["security_protocol"], ssl_context=ssl_context, - sasl_mechanism=self.config["sasl_mechanism"], - sasl_plain_username=self.config["sasl_plain_username"], - sasl_plain_password=self.config["sasl_plain_password"], + **get_kafka_client_auth_parameters_from_config(self.config), ) try: @@ -626,13 +614,11 @@ def init_admin_client(self): ssl_cafile=self.config["ssl_cafile"], ssl_certfile=self.config["ssl_certfile"], ssl_keyfile=self.config["ssl_keyfile"], - sasl_mechanism=self.config["sasl_mechanism"], - sasl_plain_username=self.config["sasl_plain_username"], - sasl_plain_password=self.config["sasl_plain_password"], api_version=(1, 0, 0), metadata_max_age_ms=self.config["metadata_max_age_ms"], connections_max_idle_ms=self.config["connections_max_idle_ms"], kafka_client=KarapaceKafkaClient, + **get_kafka_client_auth_parameters_from_config(self.config, async_client=False), ) break except: # pylint: disable=bare-except diff --git a/karapace/kafka_rest_apis/auth_utils.py b/karapace/kafka_rest_apis/auth_utils.py new file mode 100644 index 000000000..c06550ea0 --- /dev/null +++ b/karapace/kafka_rest_apis/auth_utils.py @@ -0,0 +1,109 @@ +""" +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +from __future__ import annotations + +from aiokafka.abc import AbstractTokenProvider as AbstractTokenProviderAsync +from http import HTTPStatus +from kafka.oauth.abstract import AbstractTokenProvider +from karapace.config import Config +from karapace.rapu import HTTPResponse, JSON_CONTENT_TYPE +from typing import NoReturn, TypedDict + +import aiohttp +import dataclasses +import enum + + +@enum.unique +class TokenType(enum.Enum): + BASIC = "Basic" + BEARER = "Bearer" + + +def raise_unauthorized() -> NoReturn: + raise HTTPResponse( + body='{"message": "Unauthorized"}', + status=HTTPStatus.UNAUTHORIZED, + content_type=JSON_CONTENT_TYPE, + headers={"WWW-Authenticate": 'Basic realm="Karapace REST Proxy"'}, + ) + + +class SASLPlainConfig(TypedDict): + sasl_mechanism: str | None + sasl_plain_username: str | None + sasl_plain_password: str | None + + +class SASLOauthConfig(TypedDict): + sasl_mechanism: str | None + sasl_oauth_token: str | None + + +def get_auth_config_from_header( + auth_header: str | None, + config: Config, +) -> SASLPlainConfig | SASLOauthConfig: + if auth_header is None: + raise_unauthorized() + + token_type, _separator, token = auth_header.partition(" ") + + if token_type == TokenType.BEARER.value: + return {"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": token} + + if token_type == TokenType.BASIC.value: + basic_auth = aiohttp.BasicAuth.decode(auth_header) + sasl_mechanism = config["sasl_mechanism"] + if sasl_mechanism is None: + sasl_mechanism = "PLAIN" + + return { + "sasl_mechanism": sasl_mechanism, + "sasl_plain_username": basic_auth.login, + "sasl_plain_password": basic_auth.password, + } + + raise_unauthorized() + + +@dataclasses.dataclass +class SimpleOauthTokenProvider(AbstractTokenProvider): + _token: str + + def token(self) -> str: + return self._token + + +@dataclasses.dataclass +class SimpleOauthTokenProviderAsync(AbstractTokenProviderAsync): + _token: str + + async def token(self) -> str: + return self._token + + +class SASLOauthParams(TypedDict): + sasl_mechanism: str + sasl_oauth_token_provider: AbstractTokenProvider | AbstractTokenProviderAsync + + +def get_kafka_client_auth_parameters_from_config( + config: Config, + *, + async_client: bool = True, +) -> SASLPlainConfig | SASLOauthParams: + if config["sasl_mechanism"] == "OAUTHBEARER": + token_provider_cls = SimpleOauthTokenProviderAsync if async_client else SimpleOauthTokenProvider + return { + "sasl_mechanism": config["sasl_mechanism"], + "sasl_oauth_token_provider": token_provider_cls(config["sasl_oauth_token"]), + } + + return { + "sasl_mechanism": config["sasl_mechanism"], + "sasl_plain_username": config["sasl_plain_username"], + "sasl_plain_password": config["sasl_plain_password"], + } diff --git a/karapace/kafka_rest_apis/consumer_manager.py b/karapace/kafka_rest_apis/consumer_manager.py index 72a1bea9a..c5fe4ab03 100644 --- a/karapace/kafka_rest_apis/consumer_manager.py +++ b/karapace/kafka_rest_apis/consumer_manager.py @@ -10,6 +10,7 @@ from kafka.errors import GroupAuthorizationFailedError, IllegalStateError, KafkaConfigurationError, KafkaError from kafka.structs import TopicPartition from karapace.config import Config, create_client_ssl_context +from karapace.kafka_rest_apis.auth_utils import get_kafka_client_auth_parameters_from_config from karapace.kafka_rest_apis.error_codes import RESTErrorCodes from karapace.karapace import empty_response, KarapaceBase from karapace.serialization import DeserializationError, InvalidMessageHeader, InvalidPayload, SchemaRegistrySerializer @@ -205,9 +206,6 @@ async def create_kafka_consumer(self, fetch_min_bytes, group_name, internal_name client_id=internal_name, security_protocol=self.config["security_protocol"], ssl_context=ssl_context, - sasl_mechanism=self.config["sasl_mechanism"], - sasl_plain_username=self.config["sasl_plain_username"], - sasl_plain_password=self.config["sasl_plain_password"], group_id=group_name, fetch_min_bytes=max(1, fetch_min_bytes), # Discard earlier negative values fetch_max_bytes=self.config["consumer_request_max_bytes"], @@ -218,6 +216,7 @@ async def create_kafka_consumer(self, fetch_min_bytes, group_name, internal_name enable_auto_commit=request_data["auto.commit.enable"], auto_offset_reset=request_data["auto.offset.reset"], session_timeout_ms=session_timeout_ms, + **get_kafka_client_auth_parameters_from_config(self.config), ) await c.start() return c diff --git a/mypy.ini b/mypy.ini index 291bb88b7..e06dfb201 100644 --- a/mypy.ini +++ b/mypy.ini @@ -86,6 +86,9 @@ ignore_errors = True # dependencies. # - Write your own stubs. You don't need to write stubs for the whole library, # only the parts that Karapace is interacting with. +[mypy-aiokafka.*] +ignore_missing_imports = True + [mypy-kafka.*] ignore_missing_imports = True diff --git a/tests/unit/test_auth_utils.py b/tests/unit/test_auth_utils.py new file mode 100644 index 000000000..26b6e023b --- /dev/null +++ b/tests/unit/test_auth_utils.py @@ -0,0 +1,109 @@ +""" +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +from http import HTTPStatus +from karapace.config import ConfigDefaults, set_config_defaults +from karapace.kafka_rest_apis.auth_utils import ( + get_auth_config_from_header, + get_kafka_client_auth_parameters_from_config, + SimpleOauthTokenProvider, + SimpleOauthTokenProviderAsync, +) +from karapace.rapu import HTTPResponse, JSON_CONTENT_TYPE +from typing import Optional + +import base64 +import pytest + + +@pytest.mark.parametrize( + "auth_header", + (None, "Digest foo=bar"), +) +def test_get_auth_config_from_header_raises_unauthorized_on_invalid_header(auth_header: Optional[str]) -> None: + config = set_config_defaults({}) + + with pytest.raises(HTTPResponse) as exc_info: + get_auth_config_from_header(auth_header, config) + + http_resonse = exc_info.value + assert http_resonse.body == '{"message": "Unauthorized"}' + assert http_resonse.status == HTTPStatus.UNAUTHORIZED + assert http_resonse.headers["Content-Type"] == JSON_CONTENT_TYPE + assert http_resonse.headers["WWW-Authenticate"] == 'Basic realm="Karapace REST Proxy"' + + +@pytest.mark.parametrize( + ("auth_header", "config_override", "expected_auth_config"), + ( + ( + f"Basic {base64.b64encode(b'username:password').decode()}", + {"sasl_mechanism": None}, + {"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"}, + ), + ( + f"Basic {base64.b64encode(b'username:password').decode()}", + {"sasl_mechanism": "PLAIN"}, + {"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"}, + ), + ( + f"Basic {base64.b64encode(b'username:password').decode()}", + {"sasl_mechanism": "SCRAM"}, + {"sasl_mechanism": "SCRAM", "sasl_plain_username": "username", "sasl_plain_password": "password"}, + ), + ( + "Bearer ", + {}, + {"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": ""}, + ), + ), +) +def test_get_auth_config_from_header( + auth_header: str, config_override: ConfigDefaults, expected_auth_config: ConfigDefaults +) -> None: + config = set_config_defaults(config_override) + auth_config = get_auth_config_from_header(auth_header, config) + assert auth_config == expected_auth_config + + +def test_simple_oauth_token_provider_returns_configured_token() -> None: + token_provider = SimpleOauthTokenProvider("TOKEN") + assert token_provider.token() == "TOKEN" + + +async def test_simple_oauth_token_provider_async_returns_configured_token() -> None: + token_provider = SimpleOauthTokenProviderAsync("TOKEN") + assert await token_provider.token() == "TOKEN" + + +def test_get_client_auth_parameters_from_config_sasl_plain() -> None: + config = set_config_defaults( + {"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"} + ) + + client_auth_params = get_kafka_client_auth_parameters_from_config(config) + + assert client_auth_params == { + "sasl_mechanism": "PLAIN", + "sasl_plain_username": "username", + "sasl_plain_password": "password", + } + + +def test_get_client_auth_parameters_from_config_oauth() -> None: + config = set_config_defaults({"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": "TOKEN"}) + + client_auth_params = get_kafka_client_auth_parameters_from_config(config, async_client=False) + + assert client_auth_params["sasl_mechanism"] == "OAUTHBEARER" + assert client_auth_params["sasl_oauth_token_provider"].token() == "TOKEN" + + +async def test_get_client_auth_parameters_from_config_oauth_async() -> None: + config = set_config_defaults({"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": "TOKEN"}) + + client_auth_params = get_kafka_client_auth_parameters_from_config(config, async_client=True) + + assert client_auth_params["sasl_mechanism"] == "OAUTHBEARER" + assert await client_auth_params["sasl_oauth_token_provider"].token() == "TOKEN"