From ecb5aa0015bf1e910a1c61d0992f5fca1e5f6514 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Tue, 2 Apr 2024 22:42:54 +0300 Subject: [PATCH] RESTClient: add support for relative next URLs in LinkPaginators (#1163) * Extend `mock_api_server()` to support relative next urls * Enhance BaseNextUrlPaginator to support relative next URLs in pagination --- dlt/sources/helpers/rest_client/paginators.py | 7 + tests/sources/helpers/rest_client/conftest.py | 88 +++++++---- .../helpers/rest_client/test_client.py | 17 +++ .../helpers/rest_client/test_paginators.py | 139 ++++++++++++++++-- 4 files changed, 209 insertions(+), 42 deletions(-) diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py index c098ea667f..48dfdf6e4f 100644 --- a/dlt/sources/helpers/rest_client/paginators.py +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Optional +from urllib.parse import urlparse, urljoin from dlt.sources.helpers.requests import Response, Request from dlt.common import jsonpath @@ -102,6 +103,12 @@ def update_request(self, request: Request) -> None: class BaseNextUrlPaginator(BasePaginator): def update_request(self, request: Request) -> None: + # Handle relative URLs + if self.next_reference: + parsed_url = urlparse(self.next_reference) + if not parsed_url.scheme: + self.next_reference = urljoin(request.url, self.next_reference) + request.url = self.next_reference diff --git a/tests/sources/helpers/rest_client/conftest.py b/tests/sources/helpers/rest_client/conftest.py index 7eec090db6..cffce7cb07 100644 --- a/tests/sources/helpers/rest_client/conftest.py +++ b/tests/sources/helpers/rest_client/conftest.py @@ -1,5 +1,5 @@ import re -from typing import NamedTuple, Callable, Pattern, List, TYPE_CHECKING +from typing import NamedTuple, Callable, Pattern, List, Union, TYPE_CHECKING import base64 from urllib.parse import urlsplit, urlunsplit @@ -10,9 +10,13 @@ from dlt.common import json if TYPE_CHECKING: - RequestCallback = Callable[[requests_mock.Request, requests_mock.Context], str] + RequestCallback = Callable[ + [requests_mock.Request, requests_mock.Context], Union[str, dict, list] + ] + ResponseSerializer = Callable[[requests_mock.Request, requests_mock.Context], str] else: RequestCallback = Callable + ResponseSerializer = Callable MOCK_BASE_URL = "https://api.example.com" @@ -20,7 +24,7 @@ class Route(NamedTuple): method: str pattern: Pattern[str] - callback: RequestCallback + callback: ResponseSerializer class APIRouter: @@ -32,8 +36,17 @@ def _add_route( self, method: str, pattern: str, func: RequestCallback ) -> RequestCallback: compiled_pattern = re.compile(f"{self.base_url}{pattern}") - self.routes.append(Route(method, compiled_pattern, func)) - return func + + def serialize_response(request, context): + result = func(request, context) + + if isinstance(result, dict) or isinstance(result, list): + return json.dumps(result) + + return result + + self.routes.append(Route(method, compiled_pattern, serialize_response)) + return serialize_response def get(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]: def decorator(func: RequestCallback) -> RequestCallback: @@ -59,9 +72,17 @@ def register_routes(self, mocker: requests_mock.Mocker) -> None: router = APIRouter(MOCK_BASE_URL) -def serialize_page(records, page_number, total_pages, base_url, records_key="data"): +def serialize_page( + records, + page_number, + total_pages, + request_url, + records_key="data", + use_absolute_url=True, +): + """Serialize a page of records into a dict with pagination metadata.""" if records_key is None: - return json.dumps(records) + return records response = { records_key: records, @@ -72,11 +93,15 @@ def serialize_page(records, page_number, total_pages, base_url, records_key="dat if page_number < total_pages: next_page = page_number + 1 - scheme, netloc, path, _, _ = urlsplit(base_url) - next_page = urlunsplit([scheme, netloc, path, f"page={next_page}", ""]) - response["next_page"] = next_page + scheme, netloc, path, _, _ = urlsplit(request_url) + if use_absolute_url: + next_page_url = urlunsplit([scheme, netloc, path, f"page={next_page}", ""]) + else: + next_page_url = f"{path}?page={next_page}" - return json.dumps(response) + response["next_page"] = next_page_url + + return response def generate_posts(count=100): @@ -91,7 +116,9 @@ def get_page_number(qs, key="page", default=1): return int(qs.get(key, [default])[0]) -def paginate_response(request, records, page_size=10, records_key="data"): +def paginate_response( + request, records, page_size=10, records_key="data", use_absolute_url=True +): page_number = get_page_number(request.qs) total_records = len(records) total_pages = (total_records + page_size - 1) // page_size @@ -99,7 +126,12 @@ def paginate_response(request, records, page_size=10, records_key="data"): end_index = start_index + 10 records_slice = records[start_index:end_index] return serialize_page( - records_slice, page_number, total_pages, request.url, records_key + records_slice, + page_number, + total_pages, + request.url, + records_key, + use_absolute_url, ) @@ -115,6 +147,10 @@ def posts_no_key(request, context): def posts(request, context): return paginate_response(request, generate_posts()) + @router.get(r"/posts_relative_next_url(\?page=\d+)?$") + def posts_relative_next_url(request, context): + return paginate_response(request, generate_posts(), use_absolute_url=False) + @router.get(r"/posts/(\d+)/comments") def post_comments(request, context): post_id = int(request.url.split("/")[-2]) @@ -123,17 +159,17 @@ def post_comments(request, context): @router.get(r"/posts/\d+$") def post_detail(request, context): post_id = request.url.split("/")[-1] - return json.dumps({"id": post_id, "body": f"Post body {post_id}"}) + return {"id": post_id, "body": f"Post body {post_id}"} @router.get(r"/posts/\d+/some_details_404") def post_detail_404(request, context): """Return 404 for post with id > 0. Used to test ignoring 404 errors.""" post_id = int(request.url.split("/")[-2]) if post_id < 1: - return json.dumps({"id": post_id, "body": f"Post body {post_id}"}) + return {"id": post_id, "body": f"Post body {post_id}"} else: context.status_code = 404 - return json.dumps({"error": "Post not found"}) + return {"error": "Post not found"} @router.get(r"/posts_under_a_different_key$") def posts_with_results_key(request, context): @@ -149,7 +185,7 @@ def protected_basic_auth(request, context): if auth == f"Basic {creds_base64}": return paginate_response(request, generate_posts()) context.status_code = 401 - return json.dumps({"error": "Unauthorized"}) + return {"error": "Unauthorized"} @router.get("/protected/posts/bearer-token") def protected_bearer_token(request, context): @@ -157,7 +193,7 @@ def protected_bearer_token(request, context): if auth == "Bearer test-token": return paginate_response(request, generate_posts()) context.status_code = 401 - return json.dumps({"error": "Unauthorized"}) + return {"error": "Unauthorized"} @router.get("/protected/posts/bearer-token-plain-text-error") def protected_bearer_token_plain_text_erorr(request, context): @@ -173,31 +209,27 @@ def protected_api_key(request, context): if api_key == "test-api-key": return paginate_response(request, generate_posts()) context.status_code = 401 - return json.dumps({"error": "Unauthorized"}) + return {"error": "Unauthorized"} @router.post("/oauth/token") def oauth_token(request, context): - return json.dumps( - { - "access_token": "test-token", - "expires_in": 3600, - } - ) + return {"access_token": "test-token", "expires_in": 3600} @router.post("/auth/refresh") def refresh_token(request, context): body = request.json() if body.get("refresh_token") == "valid-refresh-token": - return json.dumps({"access_token": "new-valid-token"}) + return {"access_token": "new-valid-token"} context.status_code = 401 - return json.dumps({"error": "Invalid refresh token"}) + return {"error": "Invalid refresh token"} router.register_routes(m) yield m -def assert_pagination(pages, expected_start=0, page_size=10): +def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10): + assert len(pages) == total_pages for i, page in enumerate(pages): assert page == [ {"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10) diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 7a4c55f9a6..88653efefe 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -74,6 +74,23 @@ def test_default_paginator(self, rest_client: RESTClient): assert_pagination(pages) + def test_excplicit_paginator(self, rest_client: RESTClient): + pages_iter = rest_client.paginate( + "/posts", paginator=JSONResponsePaginator(next_url_path="next_page") + ) + pages = list(pages_iter) + + assert_pagination(pages) + + def test_excplicit_paginator_relative_next_url(self, rest_client: RESTClient): + pages_iter = rest_client.paginate( + "/posts_relative_next_url", + paginator=JSONResponsePaginator(next_url_path="next_page"), + ) + pages = list(pages_iter) + + assert_pagination(pages) + def test_paginate_with_hooks(self, rest_client: RESTClient): def response_hook(response: Response, *args: Any, **kwargs: Any) -> None: if response.status_code == 404: diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py index cc4dea65dc..bd38a2e421 100644 --- a/tests/sources/helpers/rest_client/test_paginators.py +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -1,7 +1,8 @@ -import pytest from unittest.mock import Mock -from requests.models import Response +import pytest + +from requests.models import Response, Request from dlt.sources.helpers.rest_client.paginators import ( SinglePagePaginator, @@ -29,21 +30,131 @@ def test_update_state_without_next(self): class TestJSONResponsePaginator: - def test_update_state_with_next(self): - paginator = JSONResponsePaginator() - response = Mock( - Response, json=lambda: {"next": "http://example.com/next", "results": []} - ) + @pytest.mark.parametrize( + "test_case", + [ + # Test with empty next_url_path, e.g. auto-detect + { + "next_url_path": None, + "response_json": {"next": "http://example.com/next", "results": []}, + "expected": { + "next_reference": "http://example.com/next", + "has_next_page": True, + }, + }, + # Test with explicit next_url_path + { + "next_url_path": "next_page", + "response_json": { + "next_page": "http://example.com/next", + "results": [], + }, + "expected": { + "next_reference": "http://example.com/next", + "has_next_page": True, + }, + }, + # Test with nested next_url_path + { + "next_url_path": "next_page.url", + "response_json": { + "next_page": {"url": "http://example.com/next"}, + "results": [], + }, + "expected": { + "next_reference": "http://example.com/next", + "has_next_page": True, + }, + }, + # Test without next_page + { + "next_url_path": None, + "response_json": {"results": []}, + "expected": { + "next_reference": None, + "has_next_page": False, + }, + }, + ], + ) + def test_update_state(self, test_case): + next_url_path = test_case["next_url_path"] + + if next_url_path is None: + paginator = JSONResponsePaginator() + else: + paginator = JSONResponsePaginator(next_url_path=next_url_path) + response = Mock(Response, json=lambda: test_case["response_json"]) paginator.update_state(response) - assert paginator.next_reference == "http://example.com/next" - assert paginator.has_next_page is True + assert paginator.next_reference == test_case["expected"]["next_reference"] + assert paginator.has_next_page == test_case["expected"]["has_next_page"] - def test_update_state_without_next(self): + # Test update_request from BaseNextUrlPaginator + @pytest.mark.parametrize( + "test_case", + [ + # Test with absolute URL + { + "next_reference": "http://example.com/api/resource?page=2", + "request_url": "http://example.com/api/resource", + "expected": "http://example.com/api/resource?page=2", + }, + # Test with relative URL + { + "next_reference": "/api/resource?page=2", + "request_url": "http://example.com/api/resource", + "expected": "http://example.com/api/resource?page=2", + }, + # Test with more nested path + { + "next_reference": "/api/resource/subresource?page=3&sort=desc", + "request_url": "http://example.com/api/resource/subresource", + "expected": "http://example.com/api/resource/subresource?page=3&sort=desc", + }, + # Test with 'page' in path + { + "next_reference": "/api/page/4/items?filter=active", + "request_url": "http://example.com/api/page/3/items", + "expected": "http://example.com/api/page/4/items?filter=active", + }, + # Test with complex query parameters + { + "next_reference": "/api/resource?page=3&category=books&sort=author", + "request_url": "http://example.com/api/resource?page=2", + "expected": "http://example.com/api/resource?page=3&category=books&sort=author", + }, + # Test with URL having port number + { + "next_reference": "/api/resource?page=2", + "request_url": "http://example.com:8080/api/resource", + "expected": "http://example.com:8080/api/resource?page=2", + }, + # Test with HTTPS protocol + { + "next_reference": "https://secure.example.com/api/resource?page=2", + "request_url": "https://secure.example.com/api/resource", + "expected": "https://secure.example.com/api/resource?page=2", + }, + # Test with encoded characters in URL + { + "next_reference": "/api/resource?page=2&query=%E3%81%82", + "request_url": "http://example.com/api/resource", + "expected": "http://example.com/api/resource?page=2&query=%E3%81%82", + }, + # Test with missing 'page' parameter in next_reference + { + "next_reference": "/api/resource?sort=asc", + "request_url": "http://example.com/api/resource?page=1", + "expected": "http://example.com/api/resource?sort=asc", + }, + ], + ) + def test_update_request(self, test_case): paginator = JSONResponsePaginator() - response = Mock(Response, json=lambda: {"results": []}) - paginator.update_state(response) - assert paginator.next_reference is None - assert paginator.has_next_page is False + paginator.next_reference = test_case["next_reference"] + request = Mock(Request, url=test_case["request_url"]) + paginator.update_request(request) + assert request.url == test_case["expected"] class TestSinglePagePaginator: