Skip to content

Commit

Permalink
Support for automatic retry (#48)
Browse files Browse the repository at this point in the history
* Support for automatic retry

* Fix test

* Add assertions and fix retryer

* Apply exponential backoff
  • Loading branch information
hexoul authored Oct 15, 2024
1 parent b55469e commit 611e935
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 63 deletions.
37 changes: 28 additions & 9 deletions centraldogma/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
from typing import Dict, Union, Callable, TypeVar, Optional

from httpx import Client, HTTPTransport, Limits, Response
from tenacity import stop_after_attempt, wait_exponential, Retrying

from centraldogma.exceptions import to_exception

T = TypeVar("T")


class BaseClient:
PATH_PREFIX = "/api/v1"

def __init__(
self,
base_url: str,
Expand All @@ -33,14 +32,21 @@ def __init__(
max_keepalive_connections: int = 2,
**configs,
):
assert retries >= 0, "retries must be greater than or equal to zero"
assert max_connections > 0, "max_connections must be greater than zero"
assert (
max_keepalive_connections > 0
), "max_keepalive_connections must be greater than zero"

base_url = base_url[:-1] if base_url[-1] == "/" else base_url

for key in ["transport", "limits"]:
if key in configs:
del configs[key]

self.retries = retries
self.client = Client(
base_url=f"{base_url}{self.PATH_PREFIX}",
base_url=f"{base_url}/api/v1",
http2=http2,
transport=HTTPTransport(retries=retries),
limits=Limits(
Expand All @@ -53,7 +59,6 @@ def __init__(
self.headers = self._get_headers(token)
self.patch_headers = self._get_patch_headers(token)

# TODO(@hexoul): Support automatic retry with `tenacity` even if an exception occurs.
def request(
self,
method: str,
Expand All @@ -62,6 +67,25 @@ def request(
**kwargs,
) -> Union[Response, T]:
kwargs = self._set_request_headers(method, **kwargs)
retryer = Retrying(
stop=stop_after_attempt(self.retries + 1),
wait=wait_exponential(max=60),
reraise=True,
)
return retryer(self._request, method, path, handler, **kwargs)

def _set_request_headers(self, method: str, **kwargs) -> Dict:
default_headers = self.patch_headers if method == "patch" else self.headers
kwargs["headers"] = {**default_headers, **(kwargs.get("headers") or {})}
return kwargs

def _request(
self,
method: str,
path: str,
handler: Optional[Dict[int, Callable[[Response], T]]] = None,
**kwargs,
):
resp = self.client.request(method, path, **kwargs)
if handler:
converter = handler.get(resp.status_code)
Expand All @@ -71,11 +95,6 @@ def request(
raise to_exception(resp)
return resp

def _set_request_headers(self, method: str, **kwargs) -> Dict:
default_headers = self.patch_headers if method == "patch" else self.headers
kwargs["headers"] = {**default_headers, **(kwargs.get("headers") or {})}
return kwargs

@staticmethod
def _get_headers(token: str) -> Dict:
return {
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ httpx[http2]
marshmallow
pydantic
python-dateutil
tenacity

# Dev dependencies
black
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_install_requires():

setup(
name="centraldogma-python",
version="0.3.0",
version="0.4.0",
description="Python client library for Central Dogma",
long_description=get_long_description(),
long_description_content_type="text/markdown",
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_content_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from centraldogma.query import Query

dogma = Dogma()
dogma = Dogma(retries=3)
project_name = "TestProject"
repo_name = "TestRepository"

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_project_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ProjectNotFoundException,
)

dogma = Dogma()
dogma = Dogma(retries=3)
project_name = "TestProject"


Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_repository_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
import os

dogma = Dogma()
dogma = Dogma(retries=3)
project_name = "TestProject"
repo_name = "TestRepository"

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from centraldogma.query import Query
from centraldogma.watcher import Watcher, Latest

dogma = Dogma()
dogma = Dogma(retries=3)
project_name = "TestProject"
repo_name = "TestRepository"

Expand Down
52 changes: 39 additions & 13 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

from centraldogma.exceptions import UnauthorizedException, NotFoundException
from centraldogma.base_client import BaseClient
from httpx import Response
from httpx import ConnectError, NetworkError, Response
import pytest

client = BaseClient("http://baseurl", "token")
base_url = "http://baseurl"
client = BaseClient(base_url, "token", retries=0)

configs = {
"auth": None,
Expand All @@ -37,7 +38,7 @@
"app": None,
"trust_env": True,
}
client_with_configs = BaseClient("http://baseurl", "token", **configs)
client_with_configs = BaseClient(base_url, "token", **configs)

ok_handler = {HTTPStatus.OK: lambda resp: resp}

Expand Down Expand Up @@ -66,7 +67,7 @@ def test_set_request_headers():
def test_request_with_configs(respx_mock):
methods = ["get", "post", "put", "delete", "patch", "options"]
for method in methods:
getattr(respx_mock, method)("http://baseurl/api/v1/path").mock(
getattr(respx_mock, method)(f"{base_url}/api/v1/path").mock(
return_value=Response(200, text="success")
)
client.request(
Expand All @@ -82,7 +83,7 @@ def test_request_with_configs(respx_mock):


def test_delete(respx_mock):
route = respx_mock.delete("http://baseurl/api/v1/path").mock(
route = respx_mock.delete(f"{base_url}/api/v1/path").mock(
return_value=Response(200, text="success")
)
resp = client.request("delete", "/path", params={"a": "b"})
Expand All @@ -95,36 +96,61 @@ def test_delete(respx_mock):

def test_delete_exception_authorization(respx_mock):
with pytest.raises(UnauthorizedException):
respx_mock.delete("http://baseurl/api/v1/path").mock(return_value=Response(401))
respx_mock.delete(f"{base_url}/api/v1/path").mock(return_value=Response(401))
client.request("delete", "/path", handler=ok_handler)


def test_get(respx_mock):
route = respx_mock.get("http://baseurl/api/v1/path").mock(
route = respx_mock.get(f"{base_url}/api/v1/path").mock(
return_value=Response(200, text="success")
)
resp = client.request("get", "/path", params={"a": "b"}, handler=ok_handler)

assert route.called
assert route.call_count == 1
assert resp.request.headers["Authorization"] == "bearer token"
assert resp.request.headers["Content-Type"] == "application/json"
assert resp.request.url.params.multi_items() == [("a", "b")]


def test_get_exception_authorization(respx_mock):
with pytest.raises(UnauthorizedException):
respx_mock.get("http://baseurl/api/v1/path").mock(return_value=Response(401))
respx_mock.get(f"{base_url}/api/v1/path").mock(return_value=Response(401))
client.request("get", "/path", handler=ok_handler)


def test_get_exception_not_found(respx_mock):
with pytest.raises(NotFoundException):
respx_mock.get("http://baseurl/api/v1/path").mock(return_value=Response(404))
respx_mock.get(f"{base_url}/api/v1/path").mock(return_value=Response(404))
client.request("get", "/path", handler=ok_handler)


def test_get_with_retry_by_response(respx_mock):
route = respx_mock.get(f"{base_url}/api/v1/path").mock(
side_effect=[Response(503), Response(404), Response(200)],
)

retry_client = BaseClient(base_url, "token", retries=2)
retry_client.request("get", "/path", handler=ok_handler)

assert route.called
assert route.call_count == 3


def test_get_with_retry_by_client(respx_mock):
route = respx_mock.get(f"{base_url}/api/v1/path").mock(
side_effect=[ConnectError, ConnectError, NetworkError, Response(200)],
)

retry_client = BaseClient(base_url, "token", retries=10)
retry_client.request("get", "/path", handler=ok_handler)

assert route.called
assert route.call_count == 4


def test_patch(respx_mock):
route = respx_mock.patch("http://baseurl/api/v1/path").mock(
route = respx_mock.patch(f"{base_url}/api/v1/path").mock(
return_value=Response(200, text="success")
)
resp = client.request("patch", "/path", json={"a": "b"}, handler=ok_handler)
Expand All @@ -137,12 +163,12 @@ def test_patch(respx_mock):

def test_patch_exception_authorization(respx_mock):
with pytest.raises(UnauthorizedException):
respx_mock.patch("http://baseurl/api/v1/path").mock(return_value=Response(401))
respx_mock.patch(f"{base_url}/api/v1/path").mock(return_value=Response(401))
client.request("patch", "/path", json={"a": "b"}, handler=ok_handler)


def test_post(respx_mock):
route = respx_mock.post("http://baseurl/api/v1/path").mock(
route = respx_mock.post(f"{base_url}/api/v1/path").mock(
return_value=Response(200, text="success")
)
resp = client.request("post", "/path", json={"a": "b"}, handler=ok_handler)
Expand All @@ -155,5 +181,5 @@ def test_post(respx_mock):

def test_post_exception_authorization(respx_mock):
with pytest.raises(UnauthorizedException):
respx_mock.post("http://baseurl/api/v1/path").mock(return_value=Response(401))
respx_mock.post(f"{base_url}/api/v1/path").mock(return_value=Response(401))
client.request("post", "/path", handler=ok_handler)
Loading

0 comments on commit 611e935

Please sign in to comment.