Skip to content

Commit

Permalink
step 1
Browse files Browse the repository at this point in the history
  • Loading branch information
pfouque committed Nov 22, 2023
1 parent e2cc04d commit 12c4184
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ strict_equality = True
extra_checks = True

# Strongly recommend enabling this one as soon as you can
; check_untyped_defs = True
check_untyped_defs = True

# These shouldn't be too much additional work, but may be tricky to
# get passing if you use a lot of untyped libraries
Expand Down
60 changes: 34 additions & 26 deletions django_fsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import inspect
from functools import partialmethod
from functools import wraps
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable

from django.apps import apps as django_apps
from django.db import models
Expand All @@ -31,11 +34,16 @@
"RETURN_VALUE",
]

if TYPE_CHECKING:
_Model = models.Model
else:
_Model = object


class TransitionNotAllowed(Exception):
"""Raised when a transition is not allowed"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
self.object = kwargs.pop("object", None)
self.method = kwargs.pop("method", None)
super().__init__(*args, **kwargs)
Expand All @@ -54,7 +62,7 @@ class ConcurrentTransition(Exception):


class Transition:
def __init__(self, method, source, target, on_error, conditions, permission, custom):
def __init__(self, method: Callable, source, target, on_error, conditions, permission, custom) -> None:
self.method = method
self.source = source
self.target = target
Expand All @@ -64,10 +72,10 @@ def __init__(self, method, source, target, on_error, conditions, permission, cus
self.custom = custom

@property
def name(self):
def name(self) -> str:
return self.method.__name__

def has_perm(self, instance, user):
def has_perm(self, instance, user) -> bool:
if not self.permission:
return True
elif callable(self.permission):
Expand Down Expand Up @@ -116,9 +124,9 @@ class FSMMeta:
Models methods transitions meta information
"""

def __init__(self, field, method):
def __init__(self, field, method) -> None:
self.field = field
self.transitions = {} # source -> Transition
self.transitions: dict[str, Any] = {} # source -> Transition

def get_transition(self, source):
transition = self.transitions.get(source, None)
Expand All @@ -128,7 +136,7 @@ def get_transition(self, source):
transition = self.transitions.get("+", None)
return transition

def add_transition(self, method, source, target, on_error=None, conditions=[], permission=None, custom={}):
def add_transition(self, method, source, target, on_error=None, conditions=[], permission=None, custom={}) -> None:
if source in self.transitions:
raise AssertionError(f"Duplicate transition for {source} state")

Expand All @@ -142,7 +150,7 @@ def add_transition(self, method, source, target, on_error=None, conditions=[], p
custom=custom,
)

def has_transition(self, state):
def has_transition(self, state) -> bool:
"""
Lookup if any transition exists from current model state using current method
"""
Expand All @@ -157,7 +165,7 @@ def has_transition(self, state):

return False

def conditions_met(self, instance, state):
def conditions_met(self, instance, state) -> bool:
"""
Check if all conditions have been met
"""
Expand All @@ -170,13 +178,13 @@ def conditions_met(self, instance, state):
else:
return all(map(lambda condition: condition(instance), transition.conditions))

def has_transition_perm(self, instance, state, user):
def has_transition_perm(self, instance, state, user) -> bool:
transition = self.get_transition(state)

if not transition:
return False
else:
return transition.has_perm(instance, user)
return bool(transition.has_perm(instance, user))

def next_state(self, current_state):
transition = self.get_transition(current_state)
Expand All @@ -196,15 +204,15 @@ def exception_state(self, current_state):


class FSMFieldDescriptor:
def __init__(self, field):
def __init__(self, field) -> None:
self.field = field

def __get__(self, instance, type=None):
if instance is None:
return self
return self.field.get_state(instance)

def __set__(self, instance, value):
def __set__(self, instance, value) -> None:
if self.field.protected and self.field.name in instance.__dict__:
raise AttributeError(f"Direct {self.field.name} modification is not allowed")

