diff --git a/.mypy.ini b/.mypy.ini index 0117be4..6cf1061 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -15,10 +15,10 @@ check_untyped_defs = True # get passing if you use a lot of untyped libraries disallow_subclassing_any = True disallow_untyped_decorators = True -; disallow_any_generics = True +disallow_any_generics = True # These next few are various gradations of forcing use of type annotations -; disallow_untyped_calls = True +disallow_untyped_calls = True ; disallow_incomplete_defs = True ; disallow_untyped_defs = True diff --git a/django_fsm/__init__.py b/django_fsm/__init__.py index 94eb31d..7c53fb9 100644 --- a/django_fsm/__init__.py +++ b/django_fsm/__init__.py @@ -34,21 +34,30 @@ if TYPE_CHECKING: from collections.abc import Callable + from collections.abc import Generator from collections.abc import Sequence from typing import Any - from django.contrib.auth.models import AbstractBaseUser + from django.contrib.auth.models import PermissionsMixin as UserWithPermissions from django.utils.functional import _StrOrPromise _Model = models.Model + _Field = models.Field[Any, Any] + CharField = models.CharField[str, str] + IntegerField = models.IntegerField[int, int] + ForeignKey = models.ForeignKey[Any, Any] else: _Model = object + _Field = object + CharField = models.CharField + IntegerField = models.IntegerField + ForeignKey = models.ForeignKey class TransitionNotAllowed(Exception): """Raised when a transition is not allowed""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self.object = kwargs.pop("object", None) self.method = kwargs.pop("method", None) super().__init__(*args, **kwargs) @@ -69,12 +78,12 @@ class ConcurrentTransition(Exception): class Transition: def __init__( self, - method: Callable, + method: Callable[..., Any], source: str | int | Sequence[str | int] | State, target: str | int | State | None, on_error: str | int | None, conditions: list[Callable[[Any], bool]], - permission: str | Callable[[models.Model, AbstractBaseUser], bool] | None, + permission: str | Callable[[models.Model, UserWithPermissions], bool] | None, custom: dict[str, _StrOrPromise], ) -> None: self.method = method @@ -89,7 +98,7 @@ def __init__( def name(self) -> str: return self.method.__name__ - def has_perm(self, instance, user) -> bool: + def has_perm(self, instance, user: UserWithPermissions) -> bool: if not self.permission: return True elif callable(self.permission): @@ -102,7 +111,7 @@ def has_perm(self, instance, user) -> bool: return False -def get_available_FIELD_transitions(instance, field): +def get_available_FIELD_transitions(instance, field: FSMFieldMixin) -> Generator[Transition, None, None]: """ List of transitions available in current model state with all conditions met @@ -116,14 +125,16 @@ def get_available_FIELD_transitions(instance, field): yield meta.get_transition(curr_state) -def get_all_FIELD_transitions(instance, field): +def get_all_FIELD_transitions(instance, field: FSMFieldMixin) -> Generator[Transition, None, None]: """ List of all transitions available in current model state """ return field.get_all_transitions(instance.__class__) -def get_available_user_FIELD_transitions(instance, user, field): +def get_available_user_FIELD_transitions( + instance, user: UserWithPermissions, field: FSMFieldMixin +) -> Generator[Transition, None, None]: """ List of transitions available in current model state with all conditions met and user have rights on it @@ -142,7 +153,7 @@ def __init__(self, field, method) -> None: self.field = field self.transitions: dict[str, Any] = {} # source -> Transition - def get_transition(self, source): + def get_transition(self, source: str): transition = self.transitions.get(source, None) if transition is None: transition = self.transitions.get("*", None) @@ -150,7 +161,16 @@ def get_transition(self, source): transition = self.transitions.get("+", None) return transition - def add_transition(self, method, source, target, on_error=None, conditions=[], permission=None, custom={}) -> None: + def add_transition( + self, + method: Callable[..., Any], + source: str, + target: str | int, + on_error: str | int | None = None, + conditions: list[Callable[[Any], bool]] = [], + permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None, + custom: dict[str, _StrOrPromise] = {}, + ) -> None: if source in self.transitions: raise AssertionError(f"Duplicate transition for {source} state") @@ -192,7 +212,7 @@ def conditions_met(self, instance, state) -> bool: else: return all(map(lambda condition: condition(instance), transition.conditions)) - def has_transition_perm(self, instance, state, user) -> bool: + def has_transition_perm(self, instance, state, user: UserWithPermissions) -> bool: transition = self.get_transition(state) if not transition: @@ -235,10 +255,10 @@ def __set__(self, instance, value) -> None: self.field.set_state(instance, value) -class FSMFieldMixin(Field): +class FSMFieldMixin(_Field): descriptor_class = FSMFieldDescriptor - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self.protected = kwargs.pop("protected", False) self.transitions: dict[Any, dict[str, Any]] = {} # cls -> (transitions name -> method) self.state_proxy = {} # state -> ProxyClsRef @@ -263,15 +283,15 @@ def deconstruct(self): kwargs["protected"] = self.protected return name, path, args, kwargs - def get_state(self, instance): + def get_state(self, instance) -> Any: # The state field may be deferred. We delegate the logic of figuring this out # and loading the deferred field on-demand to Django's built-in DeferredAttribute class. return DeferredAttribute(self).__get__(instance) # type: ignore[attr-defined] - def set_state(self, instance, state): + def set_state(self, instance, state: str) -> None: instance.__dict__[self.name] = state - def set_proxy(self, instance, state): + def set_proxy(self, instance, state: str) -> None: """ Change class """ @@ -292,7 +312,7 @@ def set_proxy(self, instance, state): instance.__class__ = model - def change_state(self, instance, method, *args, **kwargs): + def change_state(self, instance, method, *args: Any, **kwargs: Any): meta = method._django_fsm method_name = method.__name__ current_state = self.get_state(instance) @@ -345,7 +365,7 @@ def change_state(self, instance, method, *args, **kwargs): return result - def get_all_transitions(self, instance_cls): + def get_all_transitions(self, instance_cls) -> Generator[Transition, None, None]: """ Returns [(source, target, name, method)] for all field transitions """ @@ -372,7 +392,7 @@ def contribute_to_class(self, cls, name, private_only=False, **kwargs): class_prepared.connect(self._collect_transitions) - def _collect_transitions(self, *args, **kwargs): + def _collect_transitions(self, *args: Any, **kwargs: Any): sender = kwargs["sender"] if not issubclass(sender, self.base_cls): @@ -401,17 +421,17 @@ def is_field_transition_method(attr): self.transitions[sender] = sender_transitions -class FSMField(FSMFieldMixin, models.CharField): +class FSMField(FSMFieldMixin, CharField): """ State Machine support for Django model as CharField """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.setdefault("max_length", 50) super().__init__(*args, **kwargs) -class FSMIntegerField(FSMFieldMixin, models.IntegerField): +class FSMIntegerField(FSMFieldMixin, IntegerField): """ Same as FSMField, but stores the state value in an IntegerField. """ @@ -419,7 +439,7 @@ class FSMIntegerField(FSMFieldMixin, models.IntegerField): pass -class FSMKeyField(FSMFieldMixin, models.ForeignKey): +class FSMKeyField(FSMFieldMixin, ForeignKey): """ State Machine support for Django model """ @@ -457,7 +477,7 @@ class ConcurrentTransitionMixin(_Model): state, thus practically negating their effect. """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._update_initial_state() @@ -495,14 +515,14 @@ def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_updat return updated - def _update_initial_state(self): + def _update_initial_state(self) -> None: self.__initial_states = {field.attname: field.value_from_object(self) for field in self.state_fields} - def refresh_from_db(self, *args, **kwargs): + def refresh_from_db(self, *args: Any, **kwargs: Any) -> None: super().refresh_from_db(*args, **kwargs) self._update_initial_state() - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: super().save(*args, **kwargs) self._update_initial_state() @@ -513,7 +533,7 @@ def transition( target: str | int | State | None = None, on_error: str | int | None = None, conditions: list[Callable[[Any], bool]] = [], - permission: str | Callable[[models.Model, AbstractBaseUser], bool] | None = None, + permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None, custom: dict[str, _StrOrPromise] = {}, ): """ @@ -537,7 +557,7 @@ def inner_transition(func): func._django_fsm.add_transition(func, source, target, on_error, conditions, permission, custom) @wraps(func) - def _change_state(instance, *args, **kwargs): + def _change_state(instance, *args: Any, **kwargs: Any): return fsm_meta.field.change_state(instance, func, *args, **kwargs) if not wrapper_installed: @@ -548,7 +568,7 @@ def _change_state(instance, *args, **kwargs): return inner_transition -def can_proceed(bound_method, check_conditions=True) -> bool: +def can_proceed(bound_method, check_conditions: bool = True) -> bool: """ Returns True if model in state allows to call bound_method @@ -565,7 +585,7 @@ def can_proceed(bound_method, check_conditions=True) -> bool: return meta.has_transition(current_state) and (not check_conditions or meta.conditions_met(self, current_state)) -def has_transition_perm(bound_method, user) -> bool: +def has_transition_perm(bound_method, user: UserWithPermissions) -> bool: """ Returns True if model in state allows to call bound_method and user have rights on it """ @@ -589,7 +609,7 @@ def get_state(self, model, transition, result, args=[], kwargs={}): class RETURN_VALUE(State): - def __init__(self, *allowed_states) -> None: + def __init__(self, *allowed_states: Sequence[str | int]) -> None: self.allowed_states = allowed_states if allowed_states else None def get_state(self, model, transition, result, args=[], kwargs={}): @@ -600,7 +620,7 @@ def get_state(self, model, transition, result, args=[], kwargs={}): class GET_STATE(State): - def __init__(self, func, states=None) -> None: + def __init__(self, func: Callable[..., str | int], states: Sequence[str | int] | None = None) -> None: self.func = func self.allowed_states = states