Skip to content

Commit

Permalink
Reformat with black (#1179)
Browse files Browse the repository at this point in the history
  • Loading branch information
burnash authored Apr 3, 2024
1 parent 6bf1940 commit 73e8176
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 63 deletions.
21 changes: 8 additions & 13 deletions dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
else:
PrivateKeyTypes = Any

TApiKeyLocation = Literal[
"header", "cookie", "query", "param"
] # Alias for scheme "in" field
TApiKeyLocation = Literal["header", "cookie", "query", "param"] # Alias for scheme "in" field


class AuthConfigBase(AuthBase, CredentialsConfiguration):
Expand Down Expand Up @@ -102,7 +100,8 @@ def parse_native_representation(self, value: Any) -> None:
raise NativeValueError(
type(self),
value,
f"HttpBasicAuth username and password must be a tuple of two strings, got {type(value)}",
"HttpBasicAuth username and password must be a tuple of two strings, got"
f" {type(value)}",
)

def __call__(self, request: PreparedRequest) -> PreparedRequest:
Expand Down Expand Up @@ -147,9 +146,7 @@ class OAuthJWTAuth(BearerTokenAuth):
default_token_expiration: int = 3600

def __post_init__(self) -> None:
self.scopes = (
self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
)
self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
self.token = None
self.token_expiry: Optional[pendulum.DateTime] = None

Expand All @@ -171,9 +168,7 @@ def obtain_token(self) -> None:
payload = self.create_jwt_payload()
data = {
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": jwt.encode(
payload, self.load_private_key(), algorithm="RS256"
),
"assertion": jwt.encode(payload, self.load_private_key(), algorithm="RS256"),
}

logger.debug(f"Obtaining token from {self.auth_endpoint}")
Expand Down Expand Up @@ -208,8 +203,8 @@ def load_private_key(self) -> "PrivateKeyTypes":
private_key_bytes = self.private_key.encode("utf-8")
return serialization.load_pem_private_key(
private_key_bytes,
password=self.private_key_passphrase.encode("utf-8")
if self.private_key_passphrase
else None,
password=(
self.private_key_passphrase.encode("utf-8") if self.private_key_passphrase else None
),
backend=default_backend(),
)
24 changes: 6 additions & 18 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,18 @@ def _send_request(self, request: Request) -> Response:

return self.session.send(prepared_request)

def request(
self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any
) -> Response:
def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) -> Response:
prepared_request = self._create_request(
path=path,
method=method,
**kwargs,
)
return self._send_request(prepared_request)

