Skip to content

Commit

Permalink
👌 [#3328] PR feedback processing
Browse files Browse the repository at this point in the history
* Cleaned up base URL checking a bit
* Cleaned up hypothesis search strategy
* Clean up context manager code style
  • Loading branch information
sergei-maertens committed Sep 19, 2023
1 parent 66b7fb3 commit 753575e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 35 deletions.
36 changes: 23 additions & 13 deletions src/api_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@
sentinel = object()


def has_same_base(url: furl, reference: furl) -> bool:
if not all(
(
getattr(url, attr) == getattr(reference, attr)
for attr in ("scheme", "username", "password", "host", "port")
)
):
return False
# finally, if all that is the same, the path base should match too
return str(url).startswith(str(reference))
def is_base_url(url: str | furl) -> bool:
"""
Check if a URL is not a relative path/URL.
A URL is considered a base URL if it has:
* a scheme
* a netloc
Protocol relative URLs like //example.com cannot be properly handled by requests,
as there is no default adapter available.
"""
if not isinstance(url, furl):
url = furl(url)
return bool(url.scheme and url.netloc)


class APIClient(Session):
Expand Down Expand Up @@ -97,10 +102,15 @@ def _maybe_close_session(self):

def to_absolute_url(self, maybe_relative_url: str) -> str:
base_furl = furl(self.base_url)
target_furl = furl(maybe_relative_url)
is_absolute = target_furl.path.isabsolute
# absolute here should be interpreted as "fully qualified url", with a protocol
# and netloc
is_absolute = is_base_url(maybe_relative_url)
if is_absolute:
if not has_same_base(target_furl, base_furl):
# we established the target URL is absolute, so ensure that it's contained
# within the self.base_url domain, otherwise you risk sending credentials
# intended for the base URL to some other domain.
has_same_base = maybe_relative_url.startswith(self.base_url)
if not has_same_base:
raise InvalidURLError(
f"Target URL {maybe_relative_url} has a different base URL than the "
f"client ({self.base_url})."
Expand Down
51 changes: 29 additions & 22 deletions src/api_client/tests/test_client_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,8 @@

from ..client import APIClient

http_methods = st.one_of(
st.just("GET"),
st.just("OPTIONS"),
st.just("HEAD"),
st.just("POST"),
st.just("PUT"),
st.just("PATCH"),
st.just("DELETE"),
http_methods = st.sampled_from(
["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]
)


Expand Down Expand Up @@ -122,11 +116,13 @@ def test_request_kwargs_overrule_defaults(self, m):
def test_applies_to_any_http_method(self, method):
factory = TestFactory()

with requests_mock.Mocker() as m:
with (
requests_mock.Mocker() as m,
APIClient.configure_from(factory) as client,
):
m.register_uri(requests_mock.ANY, requests_mock.ANY)

with APIClient.configure_from(factory) as client:
client.request(method, "https://from-factory.example.com/foo")
client.request(method, "https://from-factory.example.com/foo")

self.assertEqual(len(m.request_history), 1)
self.assertEqual(m.last_request.url, "https://from-factory.example.com/foo")
Expand All @@ -142,7 +138,10 @@ def test_relative_urls_are_made_absolute(self, method):
factory = TestFactory()
client = APIClient.configure_from(factory)

with requests_mock.Mocker() as m, client:
with (
requests_mock.Mocker() as m,
client,
):
m.register_uri(requests_mock.ANY, requests_mock.ANY)

client.request(method, "foo")
Expand All @@ -163,7 +162,10 @@ def test_absolute_urls_must_match_base_url_happy_flow(self, method):
factory = TestFactory()
client = APIClient.configure_from(factory)

with requests_mock.Mocker() as m, client:
with (
requests_mock.Mocker() as m,
client,
):
m.register_uri(requests_mock.ANY, requests_mock.ANY)

client.request(method, "https://from-factory.example.com/foo/bar")
Expand All @@ -175,11 +177,13 @@ def test_absolute_urls_must_match_base_url_happy_flow(self, method):
def test_discouraged_usage_without_context(self, method):
client = APIClient("https://example.com")

with requests_mock.Mocker() as m:
with (
requests_mock.Mocker() as m,
patch.object(client, "close", wraps=client.close) as mock_close,
):
m.register_uri(requests_mock.ANY, requests_mock.ANY)

with patch.object(client, "close", wraps=client.close) as mock_close:
client.request(method, "foo")
client.request(method, "foo")

self.assertEqual(len(m.request_history), 1)
mock_close.assert_called_once()
Expand All @@ -188,14 +192,17 @@ def test_discouraged_usage_without_context(self, method):
def test_encouraged_usage_with_context_do_not_close_prematurely(self, method):
client = APIClient("https://example.com")

with patch.object(client, "close", wraps=client.close) as mock_close:
with requests_mock.Mocker() as m, client:
m.register_uri(requests_mock.ANY, requests_mock.ANY)
with (
patch.object(client, "close", wraps=client.close) as mock_close,
requests_mock.Mocker() as m,
client,
):
m.register_uri(requests_mock.ANY, requests_mock.ANY)

client.request(method, "foo")
client.request(method, "foo")

# may not be called inside context block
mock_close.assert_not_called()
# may not be called inside context block
mock_close.assert_not_called()

self.assertEqual(len(m.request_history), 1)
# must be called outside context block
Expand Down

0 comments on commit 753575e

Please sign in to comment.