Skip to content

Commit

Permalink
Step 3
Browse files Browse the repository at this point in the history
  • Loading branch information
pfouque committed Nov 22, 2023
1 parent 0b001e6 commit 253e685
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 35 deletions.
4 changes: 2 additions & 2 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ check_untyped_defs = True
# get passing if you use a lot of untyped libraries
disallow_subclassing_any = True
disallow_untyped_decorators = True
; disallow_any_generics = True
disallow_any_generics = True

# These next few are various gradations of forcing use of type annotations
; disallow_untyped_calls = True
disallow_untyped_calls = True
; disallow_incomplete_defs = True
; disallow_untyped_defs = True

Expand Down
86 changes: 53 additions & 33 deletions django_fsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,30 @@

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Sequence
from typing import Any

from django.contrib.auth.models import AbstractBaseUser
from django.contrib.auth.models import PermissionsMixin as UserWithPermissions
from django.utils.functional import _StrOrPromise

_Model = models.Model
_Field = models.Field[Any, Any]
CharField = models.CharField[str, str]
IntegerField = models.IntegerField[int, int]
ForeignKey = models.ForeignKey[Any, Any]
else:
_Model = object
_Field = object
CharField = models.CharField
IntegerField = models.IntegerField
ForeignKey = models.ForeignKey


class TransitionNotAllowed(Exception):
"""Raised when a transition is not allowed"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.object = kwargs.pop("object", None)
self.method = kwargs.pop("method", None)
super().__init__(*args, **kwargs)
Expand All @@ -69,12 +78,12 @@ class ConcurrentTransition(Exception):
class Transition:
def __init__(
self,
method: Callable,
method: Callable[..., Any],
source: str | int | Sequence[str | int] | State,
target: str | int | State | None,
on_error: str | int | None,
conditions: list[Callable[[Any], bool]],
permission: str | Callable[[models.Model, AbstractBaseUser], bool] | None,
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None,
custom: dict[str, _StrOrPromise],
) -> None:
self.method = method
Expand All @@ -89,7 +98,7 @@ def __init__(
def name(self) -> str:
return self.method.__name__

def has_perm(self, instance, user) -> bool:
def has_perm(self, instance, user: UserWithPermissions) -> bool:
if not self.permission:
return True
elif callable(self.permission):
Expand All @@ -102,7 +111,7 @@ def has_perm(self, instance, user) -> bool:
return False


def get_available_FIELD_transitions(instance, field):
def get_available_FIELD_transitions(instance, field: FSMFieldMixin) -> Generator[Transition, None, None]:
"""
List of transitions available in current model state
with all conditions met
Expand All @@ -116,14 +125,16 @@ def get_available_FIELD_transitions(instance, field):
yield meta.get_transition(curr_state)


def get_all_FIELD_transitions(instance, field):
def get_all_FIELD_transitions(instance, field: FSMFieldMixin) -> Generator[Transition, None, None]:
"""
List of all transitions available in current model state
"""
return field.get_all_transitions(instance.__class__)


def get_available_user_FIELD_transitions(instance, user, field):
def get_available_user_FIELD_transitions(
instance, user: UserWithPermissions, field: FSMFieldMixin
) -> Generator[Transition, None, None]:
"""
List of transitions available in current model state
with all conditions met and user have rights on it
Expand All @@ -142,15 +153,24 @@ def __init__(self, field, method) -> None:
self.field = field
self.transitions: dict[str, Any] = {} # source -> Transition

def get_transition(self, source):
def get_transition(self, source: str):
transition = self.transitions.get(source, None)
if transition is None:
transition = self.transitions.get("*", None)
if transition is None:
transition = self.transitions.get("+", None)
return transition

def add_transition(self, method, source, target, on_error=None, conditions=[], permission=None, custom={}) -> None:
def add_transition(
self,
method: Callable[..., Any],
source: str,
target: str | int,
on_error: str | int | None = None,
conditions: list[Callable[[Any], bool]] = [],
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None,
custom: dict[str, _StrOrPromise] = {},
) -> None:
if source in self.transitions:
raise AssertionError(f"Duplicate transition for {source} state")

Expand Down Expand Up @@ -192,7 +212,7 @@ def conditions_met(self, instance, state) -> bool:
else:
return all(map(lambda condition: condition(instance), transition.conditions))

def has_transition_perm(self, instance, state, user) -> bool:
def has_transition_perm(self, instance, state, user: UserWithPermissions) -> bool:
transition = self.get_transition(state)

if not transition:
Expand Down Expand Up @@ -235,10 +255,10 @@ def __set__(self, instance, value) -> None:
self.field.set_state(instance, value)


class FSMFieldMixin(Field):
class FSMFieldMixin(_Field):
descriptor_class = FSMFieldDescriptor

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.protected = kwargs.pop("protected", False)
self.transitions: dict[Any, dict[str, Any]] = {} # cls -> (transitions name -> method)
self.state_proxy = {} # state -> ProxyClsRef
Expand All @@ -263,15 +283,15 @@ def deconstruct(self):
kwargs["protected"] = self.protected
return name, path, args, kwargs

def get_state(self, instance):
def get_state(self, instance) -> Any:
# The state field may be deferred. We delegate the logic of figuring this out
# and loading the deferred field on-demand to Django's built-in DeferredAttribute class.
return DeferredAttribute(self).__get__(instance) # type: ignore[attr-defined]

def set_state(self, instance, state):
def set_state(self, instance, state: str) -> None:
instance.__dict__[self.name] = state

def set_proxy(self, instance, state):
def set_proxy(self, instance, state: str) -> None:
"""
Change class
"""
Expand All @@ -292,7 +312,7 @@ def set_proxy(self, instance, state):

instance.__class__ = model

def change_state(self, instance, method, *args, **kwargs):
def change_state(self, instance, method, *args: Any, **kwargs: Any):
meta = method._django_fsm
method_name = method.__name__
current_state = self.get_state(instance)
Expand Down Expand Up @@ -345,7 +365,7 @@ def change_state(self, instance, method, *args, **kwargs):

return result

def get_all_transitions(self, instance_cls):
def get_all_transitions(self, instance_cls) -> Generator[Transition, None, None]:
"""
Returns [(source, target, name, method)] for all field transitions
"""
Expand All @@ -372,7 +392,7 @@ def contribute_to_class(self, cls, name, private_only=False, **kwargs):

class_prepared.connect(self._collect_transitions)

def _collect_transitions(self, *args, **kwargs):
def _collect_transitions(self, *args: Any, **kwargs: Any):
sender = kwargs["sender"]

if not issubclass(sender, self.base_cls):
Expand Down Expand Up @@ -401,25 +421,25 @@ def is_field_transition_method(attr):
self.transitions[sender] = sender_transitions


class FSMField(FSMFieldMixin, models.CharField):
class FSMField(FSMFieldMixin, CharField):
"""
State Machine support for Django model as CharField
"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs.setdefault("max_length", 50)
super().__init__(*args, **kwargs)