Expand All @@ -213,12 +221,12 @@ def __set__(self, instance, value):
self.field.set_state(instance, value)


class FSMFieldMixin:
class FSMFieldMixin(Field):
descriptor_class = FSMFieldDescriptor

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
self.protected = kwargs.pop("protected", False)
self.transitions = {} # cls -> (transitions name -> method)
self.transitions: dict[Any, dict[str, Any]] = {} # cls -> (transitions name -> method)
self.state_proxy = {} # state -> ProxyClsRef

state_choices = kwargs.pop("state_choices", None)
Expand All @@ -244,7 +252,7 @@ def deconstruct(self):
def get_state(self, instance):
# The state field may be deferred. We delegate the logic of figuring this out
# and loading the deferred field on-demand to Django's built-in DeferredAttribute class.
return DeferredAttribute(self).__get__(instance)
return DeferredAttribute(self).__get__(instance) # type: ignore[attr-defined]

def set_state(self, instance, state):
instance.__dict__[self.name] = state
Expand Down Expand Up @@ -384,7 +392,7 @@ class FSMField(FSMFieldMixin, models.CharField):
State Machine support for Django model as CharField
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
kwargs.setdefault("max_length", 50)
super().__init__(*args, **kwargs)

Expand All @@ -409,7 +417,7 @@ def set_state(self, instance, state):
instance.__dict__[self.attname] = self.to_python(state)


class ConcurrentTransitionMixin:
class ConcurrentTransitionMixin(_Model):
"""
Protects a Model from undesirable effects caused by concurrently executed transitions,
e.g. running the same transition multiple times at the same time, or running different
Expand All @@ -435,7 +443,7 @@ class ConcurrentTransitionMixin:
state, thus practically negating their effect.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._update_initial_state()

Expand All @@ -453,7 +461,7 @@ def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_updat
# state filter will be used to narrow down the standard filter checking only PK
state_filter = {field.attname: self.__initial_states[field.attname] for field in filter_on}

updated = super()._do_update(
updated = super()._do_update( # type: ignore[misc]
base_qs=base_qs.filter(**state_filter),
using=using,
pk_val=pk_val,
Expand Down Expand Up @@ -518,7 +526,7 @@ def _change_state(instance, *args, **kwargs):
return inner_transition


def can_proceed(bound_method, check_conditions=True):
def can_proceed(bound_method, check_conditions=True) -> bool:
"""
Returns True if model in state allows to call bound_method
Expand All @@ -535,7 +543,7 @@ def can_proceed(bound_method, check_conditions=True):
return meta.has_transition(current_state) and (not check_conditions or meta.conditions_met(self, current_state))


def has_transition_perm(bound_method, user):
def has_transition_perm(bound_method, user) -> bool:
"""
Returns True if model in state allows to call bound_method and user have rights on it
"""
Expand All @@ -546,7 +554,7 @@ def has_transition_perm(bound_method, user):
self = bound_method.__self__
current_state = meta.field.get_state(self)

return (
return bool(
meta.has_transition(current_state)
and meta.conditions_met(self, current_state)
and meta.has_transition_perm(self, current_state, user)
Expand All @@ -559,7 +567,7 @@ def get_state(self, model, transition, result, args=[], kwargs={}):


class RETURN_VALUE(State):
def __init__(self, *allowed_states):
def __init__(self, *allowed_states) -> None:
self.allowed_states = allowed_states if allowed_states else None

def get_state(self, model, transition, result, args=[], kwargs={}):
Expand All @@ -570,7 +578,7 @@ def get_state(self, model, transition, result, args=[], kwargs={}):


class GET_STATE(State):
def __init__(self, func, states=None):
def __init__(self, func, states=None) -> None:
self.func = func
self.allowed_states = states

Expand Down

0 comments on commit 12c4184

Please sign in to comment.