Skip to content

Commit

Permalink
Fix ruff warning UP007 - Using | for type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
berroar committed Nov 22, 2024
1 parent f37859c commit b70d6de
Show file tree
Hide file tree
Showing 64 changed files with 284 additions and 320 deletions.
8 changes: 4 additions & 4 deletions app/authentication/authenticator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from typing import Any, Generator, Mapping, MutableMapping, Optional
from typing import Any, Generator, Mapping, MutableMapping
from uuid import uuid4

from blinker import ANY
Expand Down Expand Up @@ -30,15 +30,15 @@


@login_manager.user_loader
def user_loader(user_id: str) -> Optional[str]:
def user_loader(user_id: str) -> str | None:
logger.debug("loading user", user_id=user_id)
return load_user()


@login_manager.request_loader
def request_load_user(
request: Request,
) -> Optional[User]:
) -> User | None:
logger.debug("load user")

extend_session = not (
Expand Down Expand Up @@ -94,7 +94,7 @@ def _is_session_valid(session_store: SessionStore) -> bool:
)


def load_user(extend_session: bool = True) -> Optional[User]:
def load_user(extend_session: bool = True) -> User | None:
"""
Checks for the present of the JWT in the users sessions
:return: A user object if a JWT token is available in the session
Expand Down
5 changes: 1 addition & 4 deletions app/authentication/no_questionnaire_state_exception.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Union


class NoQuestionnaireStateException(Exception):
def __init__(self, value: Union[str, int]) -> None:
def __init__(self, value: str | int) -> None:
super().__init__()
self.value = value

Expand Down
5 changes: 1 addition & 4 deletions app/authentication/no_token_exception.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Union


class NoTokenException(Exception):
def __init__(self, value: Union[str, int]) -> None:
def __init__(self, value: str | int) -> None:
super().__init__()
self.value = value

Expand Down
29 changes: 17 additions & 12 deletions app/data_models/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,26 @@

from dataclasses import asdict, dataclass, field
from decimal import Decimal
from typing import Optional, TypedDict, Union, overload
from typing import TypedDict, overload

from markupsafe import Markup, escape

DictAnswer = dict[str, Union[int, str]]
DictAnswer = dict[str, int | str]
ListAnswer = list[str]
ListDictAnswer = list[DictAnswer]
DictAnswerEscaped = dict[str, Union[int, Markup]]
DictAnswerEscaped = dict[str, int | Markup]
ListAnswerEscaped = list[Markup]
ListDictAnswerEscaped = list[DictAnswerEscaped]

AnswerValueTypes = Union[str, int, Decimal, DictAnswer, ListAnswer, ListDictAnswer]
AnswerValueEscapedTypes = Union[
Markup, int, Decimal, DictAnswerEscaped, ListAnswerEscaped, ListDictAnswerEscaped
]
AnswerValueTypes = str | int | Decimal | DictAnswer | ListAnswer | ListDictAnswer
AnswerValueEscapedTypes = (
Markup
| int
| Decimal
| DictAnswerEscaped
| ListAnswerEscaped
| ListDictAnswerEscaped
)


class AnswerDict(TypedDict, total=False):
Expand All @@ -29,7 +34,7 @@ class AnswerDict(TypedDict, total=False):
class Answer:
answer_id: str
value: AnswerValueTypes
list_item_id: Optional[str] = field(default=None)
list_item_id: str | None = field(default=None)

@classmethod
def from_dict(cls, answer_dict: AnswerDict) -> Answer:
Expand Down Expand Up @@ -69,13 +74,13 @@ def escape_answer_value(value: str) -> Markup: ... # pragma: no cover

@overload
def escape_answer_value(
value: Union[None, int, Decimal]
) -> Union[None, int, Decimal]: ... # pragma: no cover
value: None | int | Decimal,
) -> None | int | Decimal: ... # pragma: no cover


def escape_answer_value(
value: Optional[AnswerValueTypes],
) -> Optional[AnswerValueEscapedTypes]:
value: AnswerValueTypes | None,
) -> AnswerValueEscapedTypes | None:
if isinstance(value, list):
return [escape(item) for item in value]

Expand Down
16 changes: 7 additions & 9 deletions app/data_models/answer_store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from typing import Iterable, Iterator, Optional
from typing import Iterable, Iterator

from app.data_models.answer import Answer, AnswerDict

AnswerKeyType = tuple[str, Optional[str]]
AnswerKeyType = tuple[str, str | None]


class AnswerStore:
Expand All @@ -21,7 +21,7 @@ class AnswerStore:
}
"""

def __init__(self, answers: Optional[Iterable[AnswerDict]] = None):
def __init__(self, answers: Iterable[AnswerDict] | None = None):
"""Instantiate an answer_store.
Args:
Expand Down Expand Up @@ -81,8 +81,8 @@ def add_or_update(self, answer: Answer) -> bool:
return False

def get_answer(
self, answer_id: str, list_item_id: Optional[str] = None
) -> Optional[Answer]:
self, answer_id: str, list_item_id: str | None = None
) -> Answer | None:
"""Get a single answer from the store
Args:
Expand All @@ -95,7 +95,7 @@ def get_answer(
return self.answer_map.get((answer_id, list_item_id))

def get_answers_by_answer_id(
self, answer_ids: Iterable[str], list_item_id: Optional[str] = None
self, answer_ids: Iterable[str], list_item_id: str | None = None
) -> list[Answer]:
"""Get multiple answers from the store using the answer_id
Expand All @@ -121,9 +121,7 @@ def clear(self) -> None:
"""
self.answer_map.clear()

def remove_answer(
self, answer_id: str, *, list_item_id: Optional[str] = None
) -> bool:
def remove_answer(self, answer_id: str, *, list_item_id: str | None = None) -> bool:
"""
Removes answer *in place* from the answer store.
:return: True if answer removed else False
Expand Down
22 changes: 11 additions & 11 deletions app/data_models/app_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timezone
from typing import Any, Optional, Union
from typing import Any

from marshmallow import Schema, fields, post_load, pre_dump

Expand All @@ -11,8 +11,8 @@ def __init__(
state_data: str,
collection_exercise_sid: str,
version: int,
submitted_at: Optional[datetime] = None,
expires_at: Optional[datetime] = None,
submitted_at: datetime | None = None,
expires_at: datetime | None = None,
):
self.user_id = user_id
self.state_data = state_data
Expand All @@ -28,9 +28,9 @@ class EQSession:
def __init__(
self,
eq_session_id: str,
user_id: Optional[str],
user_id: str | None,
expires_at: datetime,
session_data: Optional[str],
session_data: str | None,
):
self.eq_session_id = eq_session_id
self.user_id = user_id
Expand All @@ -52,9 +52,9 @@ class Timestamp(fields.Field):
def _serialize(
self,
value: datetime,
*args: Optional[list],
*args: list | None,
**kwargs: Any,
) -> Optional[int]:
) -> int | None:
if value:
# Timezone aware datetime to timestamp
return int(value.replace(tzinfo=timezone.utc).timestamp())
Expand All @@ -63,9 +63,9 @@ def _serialize(
def _deserialize(
self,
value: float,
*args: Optional[list],
*args: list | None,
**kwargs: Any,
) -> Optional[datetime]:
) -> datetime | None:
if value:
# Timestamp to timezone aware datetime
return datetime.fromtimestamp(value, tz=timezone.utc)
Expand All @@ -79,9 +79,9 @@ class DateTimeSchemaMixin:
@staticmethod
@pre_dump
def set_date(
data: Union[EQSession, QuestionnaireState],
data: EQSession | QuestionnaireState,
**kwargs: Any,
) -> Union[EQSession, QuestionnaireState]:
) -> EQSession | QuestionnaireState:
data.updated_at = datetime.now(tz=timezone.utc)
return data

Expand Down
10 changes: 5 additions & 5 deletions app/data_models/list_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
from functools import cached_property
from string import ascii_letters
from typing import Iterable, Iterator, Optional, TypedDict, overload
from typing import Iterable, Iterator, TypedDict, overload

from structlog import get_logger

Expand All @@ -27,9 +27,9 @@ class ListModel:
def __init__(
self,
name: str,
items: Optional[list[str]] = None,
primary_person: Optional[str] = None,
same_name_items: Optional[list[str]] = None,
items: list[str] | None = None,
primary_person: str | None = None,
same_name_items: list[str] | None = None,
):
self.name = name
self.items = items or []
Expand Down Expand Up @@ -127,7 +127,7 @@ class ListStore:
```
"""

def __init__(self, items: Optional[Iterable[ListModelDictType]] = None):
def __init__(self, items: Iterable[ListModelDictType] | None = None):
items = items or []

self._lists = self._build_map(items)
Expand Down
4 changes: 2 additions & 2 deletions app/data_models/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from enum import StrEnum
from typing import Mapping, Optional, TypedDict
from typing import Mapping, TypedDict


class CompletionStatus(StrEnum):
Expand All @@ -24,7 +24,7 @@ class Progress:
section_id: str
block_ids: list[str]
status: CompletionStatus
list_item_id: Optional[str] = None
list_item_id: str | None = None

@classmethod
def from_dict(cls, progress_dict: ProgressDict) -> Progress:
Expand Down
8 changes: 4 additions & 4 deletions app/data_models/questionnaire_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING, MutableMapping, Optional
from typing import TYPE_CHECKING, MutableMapping

from app.data_models.answer_store import AnswerStore
from app.data_models.data_stores import DataStores
Expand All @@ -22,7 +22,7 @@ class QuestionnaireStore:
LATEST_VERSION = 1

def __init__(
self, storage: EncryptedQuestionnaireStorage, version: Optional[int] = None
self, storage: EncryptedQuestionnaireStorage, version: int | None = None
):
self._storage = storage
if version is None:
Expand All @@ -31,8 +31,8 @@ def __init__(
self._metadata: MutableMapping = {}
self._stores = DataStores()
self.data_stores = self._stores
self.submitted_at: Optional[datetime]
self.collection_exercise_sid: Optional[str]
self.submitted_at: datetime | None
self.collection_exercise_sid: str | None

(
raw_data,
Expand Down
8 changes: 3 additions & 5 deletions app/data_models/relationship_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass
from typing import Iterable, Iterator, Optional, TypedDict, cast
from typing import Iterable, Iterator, TypedDict, cast


class RelationshipDict(TypedDict, total=False):
Expand Down Expand Up @@ -27,9 +27,7 @@ class RelationshipStore:
Stores and updates relationships.
"""

def __init__(
self, relationships: Optional[Iterable[RelationshipDict]] = None
) -> None:
def __init__(self, relationships: Iterable[RelationshipDict] | None = None) -> None:
self._is_dirty = False
self._relationships = self._build_map(relationships or [])

Expand Down Expand Up @@ -60,7 +58,7 @@ def serialize(self) -> list[RelationshipDict]:

def get_relationship(
self, list_item_id: str, to_list_item_id: str
) -> Optional[Relationship]:
) -> Relationship | None:
key = (list_item_id, to_list_item_id)
return self._relationships.get(key)

Expand Down
4 changes: 2 additions & 2 deletions app/data_models/session_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Optional
from typing import Any


class SessionData:
def __init__(
self,
language_code: Optional[str] = None,
language_code: str | None = None,
confirmation_email_count: int = 0,
feedback_count: int = 0,
**_: Any,
Expand Down
13 changes: 6 additions & 7 deletions app/data_models/session_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from datetime import datetime
from typing import Optional

from flask import current_app
from jwcrypto.common import base64url_decode
Expand All @@ -17,19 +16,19 @@

class SessionStore:
def __init__(
self, user_ik: str, pepper: str, eq_session_id: Optional[str] = None
self, user_ik: str, pepper: str, eq_session_id: str | None = None
) -> None:
self.eq_session_id = eq_session_id
self.user_id: Optional[str] = None
self.user_id: str | None = None
self.user_ik = user_ik
self.session_data: Optional[SessionData] = None
self._eq_session: Optional[EQSession] = None
self.session_data: SessionData | None = None
self._eq_session: EQSession | None = None
self.pepper = pepper
if eq_session_id:
self._load()

@property
def expiration_time(self) -> Optional[datetime]:
def expiration_time(self) -> datetime | None:
"""
Checking if expires_at is available can be removed soon after deployment,
it is only needed to cater for in-flight sessions.
Expand Down Expand Up @@ -97,7 +96,7 @@ def _load(self) -> None:
logger.debug(
"finding eq_session_id in database", eq_session_id=self.eq_session_id
)
self._eq_session: Optional[EQSession] = current_app.eq["storage"].get(EQSession, self.eq_session_id) # type: ignore
self._eq_session: EQSession | None = current_app.eq["storage"].get(EQSession, self.eq_session_id) # type: ignore

if self._eq_session and self._eq_session.session_data:
self.user_id = self._eq_session.user_id
Expand Down
Loading

0 comments on commit b70d6de

Please sign in to comment.