Skip to content

Commit

Permalink
RESTClient: add support for relative next URLs in LinkPaginators (#1163)
Browse files Browse the repository at this point in the history
* Extend `mock_api_server()` to support relative next urls
* Enhance BaseNextUrlPaginator to support relative next URLs in pagination
  • Loading branch information
burnash authored Apr 2, 2024
1 parent ea22515 commit ecb5aa0
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 42 deletions.
7 changes: 7 additions & 0 deletions dlt/sources/helpers/rest_client/paginators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional
from urllib.parse import urlparse, urljoin

from dlt.sources.helpers.requests import Response, Request
from dlt.common import jsonpath
Expand Down Expand Up @@ -102,6 +103,12 @@ def update_request(self, request: Request) -> None:

class BaseNextUrlPaginator(BasePaginator):
def update_request(self, request: Request) -> None:
# Handle relative URLs
if self.next_reference:
parsed_url = urlparse(self.next_reference)
if not parsed_url.scheme:
self.next_reference = urljoin(request.url, self.next_reference)

request.url = self.next_reference


Expand Down
88 changes: 60 additions & 28 deletions tests/sources/helpers/rest_client/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import NamedTuple, Callable, Pattern, List, TYPE_CHECKING
from typing import NamedTuple, Callable, Pattern, List, Union, TYPE_CHECKING
import base64

from urllib.parse import urlsplit, urlunsplit
Expand All @@ -10,17 +10,21 @@
from dlt.common import json

if TYPE_CHECKING:
RequestCallback = Callable[[requests_mock.Request, requests_mock.Context], str]
RequestCallback = Callable[
[requests_mock.Request, requests_mock.Context], Union[str, dict, list]
]
ResponseSerializer = Callable[[requests_mock.Request, requests_mock.Context], str]
else:
RequestCallback = Callable
ResponseSerializer = Callable

MOCK_BASE_URL = "https://api.example.com"


class Route(NamedTuple):
method: str
pattern: Pattern[str]
callback: RequestCallback
callback: ResponseSerializer


class APIRouter:
Expand All @@ -32,8 +36,17 @@ def _add_route(
self, method: str, pattern: str, func: RequestCallback
) -> RequestCallback:
compiled_pattern = re.compile(f"{self.base_url}{pattern}")
self.routes.append(Route(method, compiled_pattern, func))
return func

def serialize_response(request, context):
result = func(request, context)

if isinstance(result, dict) or isinstance(result, list):
return json.dumps(result)

return result

self.routes.append(Route(method, compiled_pattern, serialize_response))
return serialize_response

def get(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]:
def decorator(func: RequestCallback) -> RequestCallback:
Expand All @@ -59,9 +72,17 @@ def register_routes(self, mocker: requests_mock.Mocker) -> None:
router = APIRouter(MOCK_BASE_URL)


def serialize_page(records, page_number, total_pages, base_url, records_key="data"):
def serialize_page(
records,
page_number,
total_pages,
request_url,
records_key="data",
use_absolute_url=True,
):
"""Serialize a page of records into a dict with pagination metadata."""
if records_key is None:
return json.dumps(records)
return records

response = {
records_key: records,
Expand All @@ -72,11 +93,15 @@ def serialize_page(records, page_number, total_pages, base_url, records_key="dat
if page_number < total_pages:
next_page = page_number + 1

scheme, netloc, path, _, _ = urlsplit(base_url)
next_page = urlunsplit([scheme, netloc, path, f"page={next_page}", ""])
response["next_page"] = next_page
scheme, netloc, path, _, _ = urlsplit(request_url)
if use_absolute_url:
next_page_url = urlunsplit([scheme, netloc, path, f"page={next_page}", ""])
else:
next_page_url = f"{path}?page={next_page}"

return json.dumps(response)
response["next_page"] = next_page_url

return response


def generate_posts(count=100):
Expand All @@ -91,15 +116,22 @@ def get_page_number(qs, key="page", default=1):
return int(qs.get(key, [default])[0])


def paginate_response(request, records, page_size=10, records_key="data"):
def paginate_response(
request, records, page_size=10, records_key="data", use_absolute_url=True
):
page_number = get_page_number(request.qs)
total_records = len(records)
total_pages = (total_records + page_size - 1) // page_size
start_index = (page_number - 1) * 10
end_index = start_index + 10
records_slice = records[start_index:end_index]
return serialize_page(
records_slice, page_number, total_pages, request.url, records_key
records_slice,
page_number,
total_pages,
request.url,
records_key,
use_absolute_url,
)


Expand All @@ -115,6 +147,10 @@ def posts_no_key(request, context):
def posts(request, context):
return paginate_response(request, generate_posts())

@router.get(r"/posts_relative_next_url(\?page=\d+)?$")
def posts_relative_next_url(request, context):
return paginate_response(request, generate_posts(), use_absolute_url=False)

@router.get(r"/posts/(\d+)/comments")
def post_comments(request, context):
post_id = int(request.url.split("/")[-2])
Expand All @@ -123,17 +159,17 @@ def post_comments(request, context):
@router.get(r"/posts/\d+$")
def post_detail(request, context):
post_id = request.url.split("/")[-1]
return json.dumps({"id": post_id, "body": f"Post body {post_id}"})
return {"id": post_id, "body": f"Post body {post_id}"}

@router.get(r"/posts/\d+/some_details_404")
def post_detail_404(request, context):
"""Return 404 for post with id > 0. Used to test ignoring 404 errors."""
post_id = int(request.url.split("/")[-2])
if post_id < 1:
return json.dumps({"id": post_id, "body": f"Post body {post_id}"})
return {"id": post_id, "body": f"Post body {post_id}"}
else:
context.status_code = 404
return json.dumps({"error": "Post not found"})
return {"error": "Post not found"}

@router.get(r"/posts_under_a_different_key$")
def posts_with_results_key(request, context):
Expand All @@ -149,15 +185,15 @@ def protected_basic_auth(request, context):
if auth == f"Basic {creds_base64}":
return paginate_response(request, generate_posts())
context.status_code = 401
return json.dumps({"error": "Unauthorized"})
return {"error": "Unauthorized"}

@router.get("/protected/posts/bearer-token")
def protected_bearer_token(request, context):
auth = request.headers.get("Authorization")
if auth == "Bearer test-token":
return paginate_response(request, generate_posts())
context.status_code = 401
return json.dumps({"error": "Unauthorized"})
return {"error": "Unauthorized"}

@router.get("/protected/posts/bearer-token-plain-text-error")
def protected_bearer_token_plain_text_erorr(request, context):
Expand All @@ -173,31 +209,27 @@ def protected_api_key(request, context):
if api_key == "test-api-key":
return paginate_response(request, generate_posts())
context.status_code = 401
return json.dumps({"error": "Unauthorized"})
return {"error": "Unauthorized"}

@router.post("/oauth/token")
def oauth_token(request, context):
return json.dumps(
{
"access_token": "test-token",
"expires_in": 3600,
}
)
return {"access_token": "test-token", "expires_in": 3600}

@router.post("/auth/refresh")
def refresh_token(request, context):
body = request.json()
if body.get("refresh_token") == "valid-refresh-token":
return json.dumps({"access_token": "new-valid-token"})
return {"access_token": "new-valid-token"}
context.status_code = 401
return json.dumps({"error": "Invalid refresh token"})
return {"error": "Invalid refresh token"}

router.register_routes(m)

yield m


def assert_pagination(pages, expected_start=0, page_size=10):
def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10):
assert len(pages) == total_pages
for i, page in enumerate(pages):
assert page == [
{"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10)
Expand Down
17 changes: 17 additions & 0 deletions tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,23 @@ def test_default_paginator(self, rest_client: RESTClient):

assert_pagination(pages)

def test_excplicit_paginator(self, rest_client: RESTClient):
pages_iter = rest_client.paginate(
"/posts", paginator=JSONResponsePaginator(next_url_path="next_page")
)
pages = list(pages_iter)

assert_pagination(pages)

def test_excplicit_paginator_relative_next_url(self, rest_client: RESTClient):
pages_iter = rest_client.paginate(
"/posts_relative_next_url",
paginator=JSONResponsePaginator(next_url_path="next_page"),
)
pages = list(pages_iter)

assert_pagination(pages)

def test_paginate_with_hooks(self, rest_client: RESTClient):
def response_hook(response: Response, *args: Any, **kwargs: Any) -> None:
if response.status_code == 404:
Expand Down
Loading

0 comments on commit ecb5aa0

Please sign in to comment.