Skip to content

Commit

Permalink
Fix: allow unexpected fields in responses (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
angela-tran authored Sep 17, 2024
2 parents 2e55308 + b756c42 commit 16e7150
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 5 deletions.
37 changes: 37 additions & 0 deletions littlepay/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,57 @@
from dataclasses import dataclass
from inspect import signature
import logging
from typing import Generator, Protocol, TypeVar

from authlib.integrations.requests_client import OAuth2Session


logger = logging.getLogger(__name__)


# Generic type parameter, used to represent the result of an API call.
TResponse = TypeVar("TResponse")


def from_kwargs(cls, **kwargs):
"""
Helper function meant to be used as a @classmethod
for instantiating a dataclass and allowing unexpected fields
See https://stackoverflow.com/a/55101438
"""
# fetch the constructor's signature
class_fields = {field for field in signature(cls).parameters}

# split the kwargs into native ones and new ones
native_args, new_args = {}, {}
for name, val in kwargs.items():
if name in class_fields:
native_args[name] = val
else:
new_args[name] = val

# use the native ones to create the class ...
instance = cls(**native_args)

# ... and log any unexpected args
for new_name, new_val in new_args.items():
logger.info(f"Ran into an unexpected arg: {new_name} = {new_val}")

return instance


@dataclass
class ListResponse:
"""An API response with list and total_count attributes."""

list: list
total_count: int

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


class ClientProtocol(Protocol):
"""Protocol describing key functionality for an API connection."""
Expand Down
3 changes: 2 additions & 1 deletion littlepay/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def _delete(self, endpoint: str) -> bool:
def _get(self, endpoint: str, response_cls: TResponse, **kwargs) -> TResponse:
response = self.oauth.get(endpoint, headers=self.headers, params=kwargs)
response.raise_for_status()
return response_cls(**response.json())

return response_cls.from_kwargs(**response.json())

def _get_list(self, endpoint: str, **kwargs) -> Generator[dict, None, None]:
params = dict(page=1, per_page=100)
Expand Down
27 changes: 27 additions & 0 deletions littlepay/api/funding_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from littlepay.api import ClientProtocol

from . import from_kwargs


@dataclass
class FundingSourceResponse:
Expand All @@ -17,6 +19,7 @@ class FundingSourceResponse:
participant_id: str
is_fpan: bool
related_funding_sources: List[dict]
created_date: datetime | None = None
card_category: Optional[str] = None
issuer_country_code: Optional[str] = None
issuer_country_numeric_code: Optional[str] = None
Expand All @@ -25,6 +28,26 @@ class FundingSourceResponse:
token_key_id: Optional[str] = None
icc_hash: Optional[str] = None

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)

def __post_init__(self):
"""Parses any date parameters into Python datetime objects.
For @dataclasses with a generated __init__ function, this function is called automatically.
Includes a workaround for Python 3.10 where datetime.fromisoformat() can only parse the format output
by datetime.isoformat(), i.e. without a trailing 'Z' offset character and with UTC offset expressed
as +/-HH:mm
https://docs.python.org/3.11/library/datetime.html#datetime.datetime.fromisoformat
"""
if self.created_date:
self.created_date = datetime.fromisoformat(self.created_date.replace("Z", "+00:00", 1))
else:
self.created_date = None


@dataclass
class FundingSourceDateFields:
Expand Down Expand Up @@ -65,6 +88,10 @@ class FundingSourceGroupResponse(FundingSourceDateFields):
group_id: str
label: str

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


class FundingSourcesMixin(ClientProtocol):
"""Mixin implements APIs for funding sources."""
Expand Down
10 changes: 9 additions & 1 deletion littlepay/api/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime, timezone
from typing import Generator

from littlepay.api import ClientProtocol
from littlepay.api import ClientProtocol, from_kwargs
from littlepay.api.funding_sources import FundingSourceDateFields, FundingSourcesMixin


Expand All @@ -24,11 +24,19 @@ def csv_header() -> str:
instance = GroupResponse("", "", "")
return ",".join(vars(instance).keys())

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


