Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RESTClient: add support for relative next URLs in LinkPaginators #1163

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dlt/sources/helpers/rest_client/paginators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional
from urllib.parse import urlparse, urljoin

from dlt.sources.helpers.requests import Response, Request
from dlt.common import jsonpath
Expand Down Expand Up @@ -102,6 +103,12 @@ def update_request(self, request: Request) -> None:

class BaseNextUrlPaginator(BasePaginator):
def update_request(self, request: Request) -> None:
# Handle relative URLs
if self.next_reference:
parsed_url = urlparse(self.next_reference)
if not parsed_url.scheme:
self.next_reference = urljoin(request.url, self.next_reference)

request.url = self.next_reference


Expand Down
88 changes: 60 additions & 28 deletions tests/sources/helpers/rest_client/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import NamedTuple, Callable, Pattern, List, TYPE_CHECKING
from typing import NamedTuple, Callable, Pattern, List, Union, TYPE_CHECKING
import base64

from urllib.parse import urlsplit, urlunsplit
Expand All @@ -10,17 +10,21 @@
from dlt.common import json

if TYPE_CHECKING:
RequestCallback = Callable[[requests_mock.Request, requests_mock.Context], str]
RequestCallback = Callable[
[requests_mock.Request, requests_mock.Context], Union[str, dict, list]
]
ResponseSerializer = Callable[[requests_mock.Request, requests_mock.Context], str]
else:
RequestCallback = Callable
ResponseSerializer = Callable

MOCK_BASE_URL = "https://api.example.com"


class Route(NamedTuple):
method: str
pattern: Pattern[str]
callback: RequestCallback
callback: ResponseSerializer


class APIRouter:
Expand All @@ -32,8 +36,17 @@ def _add_route(
self, method: str, pattern: str, func: RequestCallback
) -> RequestCallback:
compiled_pattern = re.compile(f"{self.base_url}{pattern}")
self.routes.append(Route(method, compiled_pattern, func))
return func

def serialize_response(request, context):
result = func(request, context)

if isinstance(result, dict) or isinstance(result, list):
return json.dumps(result)

return result

self.routes.append(Route(method, compiled_pattern, serialize_response))
return serialize_response

def get(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]:
def decorator(func: RequestCallback) -> RequestCallback:
Expand All @@ -59,9 +72,17 @@ def register_routes(self, mocker: requests_mock.Mocker) -> None:
router = APIRouter(MOCK_BASE_URL)


def serialize_page(records, page_number, total_pages, base_url, records_key="data"):
def serialize_page(
records,
page_number,
total_pages,
request_url,
records_key="data",
use_absolute_url=True,
):
"""Serialize a page of records into a dict with pagination metadata."""
if records_key is None:
return json.dumps(records)
return records

