diff --git a/src/api_client/client.py b/src/api_client/client.py index 90e2be16de..b2b83761fb 100644 --- a/src/api_client/client.py +++ b/src/api_client/client.py @@ -19,16 +19,21 @@ sentinel = object() -def has_same_base(url: furl, reference: furl) -> bool: - if not all( - ( - getattr(url, attr) == getattr(reference, attr) - for attr in ("scheme", "username", "password", "host", "port") - ) - ): - return False - # finally, if all that is the same, the path base should match too - return str(url).startswith(str(reference)) +def is_base_url(url: str | furl) -> bool: + """ + Check if a URL is not a relative path/URL. + + A URL is considered a base URL if it has: + + * a scheme + * a netloc + + Protocol relative URLs like //example.com cannot be properly handled by requests, + as there is no default adapter available. + """ + if not isinstance(url, furl): + url = furl(url) + return bool(url.scheme and url.netloc) class APIClient(Session): @@ -97,10 +102,15 @@ def _maybe_close_session(self): def to_absolute_url(self, maybe_relative_url: str) -> str: base_furl = furl(self.base_url) - target_furl = furl(maybe_relative_url) - is_absolute = target_furl.path.isabsolute + # absolute here should be interpreted as "fully qualified url", with a protocol + # and netloc + is_absolute = is_base_url(maybe_relative_url) if is_absolute: - if not has_same_base(target_furl, base_furl): + # we established the target URL is absolute, so ensure that it's contained + # within the self.base_url domain, otherwise you risk sending credentials + # intended for the base URL to some other domain. + has_same_base = maybe_relative_url.startswith(self.base_url) + if not has_same_base: raise InvalidURLError( f"Target URL {maybe_relative_url} has a different base URL than the " f"client ({self.base_url})." diff --git a/src/api_client/tests/test_client_api.py b/src/api_client/tests/test_client_api.py index 47cc22001e..6a93b98658 100644 --- a/src/api_client/tests/test_client_api.py +++ b/src/api_client/tests/test_client_api.py @@ -10,14 +10,8 @@ from ..client import APIClient -http_methods = st.one_of( - st.just("GET"), - st.just("OPTIONS"), - st.just("HEAD"), - st.just("POST"), - st.just("PUT"), - st.just("PATCH"), - st.just("DELETE"), +http_methods = st.sampled_from( + ["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"] ) @@ -122,11 +116,13 @@ def test_request_kwargs_overrule_defaults(self, m): def test_applies_to_any_http_method(self, method): factory = TestFactory() - with requests_mock.Mocker() as m: + with ( + requests_mock.Mocker() as m, + APIClient.configure_from(factory) as client, + ): m.register_uri(requests_mock.ANY, requests_mock.ANY) - with APIClient.configure_from(factory) as client: - client.request(method, "https://from-factory.example.com/foo") + client.request(method, "https://from-factory.example.com/foo") self.assertEqual(len(m.request_history), 1) self.assertEqual(m.last_request.url, "https://from-factory.example.com/foo") @@ -142,7 +138,10 @@ def test_relative_urls_are_made_absolute(self, method): factory = TestFactory() client = APIClient.configure_from(factory) - with requests_mock.Mocker() as m, client: + with ( + requests_mock.Mocker() as m, + client, + ): m.register_uri(requests_mock.ANY, requests_mock.ANY) client.request(method, "foo") @@ -163,7 +162,10 @@ def test_absolute_urls_must_match_base_url_happy_flow(self, method): factory = TestFactory() client = APIClient.configure_from(factory) - with requests_mock.Mocker() as m, client: + with ( + requests_mock.Mocker() as m, + client, + ): m.register_uri(requests_mock.ANY, requests_mock.ANY) client.request(method, "https://from-factory.example.com/foo/bar") @@ -175,11 +177,13 @@ def test_absolute_urls_must_match_base_url_happy_flow(self, method): def test_discouraged_usage_without_context(self, method): client = APIClient("https://example.com") - with requests_mock.Mocker() as m: + with ( + requests_mock.Mocker() as m, + patch.object(client, "close", wraps=client.close) as mock_close, + ): m.register_uri(requests_mock.ANY, requests_mock.ANY) - with patch.object(client, "close", wraps=client.close) as mock_close: - client.request(method, "foo") + client.request(method, "foo") self.assertEqual(len(m.request_history), 1) mock_close.assert_called_once() @@ -188,14 +192,17 @@ def test_discouraged_usage_without_context(self, method): def test_encouraged_usage_with_context_do_not_close_prematurely(self, method): client = APIClient("https://example.com") - with patch.object(client, "close", wraps=client.close) as mock_close: - with requests_mock.Mocker() as m, client: - m.register_uri(requests_mock.ANY, requests_mock.ANY) + with ( + patch.object(client, "close", wraps=client.close) as mock_close, + requests_mock.Mocker() as m, + client, + ): + m.register_uri(requests_mock.ANY, requests_mock.ANY) - client.request(method, "foo") + client.request(method, "foo") - # may not be called inside context block - mock_close.assert_not_called() + # may not be called inside context block + mock_close.assert_not_called() self.assertEqual(len(m.request_history), 1) # must be called outside context block