@dataclass(kw_only=True)
class GroupFundingSourceResponse(FundingSourceDateFields):
id: str

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


class GroupsMixin(ClientProtocol):
"""Mixin implements APIs for concession groups."""
Expand Down
6 changes: 5 additions & 1 deletion littlepay/api/products.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Generator

from littlepay.api import ClientProtocol
from littlepay.api import ClientProtocol, from_kwargs
from littlepay.api.groups import GroupsMixin


Expand All @@ -26,6 +26,10 @@ def csv_header() -> str:
instance = ProductResponse("", "", "", "", "", "")
return ",".join(vars(instance).keys())

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


class ProductsMixin(GroupsMixin, ClientProtocol):
"""Mixin implements APIs for products."""
Expand Down
38 changes: 37 additions & 1 deletion tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from requests import HTTPError

from littlepay.api import ListResponse
from littlepay.api import ListResponse, from_kwargs
from littlepay.api.client import _client_from_active_config, _fix_bearer_token_header, _json_post_credentials, Client
from littlepay.config import Config

Expand Down Expand Up @@ -48,12 +48,21 @@ class SampleResponse:
two: str
three: int

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


@pytest.fixture
def SampleResponse_json():
return {"one": "single", "two": "double", "three": 3}


@pytest.fixture
def SampleResponse_json_with_unexpected_field():
return {"one": "single", "two": "double", "three": 3, "four": "4"}


