Skip to content

Commit

Permalink
Clean up proxies with expiring credentials
Browse files Browse the repository at this point in the history
Due to the behaviour of the Kafka clients, we want to avoid them going
into a retry loop in case of an expired OAuth token. The current proxy
janitor _would_ clean them up eventually but this is behaviour to be
avoided.

Thus the proxy janitor is extended to clean up proxies that have auth
tokens expiring within a "tolerance" amount of time.

To extract the expiry time (`exp` claim) from an OAuth token, the
`pyjwt` library is used. With this action we do not verify the token
signature (which would not be possible without knowing the signing
public key). Normally this would present the issue of tokens that have
been tampered with and not re-signed. However, when using a token for
the first time, the Karapace REST proxy instantiates a Kafka admin
client, which results in immediate verification of the token by Kafka.
So even if the expiration time has been maliciously changed, the
connection will be refused by Kafka, resulting in an Unauthorized
response from Karapace to its client.
  • Loading branch information
Mátyás Kuti committed Oct 12, 2023
1 parent b19cf19 commit 79b6ae3
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 43 deletions.
30 changes: 27 additions & 3 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from karapace.config import Config, create_client_ssl_context
from karapace.errors import InvalidSchema
from karapace.kafka_rest_apis.admin import KafkaRestAdminClient
from karapace.kafka_rest_apis.authentication import get_auth_config_from_header, get_kafka_client_auth_parameters_from_config
from karapace.kafka_rest_apis.authentication import (
get_auth_config_from_header,
get_expiration_time_from_header,
get_kafka_client_auth_parameters_from_config,
)
from karapace.kafka_rest_apis.consumer_manager import ConsumerManager
from karapace.kafka_rest_apis.error_codes import RESTErrorCodes
from karapace.kafka_rest_apis.schema_cache import TopicSchemaCache
Expand All @@ -31,6 +35,7 @@

import asyncio
import base64
import datetime
import logging
import time

Expand All @@ -42,6 +47,7 @@
SCHEMA_MAPPINGS = {"avro": SchemaType.AVRO, "jsonschema": SchemaType.JSONSCHEMA, "protobuf": SchemaType.PROTOBUF}
TypedConsumer = namedtuple("TypedConsumer", ["consumer", "serialization_format", "config"])
IDLE_PROXY_TIMEOUT = 5 * 60
AUTH_EXPIRY_TOLERANCE = datetime.timedelta(seconds=IDLE_PROXY_TIMEOUT)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,6 +93,12 @@ async def _disconnect_idle_proxy_if_any(self) -> None:
async with self._proxy_lock:
# Always clean one at time, don't mutate dict while iterating
for _key, _proxy in self.proxies.items():
# In case of an OAuth2/OIDC token, the proxy is to be cleaned up _before_ the token expires
# If the token is still valid within the tolerance time, idleness is still checked
if _proxy.auth_expiry and _proxy.auth_expiry < datetime.datetime.now() + AUTH_EXPIRY_TOLERANCE:
key, proxy = _key, _proxy
log.warning("Releasing unused connection for %s due to token expiry at %s", _proxy, _proxy.auth_expiry)
break
# If UserRestProxy has consumers with state, disconnecting loses state
if _proxy.num_consumers() > 0:
if idle_consumer_timeout > 0 and _proxy.last_used + idle_consumer_timeout < time.monotonic():
Expand Down Expand Up @@ -273,6 +285,7 @@ async def get_user_proxy(self, request: HTTPRequest) -> "UserRestProxy":
if self.config.get("rest_authorization", False):
auth_header = request.headers.get("Authorization")
auth_config = get_auth_config_from_header(auth_header, self.config)
auth_expiry = get_expiration_time_from_header(auth_header)

