Skip to content

Commit

Permalink
Enable OAuth/OIDC authentication with Bearer token
Browse files Browse the repository at this point in the history
Handle OAuth/OIDC (Bearer token) auth headers, and use them when
instantiating Kafka clients (with the exception of backups for now).
Basic authentication behaviour is unchanged, just extracted and
unittested.
  • Loading branch information
Mátyás Kuti committed Oct 4, 2023
1 parent 8a4ff32 commit 310b273
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 23 deletions.
2 changes: 2 additions & 0 deletions karapace/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class Config(TypedDict):
sasl_mechanism: str | None
sasl_plain_username: str | None
sasl_plain_password: str | None
sasl_oauth_token: str | None
topic_name: str
metadata_max_age_ms: int
admin_metadata_max_age: int
Expand Down Expand Up @@ -131,6 +132,7 @@ class ConfigDefaults(Config, total=False):
"sasl_mechanism": None,
"sasl_plain_username": None,
SASL_PLAIN_PASSWORD: None,
"sasl_oauth_token": None,
"topic_name": DEFAULT_SCHEMA_TOPIC,
"metadata_max_age_ms": 60000,
"admin_metadata_max_age": 5,
Expand Down
26 changes: 6 additions & 20 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@
from karapace.config import Config, create_client_ssl_context
from karapace.errors import InvalidSchema
from karapace.kafka_rest_apis.admin import KafkaRestAdminClient
from karapace.kafka_rest_apis.auth_utils import get_auth_config_from_header, get_kafka_client_auth_parameters_from_config
from karapace.kafka_rest_apis.consumer_manager import ConsumerManager
from karapace.kafka_rest_apis.error_codes import RESTErrorCodes
from karapace.kafka_rest_apis.schema_cache import TopicSchemaCache
from karapace.karapace import KarapaceBase
from karapace.rapu import HTTPRequest, HTTPResponse, JSON_CONTENT_TYPE
from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE
from karapace.schema_models import TypedSchema, ValidatedTypedSchema
from karapace.schema_type import SchemaType
from karapace.serialization import InvalidMessageSchema, InvalidPayload, SchemaRegistrySerializer, SchemaRetrievalError
from karapace.typing import SchemaId, Subject
from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient
from typing import Callable, Dict, List, Optional, Tuple, Union

import aiohttp.web
import asyncio
import base64
import logging
Expand Down Expand Up @@ -271,26 +271,16 @@ async def get_user_proxy(self, request: HTTPRequest) -> "UserRestProxy":
try:
if self.config.get("rest_authorization", False):
auth_header = request.headers.get("Authorization")
auth_config = get_auth_config_from_header(auth_header, self.config)

if auth_header is None:
raise HTTPResponse(
body='{"message": "Unauthorized"}',
status=HTTPStatus.UNAUTHORIZED,
content_type=JSON_CONTENT_TYPE,
headers={"WWW-Authenticate": 'Basic realm="Karapace REST Proxy"'},
)
key = auth_header
if self.proxies.get(key) is None:
auth = aiohttp.BasicAuth.decode(auth_header)
config = self.config.copy()
config["bootstrap_uri"] = config["sasl_bootstrap_uri"]
config["security_protocol"] = (
"SASL_SSL" if config["security_protocol"] in ("SSL", "SASL_SSL") else "SASL_PLAINTEXT"
)
if config["sasl_mechanism"] is None:
config["sasl_mechanism"] = "PLAIN"
config["sasl_plain_username"] = auth.login
config["sasl_plain_password"] = auth.password
config.update(auth_config)
self.proxies[key] = UserRestProxy(config, self.kafka_timeout, self.serializer)
else:
if self.proxies.get(key) is None:
Expand Down Expand Up @@ -471,9 +461,7 @@ async def _maybe_create_async_producer(self) -> AIOKafkaProducer:
metadata_max_age_ms=self.config["metadata_max_age_ms"],
security_protocol=self.config["security_protocol"],
ssl_context=ssl_context,
sasl_mechanism=self.config["sasl_mechanism"],
sasl_plain_username=self.config["sasl_plain_username"],
sasl_plain_password=self.config["sasl_plain_password"],
**get_kafka_client_auth_parameters_from_config(self.config),
)