@pytest.fixture
def default_list_params():
return dict(page=1, per_page=100)
Expand Down Expand Up @@ -232,6 +241,26 @@ def test_Client_get_params(mocker, make_client: ClientFunc, url, SampleResponse_
assert result.three == 3


def test_Client_get_response_has_unexpected_fields(
mocker, make_client: ClientFunc, url, SampleResponse_json_with_unexpected_field
):
client = make_client()
mock_response = mocker.Mock(
raise_for_status=mocker.Mock(return_value=False),
json=mocker.Mock(return_value=SampleResponse_json_with_unexpected_field),
)
req_spy = mocker.patch.object(client.oauth, "get", return_value=mock_response)

result = client._get(url, SampleResponse)

req_spy.assert_called_once_with(url, headers=client.headers, params={})
assert isinstance(result, SampleResponse)
assert result.one == "single"
assert result.two == "double"
assert result.three == 3
assert not hasattr(result, "four")


def test_Client_get_error_status(mocker, make_client: ClientFunc, url):
client = make_client()
mock_response = mocker.Mock(raise_for_status=mocker.Mock(side_effect=HTTPError))
Expand All @@ -243,6 +272,13 @@ def test_Client_get_error_status(mocker, make_client: ClientFunc, url):
req_spy.assert_called_once_with(url, headers=client.headers, params={})


def test_ListResponse_unexpected_fields():
response_json = {"list": [1, 2, 3], "total_count": 3, "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
ListResponse.from_kwargs(**response_json)


def test_Client_get_list(mocker, make_client: ClientFunc, url, default_list_params, ListResponse_sample):
client = make_client()
req_spy = mocker.patch.object(client, "_get", return_value=ListResponse_sample)
Expand Down
66 changes: 65 additions & 1 deletion tests/api/test_funding_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def ListResponse_FundingSourceGroups(expected_expiry_str):


@pytest.fixture
def mock_ClientProtocol_get_FundingResource(mocker):
def mock_ClientProtocol_get_FundingResource(mocker, expected_expiry_str):
funding_source = FundingSourceResponse(
id="0",
card_first_digits="0000",
Expand All @@ -47,6 +47,7 @@ def mock_ClientProtocol_get_FundingResource(mocker):
participant_id="cst",
is_fpan=True,
related_funding_sources=[],
created_date=expected_expiry_str,
)
return mocker.patch("littlepay.api.ClientProtocol._get", return_value=funding_source)

Expand All @@ -59,6 +60,62 @@ def mock_ClientProtocol_get_list_FundingSourceGroup(mocker, ListResponse_Funding
)


def test_FundingSourceResponse_unexpected_fields():
response_json = {
"id": "0",
"card_first_digits": "0000",
"card_last_digits": "0000",
"card_expiry_month": "11",
"card_expiry_year": "24",
"card_scheme": "Visa",
"form_factor": "unknown",
"participant_id": "cst",
"is_fpan": True,
"related_funding_sources": [],
"unexpected_field": "test value",
}

# this test will fail if any error occurs from instantiating the class
FundingSourceResponse.from_kwargs(**response_json)


def test_FundingSourceResponse_no_date_field():
response_json = {
"id": "0",
"card_first_digits": "0000",
"card_last_digits": "0000",
"card_expiry_month": "11",
"card_expiry_year": "24",
"card_scheme": "Visa",
"form_factor": "unknown",
"participant_id": "cst",
"is_fpan": True,
"related_funding_sources": [],
}

funding_source = FundingSourceResponse.from_kwargs(**response_json)
assert funding_source.created_date is None


def test_FundingSourceResponse_with_date_field(expected_expiry_str, expected_expiry):
response_json = {
"id": "0",
"card_first_digits": "0000",
"card_last_digits": "0000",
"card_expiry_month": "11",
"card_expiry_year": "24",
"card_scheme": "Visa",
"form_factor": "unknown",
"participant_id": "cst",
"is_fpan": True,
"related_funding_sources": [],
"created_date": expected_expiry_str,
}

funding_source = FundingSourceResponse.from_kwargs(**response_json)
assert funding_source.created_date == expected_expiry


def test_FundingSourceDateFields(expected_expiry_str, expected_expiry):
fields = FundingSourceDateFields(
created_date=expected_expiry_str, updated_date=expected_expiry_str, expiry_date=expected_expiry_str
Expand All @@ -69,6 +126,13 @@ def test_FundingSourceDateFields(expected_expiry_str, expected_expiry):
assert fields.expiry_date == expected_expiry


def test_FundingSourceGroupResponse_unexpected_fields():
response_json = {"id": "id", "group_id": "group_id", "label": "label", "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
FundingSourceGroupResponse.from_kwargs(**response_json)


def test_FundingSourceGroupResponse_no_dates():
response = FundingSourceGroupResponse(id="id", group_id="group_id", label="label")

Expand Down
14 changes: 14 additions & 0 deletions tests/api/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def mock_ClientProtocol_put_update_concession_group_funding_source(mocker):
return mocker.patch("littlepay.api.ClientProtocol._put", side_effect=lambda *args, **kwargs: response)


def test_GroupResponse_unexpected_fields():
response_json = {"id": "id", "label": "label", "participant_id": "participant", "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
GroupResponse.from_kwargs(**response_json)


def test_GroupResponse_csv():
group = GroupResponse("id", "label", "participant")
assert group.csv() == "id,label,participant"
Expand All @@ -81,6 +88,13 @@ def test_GroupResponse_csv_header():
assert GroupResponse.csv_header() == "id,label,participant_id"


def test_GroupFundingSourceResponse_unexpected_fields():
response_json = {"id": "id", "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
GroupFundingSourceResponse.from_kwargs(**response_json)


def test_GroupFundingSourceResponse_no_dates():
response = GroupFundingSourceResponse(id="id")

Expand Down
15 changes: 15 additions & 0 deletions tests/api/test_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ def mock_ClientProtocol_post(mocker):
return mocker.patch("littlepay.api.ClientProtocol._post", side_effect=lambda *args, **kwargs: response)


def test_ProductResponse_unexpected_fields():
response_json = {
"id": "id",
"code": "code",
"status": "status",
"type": "type",
"description": "description",
"participant_id": "participant",
"unexpected_field": "test value",
}

# this test will fail if any error occurs from instantiating the class
ProductResponse.from_kwargs(**response_json)


def test_ProductResponse_csv():
product = ProductResponse("id", "code", "status", "type", "description", "participant")
assert product.csv() == "id,code,status,type,description,participant"
Expand Down

0 comments on commit 16e7150

Please sign in to comment.