From 79b6ae3739453426beae2012cc0f1f27a6888108 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1ty=C3=A1s=20Kuti?= Date: Thu, 12 Oct 2023 11:48:31 +0200 Subject: [PATCH] Clean up proxies with expiring credentials Due to the behaviour of the Kafka clients, we want to avoid them going into a retry loop in case of an expired OAuth token. The current proxy janitor _would_ clean them up eventually but this is behaviour to be avoided. Thus the proxy janitor is extended to clean up proxies that have auth tokens expiring within a "tolerance" amount of time. To extract the expiry time (`exp` claim) from an OAuth token, the `pyjwt` library is used. With this action we do not verify the token signature (which would not be possible without knowing the signing public key). Normally this would present the issue of tokens that have been tampered with and not re-signed. However, when using a token for the first time, the Karapace REST proxy instantiates a Kafka admin client, which results in immediate verification of the token by Kafka. So even if the expiration time has been maliciously changed, the connection will be refused by Kafka, resulting in an Unauthorized response from Karapace to its client. --- karapace/kafka_rest_apis/__init__.py | 30 ++++- karapace/kafka_rest_apis/authentication.py | 31 ++++- requirements/requirements-dev.txt | 2 + requirements/requirements.in | 1 + requirements/requirements.txt | 2 + tests/unit/test_authentication.py | 17 +++ tests/unit/test_rest_auth.py | 145 +++++++++++++++------ 7 files changed, 185 insertions(+), 43 deletions(-) diff --git a/karapace/kafka_rest_apis/__init__.py b/karapace/kafka_rest_apis/__init__.py index ed99fe596..74a24e180 100644 --- a/karapace/kafka_rest_apis/__init__.py +++ b/karapace/kafka_rest_apis/__init__.py @@ -16,7 +16,11 @@ from karapace.config import Config, create_client_ssl_context from karapace.errors import InvalidSchema from karapace.kafka_rest_apis.admin import KafkaRestAdminClient -from karapace.kafka_rest_apis.authentication import get_auth_config_from_header, get_kafka_client_auth_parameters_from_config +from karapace.kafka_rest_apis.authentication import ( + get_auth_config_from_header, + get_expiration_time_from_header, + get_kafka_client_auth_parameters_from_config, +) from karapace.kafka_rest_apis.consumer_manager import ConsumerManager from karapace.kafka_rest_apis.error_codes import RESTErrorCodes from karapace.kafka_rest_apis.schema_cache import TopicSchemaCache @@ -31,6 +35,7 @@ import asyncio import base64 +import datetime import logging import time @@ -42,6 +47,7 @@ SCHEMA_MAPPINGS = {"avro": SchemaType.AVRO, "jsonschema": SchemaType.JSONSCHEMA, "protobuf": SchemaType.PROTOBUF} TypedConsumer = namedtuple("TypedConsumer", ["consumer", "serialization_format", "config"]) IDLE_PROXY_TIMEOUT = 5 * 60 +AUTH_EXPIRY_TOLERANCE = datetime.timedelta(seconds=IDLE_PROXY_TIMEOUT) log = logging.getLogger(__name__) @@ -87,6 +93,12 @@ async def _disconnect_idle_proxy_if_any(self) -> None: async with self._proxy_lock: # Always clean one at time, don't mutate dict while iterating for _key, _proxy in self.proxies.items(): + # In case of an OAuth2/OIDC token, the proxy is to be cleaned up _before_ the token expires + # If the token is still valid within the tolerance time, idleness is still checked + if _proxy.auth_expiry and _proxy.auth_expiry < datetime.datetime.now() + AUTH_EXPIRY_TOLERANCE: + key, proxy = _key, _proxy + log.warning("Releasing unused connection for %s due to token expiry at %s", _proxy, _proxy.auth_expiry) + break # If UserRestProxy has consumers with state, disconnecting loses state if _proxy.num_consumers() > 0: if idle_consumer_timeout > 0 and _proxy.last_used + idle_consumer_timeout < time.monotonic(): @@ -273,6 +285,7 @@ async def get_user_proxy(self, request: HTTPRequest) -> "UserRestProxy": if self.config.get("rest_authorization", False): auth_header = request.headers.get("Authorization") auth_config = get_auth_config_from_header(auth_header, self.config) + auth_expiry = get_expiration_time_from_header(auth_header) key = auth_header if self.proxies.get(key) is None: @@ -282,7 +295,7 @@ async def get_user_proxy(self, request: HTTPRequest) -> "UserRestProxy": "SASL_SSL" if config["security_protocol"] in ("SSL", "SASL_SSL") else "SASL_PLAINTEXT" ) config.update(auth_config) - self.proxies[key] = UserRestProxy(config, self.kafka_timeout, self.serializer) + self.proxies[key] = UserRestProxy(config, self.kafka_timeout, self.serializer, auth_expiry) else: if self.proxies.get(key) is None: self.proxies[key] = UserRestProxy(self.config, self.kafka_timeout, self.serializer) @@ -400,7 +413,13 @@ async def topic_publish(self, topic: str, content_type: str, *, request: HTTPReq class UserRestProxy: - def __init__(self, config: Config, kafka_timeout: int, serializer): + def __init__( + self, + config: Config, + kafka_timeout: int, + serializer: SchemaRegistrySerializer, + auth_expiry: datetime.datetime | None = None, + ): self.config = config self.kafka_timeout = kafka_timeout self.serializer = serializer @@ -415,6 +434,7 @@ def __init__(self, config: Config, kafka_timeout: int, serializer): self.consumer_manager = ConsumerManager(config=config, deserializer=self.serializer) self.init_admin_client() self._last_used = time.monotonic() + self._auth_expiry = auth_expiry self._async_producer_lock = asyncio.Lock() self._async_producer: Optional[AIOKafkaProducer] = None @@ -429,6 +449,10 @@ def last_used(self) -> int: def mark_used(self) -> None: self._last_used = time.monotonic() + @property + def auth_expiry(self) -> datetime.datetime: + return self._auth_expiry + def num_consumers(self) -> int: return len(self.consumer_manager.consumers) diff --git a/karapace/kafka_rest_apis/authentication.py b/karapace/kafka_rest_apis/authentication.py index c03df7ca2..743eb7e11 100644 --- a/karapace/kafka_rest_apis/authentication.py +++ b/karapace/kafka_rest_apis/authentication.py @@ -13,7 +13,9 @@ import aiohttp import dataclasses +import datetime import enum +import jwt @enum.unique @@ -42,6 +44,11 @@ class SASLOauthConfig(TypedDict): sasl_oauth_token: str | None +def _split_auth_header(auth_header: str) -> tuple[str, str]: + token_type, _separator, token = auth_header.partition(" ") + return (token_type, token) + + def get_auth_config_from_header( auth_header: str | None, config: Config, @@ -57,7 +64,7 @@ def get_auth_config_from_header( if auth_header is None: raise_unauthorized() - token_type, _separator, token = auth_header.partition(" ") + token_type, token = _split_auth_header(auth_header) if token_type == TokenType.BEARER.value: return {"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": token} @@ -77,6 +84,28 @@ def get_auth_config_from_header( raise_unauthorized() +def get_expiration_time_from_header(auth_header: str) -> datetime.datetime | None: + """Extract expiration from Authorization HTTP header. + + In case of an OAuth Bearer token, the `exp` claim is extracted and returned as a + `datetime.datetime` object. Otherwise it's safely assumed that the authentication + method is Basic, thus no expiry of the credentials. + + The signature is not verified as it is done by the Kafka clients using it and + discarding the token in case of any issues. + + :param auth_header: The Authorization header extracted from an HTTP request + """ + token_type, token = _split_auth_header(auth_header) + + if token_type == TokenType.BEARER.value: + exp_claim = jwt.decode(token, options={"verify_signature": False}).get("exp") + if exp_claim is not None: + return datetime.datetime.fromtimestamp(exp_claim) + + return None + + @dataclasses.dataclass class SimpleOauthTokenProvider(AbstractTokenProvider): """A pass-through OAuth token provider to be used by synchronous Kafka clients. diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index de1aa2746..ca073a342 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -146,6 +146,8 @@ pygments==2.15.1 # -r requirements.txt # pdbpp # rich +pyjwt==2.8.0 + # via -r requirements-dev.in pyrepl==0.9.0 # via fancycompleter pytest==7.4.0 diff --git a/requirements/requirements.in b/requirements/requirements.in index 25bfad6b7..dae6c50ae 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -6,6 +6,7 @@ isodate<1 jsonschema<5 networkx<4 protobuf<4 +pyjwt<3 python-dateutil<3 tenacity<9 typing-extensions diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 2f636cc77..e45aff0fe 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -61,6 +61,8 @@ protobuf==3.20.3 # via -r requirements.in pygments==2.15.1 # via rich +pyjwt==2.8.0 + # via -r requirements.in python-dateutil==2.8.2 # via -r requirements.in referencing==0.30.0 diff --git a/tests/unit/test_authentication.py b/tests/unit/test_authentication.py index 86afd4a9f..94f1365ba 100644 --- a/tests/unit/test_authentication.py +++ b/tests/unit/test_authentication.py @@ -8,6 +8,7 @@ from karapace.config import ConfigDefaults, set_config_defaults from karapace.kafka_rest_apis.authentication import ( get_auth_config_from_header, + get_expiration_time_from_header, get_kafka_client_auth_parameters_from_config, SimpleOauthTokenProvider, SimpleOauthTokenProviderAsync, @@ -15,6 +16,8 @@ from karapace.rapu import HTTPResponse, JSON_CONTENT_TYPE import base64 +import datetime +import jwt import pytest @@ -68,6 +71,20 @@ def test_get_auth_config_from_header( assert auth_config == expected_auth_config +@pytest.mark.parametrize( + ("auth_header", "expected_expiration"), + ( + (f"Basic {base64.b64encode(b'username:password').decode()}", None), + (f"Bearer {jwt.encode({'exp': 1697013997}, 'secret')}", datetime.datetime.fromtimestamp(1697013997)), + (f"Bearer {jwt.encode({}, 'secret')}", None), + ), +) +def test_get_expiration_time_from_header(auth_header: str, expected_expiration: datetime.datetime) -> None: + expiration = get_expiration_time_from_header(auth_header) + + assert expiration == expected_expiration + + def test_simple_oauth_token_provider_returns_configured_token() -> None: token_provider = SimpleOauthTokenProvider("TOKEN") assert token_provider.token() == "TOKEN" diff --git a/tests/unit/test_rest_auth.py b/tests/unit/test_rest_auth.py index 4f06c034a..7d9364ee8 100644 --- a/tests/unit/test_rest_auth.py +++ b/tests/unit/test_rest_auth.py @@ -4,14 +4,35 @@ See LICENSE for details """ from karapace.config import set_config_defaults -from karapace.kafka_rest_apis import KafkaRest +from karapace.kafka_rest_apis import KafkaRest, UserRestProxy, AUTH_EXPIRY_TOLERANCE from unittest.mock import call, Mock import asyncio +import datetime import time -async def test_rest_proxy_janitor_default(): +def _create_mock_proxy( + *, + num_consumers: int, + last_used: float, + auth_expiry: datetime.datetime | None = None, + _with_close_future: bool = False, +) -> Mock: + proxy = Mock(spec=UserRestProxy) + proxy.num_consumers.return_value = num_consumers + proxy.last_used = last_used + proxy.auth_expiry = auth_expiry + + if _with_close_future: + close_future = asyncio.Future() + close_future.set_result(True) + proxy.aclose.return_value = close_future + + return proxy + + +async def test_rest_proxy_janitor_expiring_credentials() -> None: config = set_config_defaults( { "rest_authorization": True, @@ -20,22 +41,75 @@ async def test_rest_proxy_janitor_default(): ) instance = KafkaRest(config=config) - active_proxy_without_consumers = Mock() - active_proxy_without_consumers.num_consumers.return_value = 0 - active_proxy_without_consumers.last_used = time.monotonic() + proxy_expiring_within_tolerance = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic(), + auth_expiry=datetime.datetime.now() + (AUTH_EXPIRY_TOLERANCE / 2), + _with_close_future=True, + ) + instance.proxies["proxy_expiring_within_tolerance"] = proxy_expiring_within_tolerance + + proxy_already_expired = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic(), + auth_expiry=datetime.datetime.now() - AUTH_EXPIRY_TOLERANCE, + _with_close_future=True, + ) + instance.proxies["proxy_already_expired"] = proxy_already_expired + + active_proxy_expiring_later_than_tolerance = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic(), + auth_expiry=datetime.datetime.now() + 2 * AUTH_EXPIRY_TOLERANCE, + ) + instance.proxies["active_proxy_expiring_later_than_tolerance"] = active_proxy_expiring_later_than_tolerance + + unused_proxy_expiring_later_than_tolerance = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic() - 600, + auth_expiry=datetime.datetime.now() + 2 * AUTH_EXPIRY_TOLERANCE, + _with_close_future=True, + ) + instance.proxies["unused_proxy_expiring_later_than_tolerance"] = unused_proxy_expiring_later_than_tolerance + + # Needs to be called multiple times to clean all expected proxies, this method only + # releases the first one it finds, then exits. + await instance._disconnect_idle_proxy_if_any() + await instance._disconnect_idle_proxy_if_any() + await instance._disconnect_idle_proxy_if_any() + + assert instance.proxies.get("proxy_expiring_within_tolerance") is None + assert instance.proxies.get("proxy_already_expired") is None + assert instance.proxies.get("active_proxy_expiring_later_than_tolerance") is not None + assert instance.proxies.get("unused_proxy_expiring_later_than_tolerance") is None + assert len(instance.proxies) == 1 + + assert proxy_expiring_within_tolerance.method_calls == [call.aclose()] + assert proxy_already_expired.method_calls == [call.aclose()] + assert active_proxy_expiring_later_than_tolerance.method_calls == [call.num_consumers()] + assert unused_proxy_expiring_later_than_tolerance.method_calls == [call.num_consumers(), call.aclose()] + + +async def test_rest_proxy_janitor_default() -> None: + config = set_config_defaults( + { + "rest_authorization": True, + "sasl_bootstrap_uri": "localhost:9094", + } + ) + instance = KafkaRest(config=config) + + active_proxy_without_consumers = _create_mock_proxy(num_consumers=0, last_used=time.monotonic()) instance.proxies["active_proxy_without_consumers"] = active_proxy_without_consumers - active_proxy_with_consumers = Mock() - active_proxy_with_consumers.num_consumers.return_value = 99 - active_proxy_with_consumers.last_used = time.monotonic() + active_proxy_with_consumers = _create_mock_proxy(num_consumers=99, last_used=time.monotonic()) instance.proxies["active_proxy_with_consumers"] = active_proxy_with_consumers - unused_proxy_without_consumers = Mock() - unused_proxy_without_consumers.num_consumers.return_value = 0 - unused_proxy_without_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_without_consumers.aclose.return_value = close_future + unused_proxy_without_consumers = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_without_consumers"] = unused_proxy_without_consumers await instance._disconnect_idle_proxy_if_any() @@ -53,12 +127,11 @@ async def test_rest_proxy_janitor_default(): active_proxy_with_consumers.reset_mock() # Proxy with consumers is not deleted without explicit config - unused_proxy_with_consumers = Mock() - unused_proxy_with_consumers.num_consumers.return_value = 99 - unused_proxy_with_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_with_consumers.aclose.return_value = close_future + unused_proxy_with_consumers = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_with_consumers"] = unused_proxy_with_consumers await instance._disconnect_idle_proxy_if_any() @@ -73,7 +146,7 @@ async def test_rest_proxy_janitor_default(): assert active_proxy_with_consumers.method_calls == [call.num_consumers()] -async def test_rest_proxy_janitor_destructive(): +async def test_rest_proxy_janitor_destructive() -> None: config = set_config_defaults( { "rest_authorization": True, @@ -83,22 +156,17 @@ async def test_rest_proxy_janitor_destructive(): ) instance = KafkaRest(config=config) - active_proxy_without_consumers = Mock() - active_proxy_without_consumers.num_consumers.return_value = 0 - active_proxy_without_consumers.last_used = time.monotonic() + active_proxy_without_consumers = _create_mock_proxy(num_consumers=0, last_used=time.monotonic()) instance.proxies["active_proxy_without_consumers"] = active_proxy_without_consumers - active_proxy_with_consumers = Mock() - active_proxy_with_consumers.num_consumers.return_value = 99 - active_proxy_with_consumers.last_used = time.monotonic() + active_proxy_with_consumers = _create_mock_proxy(num_consumers=99, last_used=time.monotonic()) instance.proxies["active_proxy_with_consumers"] = active_proxy_with_consumers - unused_proxy_without_consumers = Mock() - unused_proxy_without_consumers.num_consumers.return_value = 0 - unused_proxy_without_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_without_consumers.aclose.return_value = close_future + unused_proxy_without_consumers = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_without_consumers"] = unused_proxy_without_consumers await instance._disconnect_idle_proxy_if_any() @@ -116,12 +184,11 @@ async def test_rest_proxy_janitor_destructive(): active_proxy_with_consumers.reset_mock() # Proxy with consumers gets deleted after enough time has passed - unused_proxy_with_consumers = Mock() - unused_proxy_with_consumers.num_consumers.return_value = 99 - unused_proxy_with_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_with_consumers.aclose.return_value = close_future + unused_proxy_with_consumers = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_with_consumers"] = unused_proxy_with_consumers await instance._disconnect_idle_proxy_if_any()