key = auth_header
if self.proxies.get(key) is None:
Expand All @@ -282,7 +295,7 @@ async def get_user_proxy(self, request: HTTPRequest) -> "UserRestProxy":
"SASL_SSL" if config["security_protocol"] in ("SSL", "SASL_SSL") else "SASL_PLAINTEXT"
)
config.update(auth_config)
self.proxies[key] = UserRestProxy(config, self.kafka_timeout, self.serializer)
self.proxies[key] = UserRestProxy(config, self.kafka_timeout, self.serializer, auth_expiry)
else:
if self.proxies.get(key) is None:
self.proxies[key] = UserRestProxy(self.config, self.kafka_timeout, self.serializer)
Expand Down Expand Up @@ -400,7 +413,13 @@ async def topic_publish(self, topic: str, content_type: str, *, request: HTTPReq


class UserRestProxy:
def __init__(self, config: Config, kafka_timeout: int, serializer):
def __init__(
self,
config: Config,
kafka_timeout: int,
serializer: SchemaRegistrySerializer,
auth_expiry: datetime.datetime | None = None,
):
self.config = config
self.kafka_timeout = kafka_timeout
self.serializer = serializer
Expand All @@ -415,6 +434,7 @@ def __init__(self, config: Config, kafka_timeout: int, serializer):
self.consumer_manager = ConsumerManager(config=config, deserializer=self.serializer)
self.init_admin_client()
self._last_used = time.monotonic()
self._auth_expiry = auth_expiry

self._async_producer_lock = asyncio.Lock()
self._async_producer: Optional[AIOKafkaProducer] = None
Expand All @@ -429,6 +449,10 @@ def last_used(self) -> int:
def mark_used(self) -> None:
self._last_used = time.monotonic()

@property
def auth_expiry(self) -> datetime.datetime:
return self._auth_expiry

def num_consumers(self) -> int:
return len(self.consumer_manager.consumers)

Expand Down
31 changes: 30 additions & 1 deletion karapace/kafka_rest_apis/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

import aiohttp
import dataclasses
import datetime
import enum
import jwt


@enum.unique
Expand Down Expand Up @@ -42,6 +44,11 @@ class SASLOauthConfig(TypedDict):
sasl_oauth_token: str | None


def _split_auth_header(auth_header: str) -> tuple[str, str]:
token_type, _separator, token = auth_header.partition(" ")
return (token_type, token)


def get_auth_config_from_header(
auth_header: str | None,
config: Config,
Expand All @@ -57,7 +64,7 @@ def get_auth_config_from_header(
if auth_header is None:
raise_unauthorized()

token_type, _separator, token = auth_header.partition(" ")
token_type, token = _split_auth_header(auth_header)

if token_type == TokenType.BEARER.value:
return {"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": token}
Expand All @@ -77,6 +84,28 @@ def get_auth_config_from_header(
raise_unauthorized()


def get_expiration_time_from_header(auth_header: str) -> datetime.datetime | None:
"""Extract expiration from Authorization HTTP header.
In case of an OAuth Bearer token, the `exp` claim is extracted and returned as a
`datetime.datetime` object. Otherwise it's safely assumed that the authentication
method is Basic, thus no expiry of the credentials.
The signature is not verified as it is done by the Kafka clients using it and
discarding the token in case of any issues.
:param auth_header: The Authorization header extracted from an HTTP request
"""
token_type, token = _split_auth_header(auth_header)

if token_type == TokenType.BEARER.value:
exp_claim = jwt.decode(token, options={"verify_signature": False}).get("exp")
if exp_claim is not None:
return datetime.datetime.fromtimestamp(exp_claim)

return None


@dataclasses.dataclass
class SimpleOauthTokenProvider(AbstractTokenProvider):
"""A pass-through OAuth token provider to be used by synchronous Kafka clients.
Expand Down
2 changes: 2 additions & 0 deletions requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ pygments==2.15.1
# -r requirements.txt
# pdbpp
# rich
pyjwt==2.8.0
# via -r requirements-dev.in
pyrepl==0.9.0
# via fancycompleter
pytest==7.4.0
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ isodate<1
jsonschema<5
networkx<4
protobuf<4
pyjwt<3
python-dateutil<3
tenacity<9
typing-extensions
Expand Down
2 changes: 2 additions & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ protobuf==3.20.3
# via -r requirements.in
pygments==2.15.1
# via rich
pyjwt==2.8.0
# via -r requirements.in
python-dateutil==2.8.2
# via -r requirements.in
referencing==0.30.0
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
from karapace.config import ConfigDefaults, set_config_defaults
from karapace.kafka_rest_apis.authentication import (
get_auth_config_from_header,
get_expiration_time_from_header,
get_kafka_client_auth_parameters_from_config,
SimpleOauthTokenProvider,
SimpleOauthTokenProviderAsync,
)
from karapace.rapu import HTTPResponse, JSON_CONTENT_TYPE

import base64
import datetime
import jwt
import pytest


Expand Down Expand Up @@ -68,6 +71,20 @@ def test_get_auth_config_from_header(
assert auth_config == expected_auth_config


@pytest.mark.parametrize(
("auth_header", "expected_expiration"),
(
(f"Basic {base64.b64encode(b'username:password').decode()}", None),
(f"Bearer {jwt.encode({'exp': 1697013997}, 'secret')}", datetime.datetime.fromtimestamp(1697013997)),
(f"Bearer {jwt.encode({}, 'secret')}", None),
),
)
def test_get_expiration_time_from_header(auth_header: str, expected_expiration: datetime.datetime) -> None:
expiration = get_expiration_time_from_header(auth_header)

assert expiration == expected_expiration


def test_simple_oauth_token_provider_returns_configured_token() -> None:
token_provider = SimpleOauthTokenProvider("TOKEN")
assert token_provider.token() == "TOKEN"
Expand Down
Loading

0 comments on commit 79b6ae3

Please sign in to comment.