diff --git a/karapace/kafka_rest_apis/__init__.py b/karapace/kafka_rest_apis/__init__.py index 1b6bf4b6f..369c93a45 100644 --- a/karapace/kafka_rest_apis/__init__.py +++ b/karapace/kafka_rest_apis/__init__.py @@ -15,11 +15,12 @@ from karapace.config import Config, create_client_ssl_context from karapace.errors import InvalidSchema from karapace.kafka_rest_apis.admin import KafkaRestAdminClient +from karapace.kafka_rest_apis.auth_utils import get_auth_config_from_header from karapace.kafka_rest_apis.consumer_manager import ConsumerManager from karapace.kafka_rest_apis.error_codes import RESTErrorCodes from karapace.kafka_rest_apis.schema_cache import TopicSchemaCache from karapace.karapace import KarapaceBase -from karapace.rapu import HTTPRequest, HTTPResponse, JSON_CONTENT_TYPE +from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE from karapace.schema_models import TypedSchema, ValidatedTypedSchema from karapace.schema_type import SchemaType from karapace.serialization import InvalidMessageSchema, InvalidPayload, SchemaRegistrySerializer, SchemaRetrievalError @@ -27,7 +28,6 @@ from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient from typing import Callable, Dict, List, Optional, Tuple, Union -import aiohttp.web import asyncio import base64 import logging @@ -271,26 +271,16 @@ async def get_user_proxy(self, request: HTTPRequest) -> "UserRestProxy": try: if self.config.get("rest_authorization", False): auth_header = request.headers.get("Authorization") + auth_config = get_auth_config_from_header(auth_header, self.config) - if auth_header is None: - raise HTTPResponse( - body='{"message": "Unauthorized"}', - status=HTTPStatus.UNAUTHORIZED, - content_type=JSON_CONTENT_TYPE, - headers={"WWW-Authenticate": 'Basic realm="Karapace REST Proxy"'}, - ) key = auth_header if self.proxies.get(key) is None: - auth = aiohttp.BasicAuth.decode(auth_header) config = self.config.copy() config["bootstrap_uri"] = config["sasl_bootstrap_uri"] config["security_protocol"] = ( "SASL_SSL" if config["security_protocol"] in ("SSL", "SASL_SSL") else "SASL_PLAINTEXT" ) - if config["sasl_mechanism"] is None: - config["sasl_mechanism"] = "PLAIN" - config["sasl_plain_username"] = auth.login - config["sasl_plain_password"] = auth.password + config.update(auth_config) self.proxies[key] = UserRestProxy(config, self.kafka_timeout, self.serializer) else: if self.proxies.get(key) is None: diff --git a/karapace/kafka_rest_apis/auth_utils.py b/karapace/kafka_rest_apis/auth_utils.py new file mode 100644 index 000000000..0f646178b --- /dev/null +++ b/karapace/kafka_rest_apis/auth_utils.py @@ -0,0 +1,67 @@ +""" +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +import aiohttp +import dataclasses +import enum +from http import HTTPStatus +from typing import NoReturn, Optional, TypedDict, Union + +from kafka.oauth.abstract import AbstractTokenProvider + +from karapace.config import Config +from karapace.rapu import HTTPResponse, JSON_CONTENT_TYPE + + +@enum.unique +class TokenType(enum.Enum): + BASIC = "Basic" + BEARER = "Bearer" + + +def raise_unauthorized() -> NoReturn: + raise HTTPResponse( + body='{"message": "Unauthorized"}', + status=HTTPStatus.UNAUTHORIZED, + content_type=JSON_CONTENT_TYPE, + headers={"WWW-Authenticate": 'Basic realm="Karapace REST Proxy"'}, + ) + + +class SASLPlainConfig(TypedDict): + sasl_mechanism: Optional[str] + sasl_plain_username: Optional[str] + sasl_plain_password: Optional[str] + + +class SASLOauthConfig(TypedDict): + sasl_mechanism: Optional[str] + sasl_oauth_token: Optional[str] + + +def get_auth_config_from_header( + auth_header: Optional[str], + config: Config, +) -> Union[SASLPlainConfig, SASLOauthConfig]: + if auth_header is None: + raise_unauthorized() + + token_type, _separator, token = auth_header.partition(" ") + + if token_type == TokenType.BEARER.value: + return {"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": token} + + if token_type == TokenType.BASIC.value: + basic_auth = aiohttp.BasicAuth.decode(auth_header) + sasl_mechanism = config["sasl_mechanism"] + if sasl_mechanism is None: + sasl_mechanism = "PLAIN" + + return { + "sasl_mechanism": sasl_mechanism, + "sasl_plain_username": basic_auth.login, + "sasl_plain_password": basic_auth.password, + } + + raise_unauthorized() diff --git a/tests/unit/test_auth_utils.py b/tests/unit/test_auth_utils.py new file mode 100644 index 000000000..3c6b78278 --- /dev/null +++ b/tests/unit/test_auth_utils.py @@ -0,0 +1,67 @@ +""" +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +import base64 +from http import HTTPStatus +from typing import Optional + +import pytest + +from karapace.config import set_config_defaults, ConfigDefaults +from karapace.kafka_rest_apis.auth_utils import ( + get_auth_config_from_header, + get_kafka_client_auth_parameters_from_config, + SimpleOauthTokenProvider, +) +from karapace.rapu import HTTPResponse, JSON_CONTENT_TYPE + + +@pytest.mark.parametrize( + "auth_header", + (None, "Digest foo=bar"), +) +def test_get_auth_config_from_header_raises_unauthorized_on_invalid_header(auth_header: Optional[str]) -> None: + config = set_config_defaults({}) + + with pytest.raises(HTTPResponse) as exc_info: + get_auth_config_from_header(auth_header, config) + + http_resonse = exc_info.value + assert http_resonse.body == '{"message": "Unauthorized"}' + assert http_resonse.status == HTTPStatus.UNAUTHORIZED + assert http_resonse.headers["Content-Type"] == JSON_CONTENT_TYPE + assert http_resonse.headers["WWW-Authenticate"] == 'Basic realm="Karapace REST Proxy"' + + +@pytest.mark.parametrize( + ("auth_header", "config_override", "expected_auth_config"), + ( + ( + f"Basic {base64.b64encode(b'username:password').decode()}", + {"sasl_mechanism": None}, + {"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"}, + ), + ( + f"Basic {base64.b64encode(b'username:password').decode()}", + {"sasl_mechanism": "PLAIN"}, + {"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"}, + ), + ( + f"Basic {base64.b64encode(b'username:password').decode()}", + {"sasl_mechanism": "SCRAM"}, + {"sasl_mechanism": "SCRAM", "sasl_plain_username": "username", "sasl_plain_password": "password"}, + ), + ( + "Bearer ", + {}, + {"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": ""}, + ), + ), +) +def test_get_auth_config_from_header( + auth_header: str, config_override: ConfigDefaults, expected_auth_config: ConfigDefaults +) -> None: + config = set_config_defaults(config_override) + auth_config = get_auth_config_from_header(auth_header, config) + assert auth_config == expected_auth_config