try:
Expand Down Expand Up @@ -626,13 +614,11 @@ def init_admin_client(self):
ssl_cafile=self.config["ssl_cafile"],
ssl_certfile=self.config["ssl_certfile"],
ssl_keyfile=self.config["ssl_keyfile"],
sasl_mechanism=self.config["sasl_mechanism"],
sasl_plain_username=self.config["sasl_plain_username"],
sasl_plain_password=self.config["sasl_plain_password"],
api_version=(1, 0, 0),
metadata_max_age_ms=self.config["metadata_max_age_ms"],
connections_max_idle_ms=self.config["connections_max_idle_ms"],
kafka_client=KarapaceKafkaClient,
**get_kafka_client_auth_parameters_from_config(self.config, async_client=False),
)
break
except: # pylint: disable=bare-except
Expand Down
109 changes: 109 additions & 0 deletions karapace/kafka_rest_apis/auth_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Copyright (c) 2023 Aiven Ltd
See LICENSE for details
"""
from __future__ import annotations

from aiokafka.abc import AbstractTokenProvider as AbstractTokenProviderAsync
from http import HTTPStatus
from kafka.oauth.abstract import AbstractTokenProvider
from karapace.config import Config
from karapace.rapu import HTTPResponse, JSON_CONTENT_TYPE
from typing import NoReturn, TypedDict

import aiohttp
import dataclasses
import enum


@enum.unique
class TokenType(enum.Enum):
BASIC = "Basic"
BEARER = "Bearer"


def raise_unauthorized() -> NoReturn:
raise HTTPResponse(
body='{"message": "Unauthorized"}',
status=HTTPStatus.UNAUTHORIZED,
content_type=JSON_CONTENT_TYPE,
headers={"WWW-Authenticate": 'Basic realm="Karapace REST Proxy"'},
)


class SASLPlainConfig(TypedDict):
sasl_mechanism: str | None
sasl_plain_username: str | None
sasl_plain_password: str | None


class SASLOauthConfig(TypedDict):
sasl_mechanism: str | None
sasl_oauth_token: str | None


def get_auth_config_from_header(
auth_header: str | None,
config: Config,
) -> SASLPlainConfig | SASLOauthConfig:
if auth_header is None:
raise_unauthorized()

token_type, _separator, token = auth_header.partition(" ")

if token_type == TokenType.BEARER.value:
return {"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": token}

if token_type == TokenType.BASIC.value:
basic_auth = aiohttp.BasicAuth.decode(auth_header)
sasl_mechanism = config["sasl_mechanism"]
if sasl_mechanism is None:
sasl_mechanism = "PLAIN"

return {
"sasl_mechanism": sasl_mechanism,
"sasl_plain_username": basic_auth.login,
"sasl_plain_password": basic_auth.password,
}

raise_unauthorized()


@dataclasses.dataclass
class SimpleOauthTokenProvider(AbstractTokenProvider):
_token: str

def token(self) -> str:
return self._token


@dataclasses.dataclass
class SimpleOauthTokenProviderAsync(AbstractTokenProviderAsync):
_token: str

async def token(self) -> str:
return self._token


class SASLOauthParams(TypedDict):
sasl_mechanism: str
sasl_oauth_token_provider: AbstractTokenProvider | AbstractTokenProviderAsync


def get_kafka_client_auth_parameters_from_config(
config: Config,
*,
async_client: bool = True,
) -> SASLPlainConfig | SASLOauthParams:
if config["sasl_mechanism"] == "OAUTHBEARER":
token_provider_cls = SimpleOauthTokenProviderAsync if async_client else SimpleOauthTokenProvider
return {
"sasl_mechanism": config["sasl_mechanism"],
"sasl_oauth_token_provider": token_provider_cls(config["sasl_oauth_token"]),
}

return {
"sasl_mechanism": config["sasl_mechanism"],
"sasl_plain_username": config["sasl_plain_username"],
"sasl_plain_password": config["sasl_plain_password"],
}
5 changes: 2 additions & 3 deletions karapace/kafka_rest_apis/consumer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from kafka.errors import GroupAuthorizationFailedError, IllegalStateError, KafkaConfigurationError, KafkaError
from kafka.structs import TopicPartition
from karapace.config import Config, create_client_ssl_context
from karapace.kafka_rest_apis.auth_utils import get_kafka_client_auth_parameters_from_config
from karapace.kafka_rest_apis.error_codes import RESTErrorCodes
from karapace.karapace import empty_response, KarapaceBase
from karapace.serialization import DeserializationError, InvalidMessageHeader, InvalidPayload, SchemaRegistrySerializer
Expand Down Expand Up @@ -205,9 +206,6 @@ async def create_kafka_consumer(self, fetch_min_bytes, group_name, internal_name
client_id=internal_name,
security_protocol=self.config["security_protocol"],
ssl_context=ssl_context,
sasl_mechanism=self.config["sasl_mechanism"],
sasl_plain_username=self.config["sasl_plain_username"],
sasl_plain_password=self.config["sasl_plain_password"],
group_id=group_name,
fetch_min_bytes=max(1, fetch_min_bytes), # Discard earlier negative values
fetch_max_bytes=self.config["consumer_request_max_bytes"],
Expand All @@ -218,6 +216,7 @@ async def create_kafka_consumer(self, fetch_min_bytes, group_name, internal_name
enable_auto_commit=request_data["auto.commit.enable"],
auto_offset_reset=request_data["auto.offset.reset"],
session_timeout_ms=session_timeout_ms,
**get_kafka_client_auth_parameters_from_config(self.config),
)
await c.start()
return c
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ ignore_errors = True
# dependencies.
# - Write your own stubs. You don't need to write stubs for the whole library,
# only the parts that Karapace is interacting with.
[mypy-aiokafka.*]
ignore_missing_imports = True

[mypy-kafka.*]
ignore_missing_imports = True

Expand Down
109 changes: 109 additions & 0 deletions tests/unit/test_auth_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Copyright (c) 2023 Aiven Ltd
See LICENSE for details
"""
from http import HTTPStatus
from karapace.config import ConfigDefaults, set_config_defaults
from karapace.kafka_rest_apis.auth_utils import (
get_auth_config_from_header,
get_kafka_client_auth_parameters_from_config,
SimpleOauthTokenProvider,
SimpleOauthTokenProviderAsync,
)
from karapace.rapu import HTTPResponse, JSON_CONTENT_TYPE
from typing import Optional