def get(
self, path: str, params: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Response:
def get(self, path: str, params: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Response:
return self.request(path, method="GET", params=params, **kwargs)

def post(
self, path: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Response:
def post(self, path: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Response:
return self.request(path, method="POST", json=json, **kwargs)

def paginate(
Expand Down Expand Up @@ -224,16 +218,12 @@ def raise_for_status(response: Response, *args: Any, **kwargs: Any) -> None:
paginator.update_request(request)

# yield data with context
yield PageData(
data, request=request, response=response, paginator=paginator, auth=auth
)
yield PageData(data, request=request, response=response, paginator=paginator, auth=auth)

if not paginator.has_next_page:
break

def extract_response(
self, response: Response, data_selector: jsonpath.TJsonPath
) -> List[Any]:
def extract_response(self, response: Response, data_selector: jsonpath.TJsonPath) -> List[Any]:
if data_selector:
# we should compile data_selector
data: Any = jsonpath.find_values(data_selector, response.json())
Expand All @@ -257,8 +247,6 @@ def detect_paginator(self, response: Response) -> BasePaginator:
"""
paginator = self.pagination_factory.create_paginator(response)
if paginator is None:
raise ValueError(
f"No suitable paginator found for the response at {response.url}"
)
raise ValueError(f"No suitable paginator found for the response at {response.url}")
logger.info(f"Detected paginator: {paginator.__class__.__name__}")
return paginator
7 changes: 2 additions & 5 deletions dlt/sources/helpers/rest_client/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def find_records(
return next(
list_info[2]
for list_info in lists
if list_info[1] in RECORD_KEY_PATTERNS
and list_info[1] not in NON_RECORD_KEY_PATTERNS
if list_info[1] in RECORD_KEY_PATTERNS and list_info[1] not in NON_RECORD_KEY_PATTERNS
)
except StopIteration:
# return the least nested element
Expand Down Expand Up @@ -142,9 +141,7 @@ def single_page_detector(response: Response) -> Optional[SinglePagePaginator]:


class PaginatorFactory:
def __init__(
self, detectors: List[Callable[[Response], Optional[BasePaginator]]] = None
):
def __init__(self, detectors: List[Callable[[Response], Optional[BasePaginator]]] = None):
if detectors is None:
detectors = [
header_links_detector,
Expand Down
4 changes: 1 addition & 3 deletions dlt/sources/helpers/rest_client/paginators.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def update_state(self, response: Response) -> None:
total = values[0] if values else None

if total is None:
raise ValueError(
f"Total count not found in response for {self.__class__.__name__}"
)
raise ValueError(f"Total count not found in response for {self.__class__.__name__}")

try:
total = int(total)
Expand Down
16 changes: 4 additions & 12 deletions tests/sources/helpers/rest_client/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def __init__(self, base_url: str):
self.routes: List[Route] = []
self.base_url = base_url

def _add_route(
self, method: str, pattern: str, func: RequestCallback
) -> RequestCallback:
def _add_route(self, method: str, pattern: str, func: RequestCallback) -> RequestCallback:
compiled_pattern = re.compile(f"{self.base_url}{pattern}")

def serialize_response(request, context):
Expand Down Expand Up @@ -116,9 +114,7 @@ def get_page_number(qs, key="page", default=1):
return int(qs.get(key, [default])[0])


def paginate_response(
request, records, page_size=10, records_key="data", use_absolute_url=True
):
def paginate_response(request, records, page_size=10, records_key="data", use_absolute_url=True):
page_number = get_page_number(request.qs)
total_records = len(records)
total_pages = (total_records + page_size - 1) // page_size
Expand Down Expand Up @@ -173,9 +169,7 @@ def post_detail_404(request, context):

@router.get(r"/posts_under_a_different_key$")
def posts_with_results_key(request, context):
return paginate_response(
request, generate_posts(), records_key="many-results"
)
return paginate_response(request, generate_posts(), records_key="many-results")

@router.get("/protected/posts/basic-auth")
def protected_basic_auth(request, context):
Expand Down Expand Up @@ -231,6 +225,4 @@ def refresh_token(request, context):
def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10):
assert len(pages) == total_pages
for i, page in enumerate(pages):
assert page == [
{"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10)
]
assert page == [{"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10)]
4 changes: 1 addition & 3 deletions tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ def test_bearer_token_auth_success(self, rest_client: RESTClient):
def test_api_key_auth_success(self, rest_client: RESTClient):
response = rest_client.get(
"/protected/posts/api-key",
auth=APIKeyAuth(
name="x-api-key", api_key=cast(TSecretStrValue, "test-api-key")
),
auth=APIKeyAuth(name="x-api-key", api_key=cast(TSecretStrValue, "test-api-key")),
)
assert response.status_code == 200
assert response.json()["data"][0] == {"id": 0, "title": "Post 0"}
Expand Down
8 changes: 2 additions & 6 deletions tests/sources/helpers/rest_client/test_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@
},
{
"response": {
"_embedded": {
"items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}]
},
"_embedded": {"items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}]},
"_links": {
"first": {"href": "http://api.example.com/items?page=0&size=2"},
"self": {"href": "http://api.example.com/items?page=1&size=2"},
Expand Down Expand Up @@ -315,9 +313,7 @@ def test_find_records(test_case):
@pytest.mark.parametrize("test_case", TEST_RESPONSES)
def test_find_next_page_key(test_case):
response = test_case["response"]
expected = test_case.get("expected").get(
"next_path", None
) # Some cases may not have next_path
expected = test_case.get("expected").get("next_path", None) # Some cases may not have next_path
assert find_next_page_path(response) == expected


Expand Down
4 changes: 1 addition & 3 deletions tests/sources/helpers/rest_client/test_paginators.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ def test_update_state(self):

def test_update_state_with_next(self):
paginator = SinglePagePaginator()
response = Mock(
Response, json=lambda: {"next": "http://example.com/next", "results": []}
)
response = Mock(Response, json=lambda: {"next": "http://example.com/next", "results": []})
response.links = {"next": {"url": "http://example.com/next"}}
paginator.update_state(response)
assert paginator.has_next_page is False
Expand Down

0 comments on commit 73e8176

Please sign in to comment.