Skip to content

Commit

Permalink
Issue #254/#691 introduce _on_auth_update handler
Browse files Browse the repository at this point in the history
- to make sure all cases are covered
- include authenticate_oidc_access_token
  • Loading branch information
soxofaan committed Jan 17, 2025
1 parent 1923035 commit 315ce7a
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 56 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Clear capabilities cache on login ([#254](https://github.com/Open-EO/openeo-python-client/issues/254))


## [0.36.0] - 2024-12-10

Expand Down
21 changes: 18 additions & 3 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
slow_response_threshold: Optional[float] = None,
):
self._root_url = root_url
self._auth = None
self.auth = auth or NullAuth()
self.session = session or requests.Session()
self.default_timeout = default_timeout or DEFAULT_TIMEOUT
Expand All @@ -129,6 +130,18 @@ def __init__(
def root_url(self):
return self._root_url

@property
def auth(self) -> Union[AuthBase, None]:
return self._auth

@auth.setter
def auth(self, auth: Union[AuthBase, None]):
self._auth = auth
self._on_auth_update()

def _on_auth_update(self):
pass

def build_url(self, path: str):
return url_join(self._root_url, path)

Expand Down Expand Up @@ -340,12 +353,12 @@ def __init__(
if "://" not in url:
url = "https://" + url
self._orig_url = url
self._capabilities_cache = LazyLoadCache()
super().__init__(
root_url=self.version_discovery(url, session=session, timeout=default_timeout),
auth=auth, session=session, default_timeout=default_timeout,
slow_response_threshold=slow_response_threshold,
)
self._capabilities_cache = LazyLoadCache()

# Initial API version check.
self._api_version.require_at_least(self._MINIMUM_API_VERSION)
Expand Down Expand Up @@ -380,6 +393,10 @@ def version_discovery(
# Be very lenient about failing on the well-known URI strategy.
return url

def _on_auth_update(self):
super()._on_auth_update()
self._capabilities_cache.clear()

def _get_auth_config(self) -> AuthConfig:
if self._auth_config is None:
self._auth_config = AuthConfig()
Expand Down Expand Up @@ -411,7 +428,6 @@ def authenticate_basic(self, username: Optional[str] = None, password: Optional[
).json()
# Switch to bearer based authentication in further requests.
self.auth = BasicBearerAuth(access_token=resp["access_token"])
self._capabilities_cache.clear()
return self

def _get_oidc_provider(
Expand Down Expand Up @@ -546,7 +562,6 @@ def _authenticate_oidc(
_log.warning("No OIDC refresh token to store.")
token = tokens.access_token
self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token)
self._capabilities_cache.clear()
self._oidc_auth_renewer = oidc_auth_renewer
return self

Expand Down
126 changes: 73 additions & 53 deletions tests/rest/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

API_URL = "https://oeo.test/"

# TODO: eliminate this and replace with `build_capabilities` usage
BASIC_ENDPOINTS = [{"path": "/credentials/basic", "methods": ["GET"]}]


Expand Down Expand Up @@ -551,83 +552,102 @@ def test_capabilities_caching(requests_mock):
assert con.capabilities().api_version() == "1.0.0"
assert m.call_count == 1

def test_capabilities_caching_after_authenticate_basic(requests_mock):
user, pwd = "john262", "J0hndo3"

def get_capabilities(request, context):
endpoints = BASIC_ENDPOINTS.copy()
if "Authorization" in request.headers:
endpoints.append({"path": "/account/status", "methods": ["GET"]})
return {"api_version": "1.0.0", "endpoints": endpoints}
def _get_capabilities_auth_dependent(request, context):
capabilities = build_capabilities()
capabilities["endpoints"] = [
{"methods": ["GET"], "path": "/credentials/basic"},
{"methods": ["GET"], "path": "/credentials/oidc"},
]
if "Authorization" in request.headers:
capabilities["endpoints"].append({"methods": ["GET"], "path": "/me"})
return capabilities


get_capabilities_mock = requests_mock.get(API_URL, json=get_capabilities)
def test_capabilities_caching_after_authenticate_basic(requests_mock):
user, pwd = "john262", "J0hndo3"
get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent)
requests_mock.get(API_URL + 'credentials/basic', text=_credentials_basic_handler(user, pwd))

con = Connection(API_URL)
assert con.capabilities().capabilities == {
"api_version": "1.0.0",
"endpoints": [
{"methods": ["GET"], "path": "/credentials/basic"},
],
}
assert con.capabilities().capabilities["endpoints"] == [
{"methods": ["GET"], "path": "/credentials/basic"},
{"methods": ["GET"], "path": "/credentials/oidc"},
]
assert get_capabilities_mock.call_count == 1
con.capabilities()
assert get_capabilities_mock.call_count == 1

con.authenticate_basic(user, pwd)
con.authenticate_basic(username=user, password=pwd)
assert get_capabilities_mock.call_count == 1
assert con.capabilities().capabilities == {
"api_version": "1.0.0",
"endpoints": [
{"methods": ["GET"], "path": "/credentials/basic"},
{"methods": ["GET"], "path": "/account/status"},
],
}
assert get_capabilities_mock.call_count == 2
assert con.capabilities().capabilities["endpoints"] == [
{"methods": ["GET"], "path": "/credentials/basic"},
{"methods": ["GET"], "path": "/credentials/oidc"},
{"methods": ["GET"], "path": "/me"},
]

assert get_capabilities_mock.call_count == 2


def test_capabilities_caching_after_authenticate_oidc(requests_mock):
def test_capabilities_caching_after_authenticate_oidc_refresh_token(requests_mock):
client_id = "myclient"

def get_capabilities(request, context):
endpoints = BASIC_ENDPOINTS.copy()
if "Authorization" in request.headers:
endpoints.append({"path": "/account/status", "methods": ["GET"]})
return {"api_version": "1.0.0", "endpoints": endpoints}

get_capabilities_mock = requests_mock.get(API_URL, json=get_capabilities)
requests_mock.get(API_URL + 'credentials/oidc', json={
"providers": [{"id": "fauth", "issuer": "https://fauth.test", "title": "Foo Auth", "scopes": ["openid", "im"]}]
})
refresh_token = "fr65h!"
get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent)
requests_mock.get(
API_URL + "credentials/oidc",
json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]},
)
oidc_mock = OidcMock(
requests_mock=requests_mock,
expected_grant_type="authorization_code",
expected_grant_type="refresh_token",
expected_client_id=client_id,
expected_fields={"scope": "im openid"},
oidc_issuer="https://fauth.test",
scopes_supported=["openid", "im"],
expected_fields={"refresh_token": refresh_token},
)

conn = Connection(API_URL)
assert conn.capabilities().capabilities == {
"api_version": "1.0.0",
"endpoints": [
{"methods": ["GET"], "path": "/credentials/basic"},
],
}
assert conn.capabilities().capabilities["endpoints"] == [
{"methods": ["GET"], "path": "/credentials/basic"},
{"methods": ["GET"], "path": "/credentials/oidc"},
]

assert get_capabilities_mock.call_count == 1
conn.capabilities()
assert get_capabilities_mock.call_count == 1

conn.authenticate_oidc_authorization_code(client_id=client_id, webbrowser_open=oidc_mock.webbrowser_open)
conn.authenticate_oidc_refresh_token(client_id=client_id, refresh_token=refresh_token)
assert get_capabilities_mock.call_count == 1
assert conn.capabilities().capabilities == {
"api_version": "1.0.0",
"endpoints": [
{"methods": ["GET"], "path": "/credentials/basic"},
{"methods": ["GET"], "path": "/account/status"},
],
}
assert conn.capabilities().capabilities["endpoints"] == [
{"methods": ["GET"], "path": "/credentials/basic"},
{"methods": ["GET"], "path": "/credentials/oidc"},
{"methods": ["GET"], "path": "/me"},
]
assert get_capabilities_mock.call_count == 2


def test_capabilities_caching_after_authenticate_oidc_access_token(requests_mock):
get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent)
requests_mock.get(
API_URL + "credentials/oidc",
json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]},
)

conn = Connection(API_URL)
assert conn.capabilities().capabilities["endpoints"] == [
{"methods": ["GET"], "path": "/credentials/basic"},
{"methods": ["GET"], "path": "/credentials/oidc"},
]

assert get_capabilities_mock.call_count == 1
conn.capabilities()
assert get_capabilities_mock.call_count == 1

conn.authenticate_oidc_access_token(access_token="6cc355!")
assert get_capabilities_mock.call_count == 1
assert conn.capabilities().capabilities["endpoints"] == [
{"methods": ["GET"], "path": "/credentials/basic"},
{"methods": ["GET"], "path": "/credentials/oidc"},
{"methods": ["GET"], "path": "/me"},
]
assert get_capabilities_mock.call_count == 2


Expand Down

0 comments on commit 315ce7a

Please sign in to comment.