diff --git a/karapace/kafka_rest_apis/__init__.py b/karapace/kafka_rest_apis/__init__.py index ed99fe596..4f53406f5 100644 --- a/karapace/kafka_rest_apis/__init__.py +++ b/karapace/kafka_rest_apis/__init__.py @@ -16,7 +16,11 @@ from karapace.config import Config, create_client_ssl_context from karapace.errors import InvalidSchema from karapace.kafka_rest_apis.admin import KafkaRestAdminClient -from karapace.kafka_rest_apis.authentication import get_auth_config_from_header, get_kafka_client_auth_parameters_from_config +from karapace.kafka_rest_apis.authentication import ( + get_auth_config_from_header, + get_expiration_time_from_header, + get_kafka_client_auth_parameters_from_config, +) from karapace.kafka_rest_apis.consumer_manager import ConsumerManager from karapace.kafka_rest_apis.error_codes import RESTErrorCodes from karapace.kafka_rest_apis.schema_cache import TopicSchemaCache @@ -31,6 +35,7 @@ import asyncio import base64 +import datetime import logging import time @@ -42,6 +47,7 @@ SCHEMA_MAPPINGS = {"avro": SchemaType.AVRO, "jsonschema": SchemaType.JSONSCHEMA, "protobuf": SchemaType.PROTOBUF} TypedConsumer = namedtuple("TypedConsumer", ["consumer", "serialization_format", "config"]) IDLE_PROXY_TIMEOUT = 5 * 60 +AUTH_EXPIRY_TOLERANCE = datetime.timedelta(seconds=IDLE_PROXY_TIMEOUT) log = logging.getLogger(__name__) @@ -87,6 +93,13 @@ async def _disconnect_idle_proxy_if_any(self) -> None: async with self._proxy_lock: # Always clean one at time, don't mutate dict while iterating for _key, _proxy in self.proxies.items(): + # In case of an OAuth2/OIDC token, the proxy is to be cleaned up _before_ the token expires + # If the token is still valid within the tolerance time, idleness is still checked + now = datetime.datetime.now(datetime.timezone.utc) + if _proxy.auth_expiry and _proxy.auth_expiry < now + AUTH_EXPIRY_TOLERANCE: + key, proxy = _key, _proxy + log.warning("Releasing unused connection for %s due to token expiry at %s", _proxy, _proxy.auth_expiry) + break # If UserRestProxy has consumers with state, disconnecting loses state if _proxy.num_consumers() > 0: if idle_consumer_timeout > 0 and _proxy.last_used + idle_consumer_timeout < time.monotonic(): @@ -273,6 +286,7 @@ async def get_user_proxy(self, request: HTTPRequest) -> "UserRestProxy": if self.config.get("rest_authorization", False): auth_header = request.headers.get("Authorization") auth_config = get_auth_config_from_header(auth_header, self.config) + auth_expiry = get_expiration_time_from_header(auth_header) key = auth_header if self.proxies.get(key) is None: @@ -282,7 +296,7 @@ async def get_user_proxy(self, request: HTTPRequest) -> "UserRestProxy": "SASL_SSL" if config["security_protocol"] in ("SSL", "SASL_SSL") else "SASL_PLAINTEXT" ) config.update(auth_config) - self.proxies[key] = UserRestProxy(config, self.kafka_timeout, self.serializer) + self.proxies[key] = UserRestProxy(config, self.kafka_timeout, self.serializer, auth_expiry) else: if self.proxies.get(key) is None: self.proxies[key] = UserRestProxy(self.config, self.kafka_timeout, self.serializer) @@ -400,7 +414,13 @@ async def topic_publish(self, topic: str, content_type: str, *, request: HTTPReq class UserRestProxy: - def __init__(self, config: Config, kafka_timeout: int, serializer): + def __init__( + self, + config: Config, + kafka_timeout: int, + serializer: SchemaRegistrySerializer, + auth_expiry: Optional[datetime.datetime] = None, + ): self.config = config self.kafka_timeout = kafka_timeout self.serializer = serializer @@ -415,6 +435,7 @@ def __init__(self, config: Config, kafka_timeout: int, serializer): self.consumer_manager = ConsumerManager(config=config, deserializer=self.serializer) self.init_admin_client() self._last_used = time.monotonic() + self._auth_expiry = auth_expiry self._async_producer_lock = asyncio.Lock() self._async_producer: Optional[AIOKafkaProducer] = None @@ -429,6 +450,10 @@ def last_used(self) -> int: def mark_used(self) -> None: self._last_used = time.monotonic() + @property + def auth_expiry(self) -> datetime.datetime: + return self._auth_expiry + def num_consumers(self) -> int: return len(self.consumer_manager.consumers) diff --git a/karapace/kafka_rest_apis/authentication.py b/karapace/kafka_rest_apis/authentication.py index c03df7ca2..88cab7289 100644 --- a/karapace/kafka_rest_apis/authentication.py +++ b/karapace/kafka_rest_apis/authentication.py @@ -13,7 +13,9 @@ import aiohttp import dataclasses +import datetime import enum +import jwt @enum.unique @@ -42,6 +44,11 @@ class SASLOauthConfig(TypedDict): sasl_oauth_token: str | None +def _split_auth_header(auth_header: str) -> tuple[str, str]: + token_type, _separator, token = auth_header.partition(" ") + return (token_type, token) + + def get_auth_config_from_header( auth_header: str | None, config: Config, @@ -57,7 +64,7 @@ def get_auth_config_from_header( if auth_header is None: raise_unauthorized() - token_type, _separator, token = auth_header.partition(" ") + token_type, token = _split_auth_header(auth_header) if token_type == TokenType.BEARER.value: return {"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": token} @@ -77,6 +84,28 @@ def get_auth_config_from_header( raise_unauthorized() +def get_expiration_time_from_header(auth_header: str) -> datetime.datetime | None: + """Extract expiration from Authorization HTTP header. + + In case of an OAuth Bearer token, the `exp` claim is extracted and returned as a + `datetime.datetime` object. Otherwise it's safely assumed that the authentication + method is Basic, thus no expiry of the credentials. + + The signature is not verified as it is done by the Kafka clients using it and + discarding the token in case of any issues. + + :param auth_header: The Authorization header extracted from an HTTP request + """ + token_type, token = _split_auth_header(auth_header) + + if token_type == TokenType.BEARER.value: + exp_claim = jwt.decode(token, options={"verify_signature": False}).get("exp") + if exp_claim is not None: + return datetime.datetime.fromtimestamp(exp_claim, datetime.timezone.utc) + + return None + + @dataclasses.dataclass class SimpleOauthTokenProvider(AbstractTokenProvider): """A pass-through OAuth token provider to be used by synchronous Kafka clients. diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index de1aa2746..5dfa10c9e 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -146,6 +146,8 @@ pygments==2.15.1 # -r requirements.txt # pdbpp # rich +pyjwt==2.4.0 + # via -r requirements-dev.in pyrepl==0.9.0 # via fancycompleter pytest==7.4.0 diff --git a/requirements/requirements.in b/requirements/requirements.in index 25bfad6b7..6338e0ac6 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -6,6 +6,7 @@ isodate<1 jsonschema<5 networkx<4 protobuf<4 +pyjwt==2.4.0 python-dateutil<3 tenacity<9 typing-extensions diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 2f636cc77..01693b966 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -61,6 +61,8 @@ protobuf==3.20.3 # via -r requirements.in pygments==2.15.1 # via rich +pyjwt==2.4.0 + # via -r requirements.in python-dateutil==2.8.2 # via -r requirements.in referencing==0.30.0 diff --git a/tests/unit/test_authentication.py b/tests/unit/test_authentication.py index 86afd4a9f..3921f0560 100644 --- a/tests/unit/test_authentication.py +++ b/tests/unit/test_authentication.py @@ -8,6 +8,7 @@ from karapace.config import ConfigDefaults, set_config_defaults from karapace.kafka_rest_apis.authentication import ( get_auth_config_from_header, + get_expiration_time_from_header, get_kafka_client_auth_parameters_from_config, SimpleOauthTokenProvider, SimpleOauthTokenProviderAsync, @@ -15,6 +16,8 @@ from karapace.rapu import HTTPResponse, JSON_CONTENT_TYPE import base64 +import datetime +import jwt import pytest @@ -68,6 +71,23 @@ def test_get_auth_config_from_header( assert auth_config == expected_auth_config +@pytest.mark.parametrize( + ("auth_header", "expected_expiration"), + ( + (f"Basic {base64.b64encode(b'username:password').decode()}", None), + ( + f"Bearer {jwt.encode({'exp': 1697013997}, 'secret')}", + datetime.datetime.fromtimestamp(1697013997, datetime.timezone.utc), + ), + (f"Bearer {jwt.encode({}, 'secret')}", None), + ), +) +def test_get_expiration_time_from_header(auth_header: str, expected_expiration: datetime.datetime) -> None: + expiration = get_expiration_time_from_header(auth_header) + + assert expiration == expected_expiration + + def test_simple_oauth_token_provider_returns_configured_token() -> None: token_provider = SimpleOauthTokenProvider("TOKEN") assert token_provider.token() == "TOKEN" diff --git a/tests/unit/test_rest_auth.py b/tests/unit/test_rest_auth.py index 4f06c034a..86bb14b8a 100644 --- a/tests/unit/test_rest_auth.py +++ b/tests/unit/test_rest_auth.py @@ -3,15 +3,96 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ +from __future__ import annotations + from karapace.config import set_config_defaults -from karapace.kafka_rest_apis import KafkaRest +from karapace.kafka_rest_apis import AUTH_EXPIRY_TOLERANCE, KafkaRest, UserRestProxy from unittest.mock import call, Mock import asyncio +import datetime import time -async def test_rest_proxy_janitor_default(): +def _create_mock_proxy( + *, + num_consumers: int, + last_used: float, + auth_expiry: datetime.datetime | None = None, + _with_close_future: bool = False, +) -> Mock: + proxy = Mock(spec=UserRestProxy) + proxy.num_consumers.return_value = num_consumers + proxy.last_used = last_used + proxy.auth_expiry = auth_expiry + + if _with_close_future: + close_future = asyncio.Future() + close_future.set_result(True) + proxy.aclose.return_value = close_future + + return proxy + + +async def test_rest_proxy_janitor_expiring_credentials() -> None: + config = set_config_defaults( + { + "rest_authorization": True, + "sasl_bootstrap_uri": "localhost:9094", + } + ) + instance = KafkaRest(config=config) + + proxy_expiring_within_tolerance = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic(), + auth_expiry=datetime.datetime.now(datetime.timezone.utc) + (AUTH_EXPIRY_TOLERANCE / 2), + _with_close_future=True, + ) + instance.proxies["proxy_expiring_within_tolerance"] = proxy_expiring_within_tolerance + + proxy_already_expired = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic(), + auth_expiry=datetime.datetime.now(datetime.timezone.utc) - AUTH_EXPIRY_TOLERANCE, + _with_close_future=True, + ) + instance.proxies["proxy_already_expired"] = proxy_already_expired + + active_proxy_expiring_later_than_tolerance = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic(), + auth_expiry=datetime.datetime.now(datetime.timezone.utc) + 2 * AUTH_EXPIRY_TOLERANCE, + ) + instance.proxies["active_proxy_expiring_later_than_tolerance"] = active_proxy_expiring_later_than_tolerance + + unused_proxy_expiring_later_than_tolerance = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic() - 600, + auth_expiry=datetime.datetime.now(datetime.timezone.utc) + 2 * AUTH_EXPIRY_TOLERANCE, + _with_close_future=True, + ) + instance.proxies["unused_proxy_expiring_later_than_tolerance"] = unused_proxy_expiring_later_than_tolerance + + # Needs to be called multiple times to clean all expected proxies, this method only + # releases the first one it finds, then exits. + await instance._disconnect_idle_proxy_if_any() + await instance._disconnect_idle_proxy_if_any() + await instance._disconnect_idle_proxy_if_any() + + assert instance.proxies.get("proxy_expiring_within_tolerance") is None + assert instance.proxies.get("proxy_already_expired") is None + assert instance.proxies.get("active_proxy_expiring_later_than_tolerance") is not None + assert instance.proxies.get("unused_proxy_expiring_later_than_tolerance") is None + assert len(instance.proxies) == 1 + + assert proxy_expiring_within_tolerance.method_calls == [call.aclose()] + assert proxy_already_expired.method_calls == [call.aclose()] + assert active_proxy_expiring_later_than_tolerance.method_calls == [call.num_consumers()] + assert unused_proxy_expiring_later_than_tolerance.method_calls == [call.num_consumers(), call.aclose()] + + +async def test_rest_proxy_janitor_default() -> None: config = set_config_defaults( { "rest_authorization": True, @@ -20,22 +101,17 @@ async def test_rest_proxy_janitor_default(): ) instance = KafkaRest(config=config) - active_proxy_without_consumers = Mock() - active_proxy_without_consumers.num_consumers.return_value = 0 - active_proxy_without_consumers.last_used = time.monotonic() + active_proxy_without_consumers = _create_mock_proxy(num_consumers=0, last_used=time.monotonic()) instance.proxies["active_proxy_without_consumers"] = active_proxy_without_consumers - active_proxy_with_consumers = Mock() - active_proxy_with_consumers.num_consumers.return_value = 99 - active_proxy_with_consumers.last_used = time.monotonic() + active_proxy_with_consumers = _create_mock_proxy(num_consumers=99, last_used=time.monotonic()) instance.proxies["active_proxy_with_consumers"] = active_proxy_with_consumers - unused_proxy_without_consumers = Mock() - unused_proxy_without_consumers.num_consumers.return_value = 0 - unused_proxy_without_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_without_consumers.aclose.return_value = close_future + unused_proxy_without_consumers = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_without_consumers"] = unused_proxy_without_consumers await instance._disconnect_idle_proxy_if_any() @@ -53,12 +129,11 @@ async def test_rest_proxy_janitor_default(): active_proxy_with_consumers.reset_mock() # Proxy with consumers is not deleted without explicit config - unused_proxy_with_consumers = Mock() - unused_proxy_with_consumers.num_consumers.return_value = 99 - unused_proxy_with_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_with_consumers.aclose.return_value = close_future + unused_proxy_with_consumers = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_with_consumers"] = unused_proxy_with_consumers await instance._disconnect_idle_proxy_if_any() @@ -73,7 +148,7 @@ async def test_rest_proxy_janitor_default(): assert active_proxy_with_consumers.method_calls == [call.num_consumers()] -async def test_rest_proxy_janitor_destructive(): +async def test_rest_proxy_janitor_destructive() -> None: config = set_config_defaults( { "rest_authorization": True, @@ -83,22 +158,17 @@ async def test_rest_proxy_janitor_destructive(): ) instance = KafkaRest(config=config) - active_proxy_without_consumers = Mock() - active_proxy_without_consumers.num_consumers.return_value = 0 - active_proxy_without_consumers.last_used = time.monotonic() + active_proxy_without_consumers = _create_mock_proxy(num_consumers=0, last_used=time.monotonic()) instance.proxies["active_proxy_without_consumers"] = active_proxy_without_consumers - active_proxy_with_consumers = Mock() - active_proxy_with_consumers.num_consumers.return_value = 99 - active_proxy_with_consumers.last_used = time.monotonic() + active_proxy_with_consumers = _create_mock_proxy(num_consumers=99, last_used=time.monotonic()) instance.proxies["active_proxy_with_consumers"] = active_proxy_with_consumers - unused_proxy_without_consumers = Mock() - unused_proxy_without_consumers.num_consumers.return_value = 0 - unused_proxy_without_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_without_consumers.aclose.return_value = close_future + unused_proxy_without_consumers = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_without_consumers"] = unused_proxy_without_consumers await instance._disconnect_idle_proxy_if_any() @@ -116,12 +186,11 @@ async def test_rest_proxy_janitor_destructive(): active_proxy_with_consumers.reset_mock() # Proxy with consumers gets deleted after enough time has passed - unused_proxy_with_consumers = Mock() - unused_proxy_with_consumers.num_consumers.return_value = 99 - unused_proxy_with_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_with_consumers.aclose.return_value = close_future + unused_proxy_with_consumers = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_with_consumers"] = unused_proxy_with_consumers await instance._disconnect_idle_proxy_if_any()