response = {
records_key: records,
Expand All @@ -72,11 +93,15 @@ def serialize_page(records, page_number, total_pages, base_url, records_key="dat
if page_number < total_pages:
next_page = page_number + 1

scheme, netloc, path, _, _ = urlsplit(base_url)
next_page = urlunsplit([scheme, netloc, path, f"page={next_page}", ""])
response["next_page"] = next_page
scheme, netloc, path, _, _ = urlsplit(request_url)
if use_absolute_url:
next_page_url = urlunsplit([scheme, netloc, path, f"page={next_page}", ""])
else:
next_page_url = f"{path}?page={next_page}"

return json.dumps(response)
response["next_page"] = next_page_url

return response


def generate_posts(count=100):
Expand All @@ -91,15 +116,22 @@ def get_page_number(qs, key="page", default=1):
return int(qs.get(key, [default])[0])


def paginate_response(request, records, page_size=10, records_key="data"):
def paginate_response(
request, records, page_size=10, records_key="data", use_absolute_url=True
):
page_number = get_page_number(request.qs)
total_records = len(records)
total_pages = (total_records + page_size - 1) // page_size
start_index = (page_number - 1) * 10
end_index = start_index + 10
records_slice = records[start_index:end_index]
return serialize_page(
records_slice, page_number, total_pages, request.url, records_key
records_slice,
page_number,
total_pages,
request.url,
records_key,
use_absolute_url,
)


Expand All @@ -115,6 +147,10 @@ def posts_no_key(request, context):
def posts(request, context):
return paginate_response(request, generate_posts())

@router.get(r"/posts_relative_next_url(\?page=\d+)?$")
def posts_relative_next_url(request, context):
return paginate_response(request, generate_posts(), use_absolute_url=False)

@router.get(r"/posts/(\d+)/comments")
def post_comments(request, context):
post_id = int(request.url.split("/")[-2])
Expand All @@ -123,17 +159,17 @@ def post_comments(request, context):
@router.get(r"/posts/\d+$")
def post_detail(request, context):
post_id = request.url.split("/")[-1]
return json.dumps({"id": post_id, "body": f"Post body {post_id}"})
return {"id": post_id, "body": f"Post body {post_id}"}

@router.get(r"/posts/\d+/some_details_404")
def post_detail_404(request, context):
"""Return 404 for post with id > 0. Used to test ignoring 404 errors."""
post_id = int(request.url.split("/")[-2])
if post_id < 1:
return json.dumps({"id": post_id, "body": f"Post body {post_id}"})
return {"id": post_id, "body": f"Post body {post_id}"}
else:
context.status_code = 404
return json.dumps({"error": "Post not found"})
return {"error": "Post not found"}

@router.get(r"/posts_under_a_different_key$")
def posts_with_results_key(request, context):
Expand All @@ -149,15 +185,15 @@ def protected_basic_auth(request, context):
if auth == f"Basic {creds_base64}":
return paginate_response(request, generate_posts())
context.status_code = 401
return json.dumps({"error": "Unauthorized"})
return {"error": "Unauthorized"}

@router.get("/protected/posts/bearer-token")
def protected_bearer_token(request, context):
auth = request.headers.get("Authorization")
if auth == "Bearer test-token":
return paginate_response(request, generate_posts())
context.status_code = 401
return json.dumps({"error": "Unauthorized"})
return {"error": "Unauthorized"}

@router.get("/protected/posts/bearer-token-plain-text-error")
def protected_bearer_token_plain_text_erorr(request, context):
Expand All @@ -173,31 +209,27 @@ def protected_api_key(request, context):
if api_key == "test-api-key":
return paginate_response(request, generate_posts())
context.status_code = 401
return json.dumps({"error": "Unauthorized"})
return {"error": "Unauthorized"}

@router.post("/oauth/token")
def oauth_token(request, context):
return json.dumps(
{
"access_token": "test-token",
"expires_in": 3600,
}
)
return {"access_token": "test-token", "expires_in": 3600}

@router.post("/auth/refresh")
def refresh_token(request, context):
body = request.json()
if body.get("refresh_token") == "valid-refresh-token":
return json.dumps({"access_token": "new-valid-token"})
return {"access_token": "new-valid-token"}
context.status_code = 401
return json.dumps({"error": "Invalid refresh token"})
return {"error": "Invalid refresh token"}

router.register_routes(m)

yield m


def assert_pagination(pages, expected_start=0, page_size=10):
def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10):
assert len(pages) == total_pages
for i, page in enumerate(pages):
assert page == [
{"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10)
Expand Down
17 changes: 17 additions & 0 deletions tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,23 @@ def test_default_paginator(self, rest_client: RESTClient):

assert_pagination(pages)

def test_excplicit_paginator(self, rest_client: RESTClient):
pages_iter = rest_client.paginate(
"/posts", paginator=JSONResponsePaginator(next_url_path="next_page")
)
pages = list(pages_iter)

assert_pagination(pages)

def test_excplicit_paginator_relative_next_url(self, rest_client: RESTClient):
pages_iter = rest_client.paginate(
"/posts_relative_next_url",
paginator=JSONResponsePaginator(next_url_path="next_page"),
)
pages = list(pages_iter)

assert_pagination(pages)

def test_paginate_with_hooks(self, rest_client: RESTClient):
def response_hook(response: Response, *args: Any, **kwargs: Any) -> None:
if response.status_code == 404:
Expand Down
Loading
Loading