import base64
import pytest


@pytest.mark.parametrize(
"auth_header",
(None, "Digest foo=bar"),
)
def test_get_auth_config_from_header_raises_unauthorized_on_invalid_header(auth_header: Optional[str]) -> None:
config = set_config_defaults({})

with pytest.raises(HTTPResponse) as exc_info:
get_auth_config_from_header(auth_header, config)

http_resonse = exc_info.value
assert http_resonse.body == '{"message": "Unauthorized"}'
assert http_resonse.status == HTTPStatus.UNAUTHORIZED
assert http_resonse.headers["Content-Type"] == JSON_CONTENT_TYPE
assert http_resonse.headers["WWW-Authenticate"] == 'Basic realm="Karapace REST Proxy"'


@pytest.mark.parametrize(
("auth_header", "config_override", "expected_auth_config"),
(
(
f"Basic {base64.b64encode(b'username:password').decode()}",
{"sasl_mechanism": None},
{"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"},
),
(
f"Basic {base64.b64encode(b'username:password').decode()}",
{"sasl_mechanism": "PLAIN"},
{"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"},
),
(
f"Basic {base64.b64encode(b'username:password').decode()}",
{"sasl_mechanism": "SCRAM"},
{"sasl_mechanism": "SCRAM", "sasl_plain_username": "username", "sasl_plain_password": "password"},
),
(
"Bearer <TOKEN>",
{},
{"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": "<TOKEN>"},
),
),
)
def test_get_auth_config_from_header(
auth_header: str, config_override: ConfigDefaults, expected_auth_config: ConfigDefaults
) -> None:
config = set_config_defaults(config_override)
auth_config = get_auth_config_from_header(auth_header, config)
assert auth_config == expected_auth_config


def test_simple_oauth_token_provider_returns_configured_token() -> None:
token_provider = SimpleOauthTokenProvider("TOKEN")
assert token_provider.token() == "TOKEN"


async def test_simple_oauth_token_provider_async_returns_configured_token() -> None:
token_provider = SimpleOauthTokenProviderAsync("TOKEN")
assert await token_provider.token() == "TOKEN"


def test_get_client_auth_parameters_from_config_sasl_plain() -> None:
config = set_config_defaults(
{"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"}
)

client_auth_params = get_kafka_client_auth_parameters_from_config(config)

assert client_auth_params == {
"sasl_mechanism": "PLAIN",
"sasl_plain_username": "username",
"sasl_plain_password": "password",
}


def test_get_client_auth_parameters_from_config_oauth() -> None:
config = set_config_defaults({"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": "TOKEN"})

client_auth_params = get_kafka_client_auth_parameters_from_config(config, async_client=False)

assert client_auth_params["sasl_mechanism"] == "OAUTHBEARER"
assert client_auth_params["sasl_oauth_token_provider"].token() == "TOKEN"


async def test_get_client_auth_parameters_from_config_oauth_async() -> None:
config = set_config_defaults({"sasl_mechanism": "OAUTHBEARER", "sasl_oauth_token": "TOKEN"})

client_auth_params = get_kafka_client_auth_parameters_from_config(config, async_client=True)

assert client_auth_params["sasl_mechanism"] == "OAUTHBEARER"
assert await client_auth_params["sasl_oauth_token_provider"].token() == "TOKEN"

0 comments on commit 310b273

Please sign in to comment.