From c93bf5d882ef952276aba7ef6b12ddf34e0baf24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1ty=C3=A1s=20Kuti?= Date: Mon, 23 Oct 2023 11:19:56 +0200 Subject: [PATCH] Return Unauthorized on malformed OAuth token Handle JWT DecodeErrors when extracting the expiration timestamp from an OIDC/OAuth2 JWT token, otherwise this would result in an HTTP 500 response. Catching PyJWT's `DecodeError` is sufficient, as all other exceptions from the library (see https://pyjwt.readthedocs.io/en/stable/api.html#exceptions) are related to proper verification, which we do not do at the moment. --- karapace/kafka_rest_apis/authentication.py | 6 +++++- tests/unit/test_authentication.py | 24 +++++++++++++++++----- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/karapace/kafka_rest_apis/authentication.py b/karapace/kafka_rest_apis/authentication.py index 88cab7289..ec2af41f1 100644 --- a/karapace/kafka_rest_apis/authentication.py +++ b/karapace/kafka_rest_apis/authentication.py @@ -99,7 +99,11 @@ def get_expiration_time_from_header(auth_header: str) -> datetime.datetime | Non token_type, token = _split_auth_header(auth_header) if token_type == TokenType.BEARER.value: - exp_claim = jwt.decode(token, options={"verify_signature": False}).get("exp") + try: + exp_claim = jwt.decode(token, options={"verify_signature": False}).get("exp") + except jwt.exceptions.DecodeError: + raise_unauthorized() + if exp_claim is not None: return datetime.datetime.fromtimestamp(exp_claim, datetime.timezone.utc) diff --git a/tests/unit/test_authentication.py b/tests/unit/test_authentication.py index 3921f0560..485ff86d1 100644 --- a/tests/unit/test_authentication.py +++ b/tests/unit/test_authentication.py @@ -21,6 +21,13 @@ import pytest +def _assert_unauthorized_http_response(http_response: HTTPResponse) -> None: + assert http_response.body == '{"message": "Unauthorized"}' + assert http_response.status == HTTPStatus.UNAUTHORIZED + assert http_response.headers["Content-Type"] == JSON_CONTENT_TYPE + assert http_response.headers["WWW-Authenticate"] == 'Basic realm="Karapace REST Proxy"' + + @pytest.mark.parametrize( "auth_header", (None, "Digest foo=bar"), @@ -31,11 +38,7 @@ def test_get_auth_config_from_header_raises_unauthorized_on_invalid_header(auth_ with pytest.raises(HTTPResponse) as exc_info: get_auth_config_from_header(auth_header, config) - http_resonse = exc_info.value - assert http_resonse.body == '{"message": "Unauthorized"}' - assert http_resonse.status == HTTPStatus.UNAUTHORIZED - assert http_resonse.headers["Content-Type"] == JSON_CONTENT_TYPE - assert http_resonse.headers["WWW-Authenticate"] == 'Basic realm="Karapace REST Proxy"' + _assert_unauthorized_http_response(exc_info.value) @pytest.mark.parametrize( @@ -88,6 +91,17 @@ def test_get_expiration_time_from_header(auth_header: str, expected_expiration: assert expiration == expected_expiration +@pytest.mark.parametrize( + "auth_header", + (f"Bearer {jwt.encode({'exp': 1697013997}, 'secret')}XX", "Bearer NotAToken"), +) +def test_get_expiration_time_from_header_malformed_bearer_token_raises_unauthorized(auth_header: str) -> None: + with pytest.raises(HTTPResponse) as exc_info: + get_expiration_time_from_header(auth_header) + + _assert_unauthorized_http_response(exc_info.value) + + def test_simple_oauth_token_provider_returns_configured_token() -> None: token_provider = SimpleOauthTokenProvider("TOKEN") assert token_provider.token() == "TOKEN"