diff --git a/README.rst b/README.rst index 606bbd9cc..fc39bbd0b 100644 --- a/README.rst +++ b/README.rst @@ -565,6 +565,24 @@ Example of complete authorization file ] } +OAuth2 authentication and authorization of Karapace REST proxy +=================================================================== + +The Karapace REST proxy supports passing OAuth2 credentials to the underlying Kafka service (defined in the ``sasl_bootstrap_uri`` configuration parameter). The JSON Web Token (JWT) is extracted from the ``Authorization`` HTTP header if the authorization scheme is ``Bearer``, +eg. ``Authorization: Bearer $JWT``. If a ``Bearer`` token is present, the Kafka clients managed by Karapace will be created to use the SASL ``OAUTHBEARER`` mechanism and the JWT will be passed along. The Karapace REST proxy does not verify the token, that is done by +the underlying Kafka service itself, if it's configured accordingly. + +Authorization is also done by Kafka itself, typically using the ``sub`` claim (although it's configurable) from the JWT as the username, checked against the configured ACLs. + +OAuth2 and ``Bearer`` token usage is dependent on the ``rest_authorization`` configuration parameter being ``true``. + +Token expiry +------------ + +The REST proxy process manages a set of producer and consumer clients, which are identified by the OAuth2 JWT token. These are periodically cleaned up if they are idle, as well as *before* the JWT token expires (the clean up currently runs every 5 minutes). + +Before a client refreshes its OAuth2 JWT token, it is expected to remove currently running consumers (eg. after committing their offsets) and producers using the current token. + Uninstall ========= diff --git a/karapace/config.py b/karapace/config.py index f426dfcc5..4358b3bfe 100644 --- a/karapace/config.py +++ b/karapace/config.py @@ -62,6 +62,7 @@ class Config(TypedDict): sasl_mechanism: str | None sasl_plain_username: str | None sasl_plain_password: str | None + sasl_oauth_token: str | None topic_name: str metadata_max_age_ms: int admin_metadata_max_age: int @@ -131,6 +132,7 @@ class ConfigDefaults(Config, total=False): "sasl_mechanism": None, "sasl_plain_username": None, SASL_PLAIN_PASSWORD: None, + "sasl_oauth_token": None, "topic_name": DEFAULT_SCHEMA_TOPIC, "metadata_max_age_ms": 60000, "admin_metadata_max_age": 5, diff --git a/karapace/kafka_rest_apis/__init__.py b/karapace/kafka_rest_apis/__init__.py index 1b6bf4b6f..4f53406f5 100644 --- a/karapace/kafka_rest_apis/__init__.py +++ b/karapace/kafka_rest_apis/__init__.py @@ -5,6 +5,7 @@ from contextlib import AsyncExitStack, closing from http import HTTPStatus from kafka.errors import ( + AuthenticationFailedError, BrokerResponseError, KafkaTimeoutError, NoBrokersAvailable, @@ -15,11 +16,16 @@ from karapace.config import Config, create_client_ssl_context from karapace.errors import InvalidSchema from karapace.kafka_rest_apis.admin import KafkaRestAdminClient +from karapace.kafka_rest_apis.authentication import ( + get_auth_config_from_header, + get_expiration_time_from_header, + get_kafka_client_auth_parameters_from_config, +) from karapace.kafka_rest_apis.consumer_manager import ConsumerManager from karapace.kafka_rest_apis.error_codes import RESTErrorCodes from karapace.kafka_rest_apis.schema_cache import TopicSchemaCache from karapace.karapace import KarapaceBase -from karapace.rapu import HTTPRequest, HTTPResponse, JSON_CONTENT_TYPE +from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE from karapace.schema_models import TypedSchema, ValidatedTypedSchema from karapace.schema_type import SchemaType from karapace.serialization import InvalidMessageSchema, InvalidPayload, SchemaRegistrySerializer, SchemaRetrievalError @@ -27,9 +33,9 @@ from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient from typing import Callable, Dict, List, Optional, Tuple, Union -import aiohttp.web import asyncio import base64 +import datetime import logging import time @@ -41,6 +47,7 @@ SCHEMA_MAPPINGS = {"avro": SchemaType.AVRO, "jsonschema": SchemaType.JSONSCHEMA, "protobuf": SchemaType.PROTOBUF} TypedConsumer = namedtuple("TypedConsumer", ["consumer", "serialization_format", "config"]) IDLE_PROXY_TIMEOUT = 5 * 60 +AUTH_EXPIRY_TOLERANCE = datetime.timedelta(seconds=IDLE_PROXY_TIMEOUT) log = logging.getLogger(__name__) @@ -86,6 +93,13 @@ async def _disconnect_idle_proxy_if_any(self) -> None: async with self._proxy_lock: # Always clean one at time, don't mutate dict while iterating for _key, _proxy in self.proxies.items(): + # In case of an OAuth2/OIDC token, the proxy is to be cleaned up _before_ the token expires + # If the token is still valid within the tolerance time, idleness is still checked + now = datetime.datetime.now(datetime.timezone.utc) + if _proxy.auth_expiry and _proxy.auth_expiry < now + AUTH_EXPIRY_TOLERANCE: + key, proxy = _key, _proxy + log.warning("Releasing unused connection for %s due to token expiry at %s", _proxy, _proxy.auth_expiry) + break # If UserRestProxy has consumers with state, disconnecting loses state if _proxy.num_consumers() > 0: if idle_consumer_timeout > 0 and _proxy.last_used + idle_consumer_timeout < time.monotonic(): @@ -271,33 +285,25 @@ async def get_user_proxy(self, request: HTTPRequest) -> "UserRestProxy": try: if self.config.get("rest_authorization", False): auth_header = request.headers.get("Authorization") + auth_config = get_auth_config_from_header(auth_header, self.config) + auth_expiry = get_expiration_time_from_header(auth_header) - if auth_header is None: - raise HTTPResponse( - body='{"message": "Unauthorized"}', - status=HTTPStatus.UNAUTHORIZED, - content_type=JSON_CONTENT_TYPE, - headers={"WWW-Authenticate": 'Basic realm="Karapace REST Proxy"'}, - ) key = auth_header if self.proxies.get(key) is None: - auth = aiohttp.BasicAuth.decode(auth_header) config = self.config.copy() config["bootstrap_uri"] = config["sasl_bootstrap_uri"] config["security_protocol"] = ( "SASL_SSL" if config["security_protocol"] in ("SSL", "SASL_SSL") else "SASL_PLAINTEXT" ) - if config["sasl_mechanism"] is None: - config["sasl_mechanism"] = "PLAIN" - config["sasl_plain_username"] = auth.login - config["sasl_plain_password"] = auth.password - self.proxies[key] = UserRestProxy(config, self.kafka_timeout, self.serializer) + config.update(auth_config) + self.proxies[key] = UserRestProxy(config, self.kafka_timeout, self.serializer, auth_expiry) else: if self.proxies.get(key) is None: self.proxies[key] = UserRestProxy(self.config, self.kafka_timeout, self.serializer) - except NoBrokersAvailable: - # This can be caused also due misconfigration, but kafka-python's - # KafkaAdminClient cannot currently distinguish those two cases + except (NoBrokersAvailable, AuthenticationFailedError): + # NoBrokersAvailable can be caused also due to misconfigration, but kafka-python's + # KafkaAdminClient cannot currently distinguish those two cases. + # A more expressive AuthenticationFailedError is raised in case of OAuth2 log.exception("Failed to connect to Kafka with the credentials") self.r(body={"message": "Forbidden"}, content_type=JSON_CONTENT_TYPE, status=HTTPStatus.FORBIDDEN) proxy = self.proxies[key] @@ -408,7 +414,13 @@ async def topic_publish(self, topic: str, content_type: str, *, request: HTTPReq class UserRestProxy: - def __init__(self, config: Config, kafka_timeout: int, serializer): + def __init__( + self, + config: Config, + kafka_timeout: int, + serializer: SchemaRegistrySerializer, + auth_expiry: Optional[datetime.datetime] = None, + ): self.config = config self.kafka_timeout = kafka_timeout self.serializer = serializer @@ -423,6 +435,7 @@ def __init__(self, config: Config, kafka_timeout: int, serializer): self.consumer_manager = ConsumerManager(config=config, deserializer=self.serializer) self.init_admin_client() self._last_used = time.monotonic() + self._auth_expiry = auth_expiry self._async_producer_lock = asyncio.Lock() self._async_producer: Optional[AIOKafkaProducer] = None @@ -437,6 +450,10 @@ def last_used(self) -> int: def mark_used(self) -> None: self._last_used = time.monotonic() + @property + def auth_expiry(self) -> datetime.datetime: + return self._auth_expiry + def num_consumers(self) -> int: return len(self.consumer_manager.consumers) @@ -471,9 +488,7 @@ async def _maybe_create_async_producer(self) -> AIOKafkaProducer: metadata_max_age_ms=self.config["metadata_max_age_ms"], security_protocol=self.config["security_protocol"], ssl_context=ssl_context, - sasl_mechanism=self.config["sasl_mechanism"], - sasl_plain_username=self.config["sasl_plain_username"], - sasl_plain_password=self.config["sasl_plain_password"], + **get_kafka_client_auth_parameters_from_config(self.config), ) try: @@ -626,13 +641,11 @@ def init_admin_client(self): ssl_cafile=self.config["ssl_cafile"], ssl_certfile=self.config["ssl_certfile"], ssl_keyfile=self.config["ssl_keyfile"], - sasl_mechanism=self.config["sasl_mechanism"], - sasl_plain_username=self.config["sasl_plain_username"], - sasl_plain_password=self.config["sasl_plain_password"], api_version=(1, 0, 0), metadata_max_age_ms=self.config["metadata_max_age_ms"], connections_max_idle_ms=self.config["connections_max_idle_ms"], kafka_client=KarapaceKafkaClient, + **get_kafka_client_auth_parameters_from_config(self.config, async_client=False), ) break except: # pylint: disable=bare-except diff --git a/karapace/kafka_rest_apis/authentication.py b/karapace/kafka_rest_apis/authentication.py new file mode 100644 index 000000000..88cab7289 --- /dev/null +++ b/karapace/kafka_rest_apis/authentication.py @@ -0,0 +1,165 @@ +""" +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +from __future__ import annotations + +from aiokafka.abc import AbstractTokenProvider as AbstractTokenProviderAsync +from http import HTTPStatus +from kafka.oauth.abstract import AbstractTokenProvider +from karapace.config import Config +from karapace.rapu import HTTPResponse, JSON_CONTENT_TYPE +from typing import NoReturn, TypedDict + +import aiohttp +import dataclasses +import datetime +import enum +import jwt + + +@enum.unique +class TokenType(enum.Enum): + BASIC = "Basic" + BEARER = "Bearer" + + +def raise_unauthorized() -> NoReturn: + raise HTTPResponse( + body='{"message": "Unauthorized"}', + status=HTTPStatus.UNAUTHORIZED, + content_type=JSON_CONTENT_TYPE, + headers={"WWW-Authenticate": 'Basic realm="Karapace REST Proxy"'}, + ) + + +class SASLPlainConfig(TypedDict): + sasl_mechanism: str | None + sasl_plain_username: str | None + sasl_plain_password: str | None + + +class SASLOauthConfig(TypedDict): + sasl_mechanism: str | None + sasl_oauth_token: str | None + + +def _split_auth_header(auth_header: str) -> tuple[str, str]: + token_type, _separator, token = auth_header.partition(" ") + return (token_type, token) + + +def get_auth_config_from_header( + auth_header: str | None, + config: Config, +) -> SASLPlainConfig | SASLOauthConfig: + """Verify the given Authorization HTTP header and constructs config parameters based on it. + + In case the Authorization header is `None`, or unknown, raises an Unauthorized HTTP response. + Known/possible authentication tokens are `Bearer` and `Basic`. + + :param auth_header: The Authorization header extracted from an HTTP request + :param config: Current config of Karapace, necessary to decide on the SASL mechanism + """ + if auth_header is None: + raise_unauthorized() + + token_type, token = _split_auth_header(auth_header) + + if token_type == TokenType.BEARER.value: + return {"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": token} + + if token_type == TokenType.BASIC.value: + basic_auth = aiohttp.BasicAuth.decode(auth_header) + sasl_mechanism = config["sasl_mechanism"] + if sasl_mechanism is None: + sasl_mechanism = "PLAIN" + + return { + "sasl_mechanism": sasl_mechanism, + "sasl_plain_username": basic_auth.login, + "sasl_plain_password": basic_auth.password, + } + + raise_unauthorized() + + +def get_expiration_time_from_header(auth_header: str) -> datetime.datetime | None: + """Extract expiration from Authorization HTTP header. + + In case of an OAuth Bearer token, the `exp` claim is extracted and returned as a + `datetime.datetime` object. Otherwise it's safely assumed that the authentication + method is Basic, thus no expiry of the credentials. + + The signature is not verified as it is done by the Kafka clients using it and + discarding the token in case of any issues. + + :param auth_header: The Authorization header extracted from an HTTP request + """ + token_type, token = _split_auth_header(auth_header) + + if token_type == TokenType.BEARER.value: + exp_claim = jwt.decode(token, options={"verify_signature": False}).get("exp") + if exp_claim is not None: + return datetime.datetime.fromtimestamp(exp_claim, datetime.timezone.utc) + + return None + + +@dataclasses.dataclass +class SimpleOauthTokenProvider(AbstractTokenProvider): + """A pass-through OAuth token provider to be used by synchronous Kafka clients. + + The token is meant to be extracted from an HTTP Authorization header. + """ + + _token: str + + def token(self) -> str: + return self._token + + +@dataclasses.dataclass +class SimpleOauthTokenProviderAsync(AbstractTokenProviderAsync): + """A pass-through OAuth token provider to be used by asynchronous Kafka clients. + + The token is meant to be extracted from an HTTP Authorization header. + """ + + _token: str + + async def token(self) -> str: + return self._token + + +class SASLOauthParams(TypedDict): + sasl_mechanism: str + sasl_oauth_token_provider: AbstractTokenProvider | AbstractTokenProviderAsync + + +def get_kafka_client_auth_parameters_from_config( + config: Config, + *, + async_client: bool = True, +) -> SASLPlainConfig | SASLOauthParams: + """Create authentication parameters for a Kafka client based on the Karapace config. + + In case of an `OAUTHBEARER` SASL mechanism present in the config, will create the + OAuth token provider needed by the Kafka client - the `async_client` parameter + decides whether this will be a sync or async one. + + :param config: Current config of Karapace + :param async_client: Flag to indicate whether the Kafka client using the returned paramaters is async + """ + if config["sasl_mechanism"] == "OAUTHBEARER": + token_provider_cls = SimpleOauthTokenProviderAsync if async_client else SimpleOauthTokenProvider + return { + "sasl_mechanism": config["sasl_mechanism"], + "sasl_oauth_token_provider": token_provider_cls(config["sasl_oauth_token"]), + } + + return { + "sasl_mechanism": config["sasl_mechanism"], + "sasl_plain_username": config["sasl_plain_username"], + "sasl_plain_password": config["sasl_plain_password"], + } diff --git a/karapace/kafka_rest_apis/consumer_manager.py b/karapace/kafka_rest_apis/consumer_manager.py index 72a1bea9a..b2960c95d 100644 --- a/karapace/kafka_rest_apis/consumer_manager.py +++ b/karapace/kafka_rest_apis/consumer_manager.py @@ -10,6 +10,7 @@ from kafka.errors import GroupAuthorizationFailedError, IllegalStateError, KafkaConfigurationError, KafkaError from kafka.structs import TopicPartition from karapace.config import Config, create_client_ssl_context +from karapace.kafka_rest_apis.authentication import get_kafka_client_auth_parameters_from_config from karapace.kafka_rest_apis.error_codes import RESTErrorCodes from karapace.karapace import empty_response, KarapaceBase from karapace.serialization import DeserializationError, InvalidMessageHeader, InvalidPayload, SchemaRegistrySerializer @@ -205,9 +206,6 @@ async def create_kafka_consumer(self, fetch_min_bytes, group_name, internal_name client_id=internal_name, security_protocol=self.config["security_protocol"], ssl_context=ssl_context, - sasl_mechanism=self.config["sasl_mechanism"], - sasl_plain_username=self.config["sasl_plain_username"], - sasl_plain_password=self.config["sasl_plain_password"], group_id=group_name, fetch_min_bytes=max(1, fetch_min_bytes), # Discard earlier negative values fetch_max_bytes=self.config["consumer_request_max_bytes"], @@ -218,6 +216,7 @@ async def create_kafka_consumer(self, fetch_min_bytes, group_name, internal_name enable_auto_commit=request_data["auto.commit.enable"], auto_offset_reset=request_data["auto.offset.reset"], session_timeout_ms=session_timeout_ms, + **get_kafka_client_auth_parameters_from_config(self.config), ) await c.start() return c diff --git a/mypy.ini b/mypy.ini index 291bb88b7..e06dfb201 100644 --- a/mypy.ini +++ b/mypy.ini @@ -86,6 +86,9 @@ ignore_errors = True # dependencies. # - Write your own stubs. You don't need to write stubs for the whole library, # only the parts that Karapace is interacting with. +[mypy-aiokafka.*] +ignore_missing_imports = True + [mypy-kafka.*] ignore_missing_imports = True diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index de1aa2746..ca073a342 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -146,6 +146,8 @@ pygments==2.15.1 # -r requirements.txt # pdbpp # rich +pyjwt==2.8.0 + # via -r requirements-dev.in pyrepl==0.9.0 # via fancycompleter pytest==7.4.0 diff --git a/requirements/requirements.in b/requirements/requirements.in index 25bfad6b7..d3ad217f2 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -6,6 +6,7 @@ isodate<1 jsonschema<5 networkx<4 protobuf<4 +pyjwt>=2.4.0<3 python-dateutil<3 tenacity<9 typing-extensions diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 2f636cc77..e45aff0fe 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -61,6 +61,8 @@ protobuf==3.20.3 # via -r requirements.in pygments==2.15.1 # via rich +pyjwt==2.8.0 + # via -r requirements.in python-dateutil==2.8.2 # via -r requirements.in referencing==0.30.0 diff --git a/setup.py b/setup.py index d36bcb666..9abf3be3c 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "kafka-python", "networkx", "protobuf", + "pyjwt", "python-dateutil", ], extras_require={ diff --git a/tests/unit/test_authentication.py b/tests/unit/test_authentication.py new file mode 100644 index 000000000..3921f0560 --- /dev/null +++ b/tests/unit/test_authentication.py @@ -0,0 +1,130 @@ +""" +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +from __future__ import annotations + +from http import HTTPStatus +from karapace.config import ConfigDefaults, set_config_defaults +from karapace.kafka_rest_apis.authentication import ( + get_auth_config_from_header, + get_expiration_time_from_header, + get_kafka_client_auth_parameters_from_config, + SimpleOauthTokenProvider, + SimpleOauthTokenProviderAsync, +) +from karapace.rapu import HTTPResponse, JSON_CONTENT_TYPE + +import base64 +import datetime +import jwt +import pytest + + +@pytest.mark.parametrize( + "auth_header", + (None, "Digest foo=bar"), +) +def test_get_auth_config_from_header_raises_unauthorized_on_invalid_header(auth_header: str | None) -> None: + config = set_config_defaults({}) + + with pytest.raises(HTTPResponse) as exc_info: + get_auth_config_from_header(auth_header, config) + + http_resonse = exc_info.value + assert http_resonse.body == '{"message": "Unauthorized"}' + assert http_resonse.status == HTTPStatus.UNAUTHORIZED + assert http_resonse.headers["Content-Type"] == JSON_CONTENT_TYPE + assert http_resonse.headers["WWW-Authenticate"] == 'Basic realm="Karapace REST Proxy"' + + +@pytest.mark.parametrize( + ("auth_header", "config_override", "expected_auth_config"), + ( + ( + f"Basic {base64.b64encode(b'username:password').decode()}", + {"sasl_mechanism": None}, + {"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"}, + ), + ( + f"Basic {base64.b64encode(b'username:password').decode()}", + {"sasl_mechanism": "PLAIN"}, + {"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"}, + ), + ( + f"Basic {base64.b64encode(b'username:password').decode()}", + {"sasl_mechanism": "SCRAM"}, + {"sasl_mechanism": "SCRAM", "sasl_plain_username": "username", "sasl_plain_password": "password"}, + ), + ( + "Bearer ", + {}, + {"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": ""}, + ), + ), +) +def test_get_auth_config_from_header( + auth_header: str, config_override: ConfigDefaults, expected_auth_config: ConfigDefaults +) -> None: + config = set_config_defaults(config_override) + auth_config = get_auth_config_from_header(auth_header, config) + assert auth_config == expected_auth_config + + +@pytest.mark.parametrize( + ("auth_header", "expected_expiration"), + ( + (f"Basic {base64.b64encode(b'username:password').decode()}", None), + ( + f"Bearer {jwt.encode({'exp': 1697013997}, 'secret')}", + datetime.datetime.fromtimestamp(1697013997, datetime.timezone.utc), + ), + (f"Bearer {jwt.encode({}, 'secret')}", None), + ), +) +def test_get_expiration_time_from_header(auth_header: str, expected_expiration: datetime.datetime) -> None: + expiration = get_expiration_time_from_header(auth_header) + + assert expiration == expected_expiration + + +def test_simple_oauth_token_provider_returns_configured_token() -> None: + token_provider = SimpleOauthTokenProvider("TOKEN") + assert token_provider.token() == "TOKEN" + + +async def test_simple_oauth_token_provider_async_returns_configured_token() -> None: + token_provider = SimpleOauthTokenProviderAsync("TOKEN") + assert await token_provider.token() == "TOKEN" + + +def test_get_client_auth_parameters_from_config_sasl_plain() -> None: + config = set_config_defaults( + {"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"} + ) + + client_auth_params = get_kafka_client_auth_parameters_from_config(config) + + assert client_auth_params == { + "sasl_mechanism": "PLAIN", + "sasl_plain_username": "username", + "sasl_plain_password": "password", + } + + +def test_get_client_auth_parameters_from_config_oauth() -> None: + config = set_config_defaults({"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": "TOKEN"}) + + client_auth_params = get_kafka_client_auth_parameters_from_config(config, async_client=False) + + assert client_auth_params["sasl_mechanism"] == "OAUTHBEARER" + assert client_auth_params["sasl_oauth_token_provider"].token() == "TOKEN" + + +async def test_get_client_auth_parameters_from_config_oauth_async() -> None: + config = set_config_defaults({"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": "TOKEN"}) + + client_auth_params = get_kafka_client_auth_parameters_from_config(config, async_client=True) + + assert client_auth_params["sasl_mechanism"] == "OAUTHBEARER" + assert await client_auth_params["sasl_oauth_token_provider"].token() == "TOKEN" diff --git a/tests/unit/test_rest_auth.py b/tests/unit/test_rest_auth.py index 4f06c034a..86bb14b8a 100644 --- a/tests/unit/test_rest_auth.py +++ b/tests/unit/test_rest_auth.py @@ -3,15 +3,96 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ +from __future__ import annotations + from karapace.config import set_config_defaults -from karapace.kafka_rest_apis import KafkaRest +from karapace.kafka_rest_apis import AUTH_EXPIRY_TOLERANCE, KafkaRest, UserRestProxy from unittest.mock import call, Mock import asyncio +import datetime import time -async def test_rest_proxy_janitor_default(): +def _create_mock_proxy( + *, + num_consumers: int, + last_used: float, + auth_expiry: datetime.datetime | None = None, + _with_close_future: bool = False, +) -> Mock: + proxy = Mock(spec=UserRestProxy) + proxy.num_consumers.return_value = num_consumers + proxy.last_used = last_used + proxy.auth_expiry = auth_expiry + + if _with_close_future: + close_future = asyncio.Future() + close_future.set_result(True) + proxy.aclose.return_value = close_future + + return proxy + + +async def test_rest_proxy_janitor_expiring_credentials() -> None: + config = set_config_defaults( + { + "rest_authorization": True, + "sasl_bootstrap_uri": "localhost:9094", + } + ) + instance = KafkaRest(config=config) + + proxy_expiring_within_tolerance = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic(), + auth_expiry=datetime.datetime.now(datetime.timezone.utc) + (AUTH_EXPIRY_TOLERANCE / 2), + _with_close_future=True, + ) + instance.proxies["proxy_expiring_within_tolerance"] = proxy_expiring_within_tolerance + + proxy_already_expired = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic(), + auth_expiry=datetime.datetime.now(datetime.timezone.utc) - AUTH_EXPIRY_TOLERANCE, + _with_close_future=True, + ) + instance.proxies["proxy_already_expired"] = proxy_already_expired + + active_proxy_expiring_later_than_tolerance = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic(), + auth_expiry=datetime.datetime.now(datetime.timezone.utc) + 2 * AUTH_EXPIRY_TOLERANCE, + ) + instance.proxies["active_proxy_expiring_later_than_tolerance"] = active_proxy_expiring_later_than_tolerance + + unused_proxy_expiring_later_than_tolerance = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic() - 600, + auth_expiry=datetime.datetime.now(datetime.timezone.utc) + 2 * AUTH_EXPIRY_TOLERANCE, + _with_close_future=True, + ) + instance.proxies["unused_proxy_expiring_later_than_tolerance"] = unused_proxy_expiring_later_than_tolerance + + # Needs to be called multiple times to clean all expected proxies, this method only + # releases the first one it finds, then exits. + await instance._disconnect_idle_proxy_if_any() + await instance._disconnect_idle_proxy_if_any() + await instance._disconnect_idle_proxy_if_any() + + assert instance.proxies.get("proxy_expiring_within_tolerance") is None + assert instance.proxies.get("proxy_already_expired") is None + assert instance.proxies.get("active_proxy_expiring_later_than_tolerance") is not None + assert instance.proxies.get("unused_proxy_expiring_later_than_tolerance") is None + assert len(instance.proxies) == 1 + + assert proxy_expiring_within_tolerance.method_calls == [call.aclose()] + assert proxy_already_expired.method_calls == [call.aclose()] + assert active_proxy_expiring_later_than_tolerance.method_calls == [call.num_consumers()] + assert unused_proxy_expiring_later_than_tolerance.method_calls == [call.num_consumers(), call.aclose()] + + +async def test_rest_proxy_janitor_default() -> None: config = set_config_defaults( { "rest_authorization": True, @@ -20,22 +101,17 @@ async def test_rest_proxy_janitor_default(): ) instance = KafkaRest(config=config) - active_proxy_without_consumers = Mock() - active_proxy_without_consumers.num_consumers.return_value = 0 - active_proxy_without_consumers.last_used = time.monotonic() + active_proxy_without_consumers = _create_mock_proxy(num_consumers=0, last_used=time.monotonic()) instance.proxies["active_proxy_without_consumers"] = active_proxy_without_consumers - active_proxy_with_consumers = Mock() - active_proxy_with_consumers.num_consumers.return_value = 99 - active_proxy_with_consumers.last_used = time.monotonic() + active_proxy_with_consumers = _create_mock_proxy(num_consumers=99, last_used=time.monotonic()) instance.proxies["active_proxy_with_consumers"] = active_proxy_with_consumers - unused_proxy_without_consumers = Mock() - unused_proxy_without_consumers.num_consumers.return_value = 0 - unused_proxy_without_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_without_consumers.aclose.return_value = close_future + unused_proxy_without_consumers = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_without_consumers"] = unused_proxy_without_consumers await instance._disconnect_idle_proxy_if_any() @@ -53,12 +129,11 @@ async def test_rest_proxy_janitor_default(): active_proxy_with_consumers.reset_mock() # Proxy with consumers is not deleted without explicit config - unused_proxy_with_consumers = Mock() - unused_proxy_with_consumers.num_consumers.return_value = 99 - unused_proxy_with_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_with_consumers.aclose.return_value = close_future + unused_proxy_with_consumers = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_with_consumers"] = unused_proxy_with_consumers await instance._disconnect_idle_proxy_if_any() @@ -73,7 +148,7 @@ async def test_rest_proxy_janitor_default(): assert active_proxy_with_consumers.method_calls == [call.num_consumers()] -async def test_rest_proxy_janitor_destructive(): +async def test_rest_proxy_janitor_destructive() -> None: config = set_config_defaults( { "rest_authorization": True, @@ -83,22 +158,17 @@ async def test_rest_proxy_janitor_destructive(): ) instance = KafkaRest(config=config) - active_proxy_without_consumers = Mock() - active_proxy_without_consumers.num_consumers.return_value = 0 - active_proxy_without_consumers.last_used = time.monotonic() + active_proxy_without_consumers = _create_mock_proxy(num_consumers=0, last_used=time.monotonic()) instance.proxies["active_proxy_without_consumers"] = active_proxy_without_consumers - active_proxy_with_consumers = Mock() - active_proxy_with_consumers.num_consumers.return_value = 99 - active_proxy_with_consumers.last_used = time.monotonic() + active_proxy_with_consumers = _create_mock_proxy(num_consumers=99, last_used=time.monotonic()) instance.proxies["active_proxy_with_consumers"] = active_proxy_with_consumers - unused_proxy_without_consumers = Mock() - unused_proxy_without_consumers.num_consumers.return_value = 0 - unused_proxy_without_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_without_consumers.aclose.return_value = close_future + unused_proxy_without_consumers = _create_mock_proxy( + num_consumers=0, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_without_consumers"] = unused_proxy_without_consumers await instance._disconnect_idle_proxy_if_any() @@ -116,12 +186,11 @@ async def test_rest_proxy_janitor_destructive(): active_proxy_with_consumers.reset_mock() # Proxy with consumers gets deleted after enough time has passed - unused_proxy_with_consumers = Mock() - unused_proxy_with_consumers.num_consumers.return_value = 99 - unused_proxy_with_consumers.last_used = time.monotonic() - 600 - close_future = asyncio.Future() - close_future.set_result(True) - unused_proxy_with_consumers.aclose.return_value = close_future + unused_proxy_with_consumers = _create_mock_proxy( + num_consumers=99, + last_used=time.monotonic() - 600, + _with_close_future=True, + ) instance.proxies["unused_proxy_with_consumers"] = unused_proxy_with_consumers await instance._disconnect_idle_proxy_if_any()