class FSMIntegerField(FSMFieldMixin, models.IntegerField):
class FSMIntegerField(FSMFieldMixin, IntegerField):
"""
Same as FSMField, but stores the state value in an IntegerField.
"""

pass


class FSMKeyField(FSMFieldMixin, models.ForeignKey):
class FSMKeyField(FSMFieldMixin, ForeignKey):
"""
State Machine support for Django model
"""
Expand Down Expand Up @@ -457,7 +477,7 @@ class ConcurrentTransitionMixin(_Model):
state, thus practically negating their effect.
"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._update_initial_state()

Expand Down Expand Up @@ -495,14 +515,14 @@ def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_updat

return updated

def _update_initial_state(self):
def _update_initial_state(self) -> None:
self.__initial_states = {field.attname: field.value_from_object(self) for field in self.state_fields}

def refresh_from_db(self, *args, **kwargs):
def refresh_from_db(self, *args: Any, **kwargs: Any) -> None:
super().refresh_from_db(*args, **kwargs)
self._update_initial_state()

def save(self, *args, **kwargs):
def save(self, *args: Any, **kwargs: Any) -> None:
super().save(*args, **kwargs)
self._update_initial_state()

Expand All @@ -513,7 +533,7 @@ def transition(
target: str | int | State | None = None,
on_error: str | int | None = None,
conditions: list[Callable[[Any], bool]] = [],
permission: str | Callable[[models.Model, AbstractBaseUser], bool] | None = None,
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None,
custom: dict[str, _StrOrPromise] = {},
):
"""
Expand All @@ -537,7 +557,7 @@ def inner_transition(func):
func._django_fsm.add_transition(func, source, target, on_error, conditions, permission, custom)

@wraps(func)
def _change_state(instance, *args, **kwargs):
def _change_state(instance, *args: Any, **kwargs: Any):
return fsm_meta.field.change_state(instance, func, *args, **kwargs)

if not wrapper_installed:
Expand All @@ -548,7 +568,7 @@ def _change_state(instance, *args, **kwargs):
return inner_transition


def can_proceed(bound_method, check_conditions=True) -> bool:
def can_proceed(bound_method, check_conditions: bool = True) -> bool:
"""
Returns True if model in state allows to call bound_method
Expand All @@ -565,7 +585,7 @@ def can_proceed(bound_method, check_conditions=True) -> bool:
return meta.has_transition(current_state) and (not check_conditions or meta.conditions_met(self, current_state))


def has_transition_perm(bound_method, user) -> bool:
def has_transition_perm(bound_method, user: UserWithPermissions) -> bool:
"""
Returns True if model in state allows to call bound_method and user have rights on it
"""
Expand All @@ -589,7 +609,7 @@ def get_state(self, model, transition, result, args=[], kwargs={}):


class RETURN_VALUE(State):
def __init__(self, *allowed_states) -> None:
def __init__(self, *allowed_states: Sequence[str | int]) -> None:
self.allowed_states = allowed_states if allowed_states else None

def get_state(self, model, transition, result, args=[], kwargs={}):
Expand All @@ -600,7 +620,7 @@ def get_state(self, model, transition, result, args=[], kwargs={}):


class GET_STATE(State):
def __init__(self, func, states=None) -> None:
def __init__(self, func: Callable[..., str | int], states: Sequence[str | int] | None = None) -> None:
self.func = func
self.allowed_states = states

Expand Down

0 comments on commit 253e685

Please sign in to comment.