Skip to content

Commit

Permalink
Step 7
Browse files Browse the repository at this point in the history
  • Loading branch information
pfouque committed Nov 30, 2023
1 parent 26368af commit 9194723
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 76 deletions.
38 changes: 0 additions & 38 deletions .mypy.ini

This file was deleted.

5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ repos:


- repo: https://github.com/python-poetry/poetry
rev: 1.6.1
rev: 1.7.0
hooks:
- id: poetry-check
additional_dependencies:
Expand All @@ -42,8 +42,9 @@ repos:
- id: ruff

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.0
rev: v1.7.1
hooks:
- id: mypy
additional_dependencies:
- django-stubs==4.2.6
- django-guardian
67 changes: 41 additions & 26 deletions django_fsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@
from __future__ import annotations

import inspect
from collections.abc import Callable
from collections.abc import Collection
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from functools import partialmethod
from functools import wraps
from typing import TYPE_CHECKING
from typing import Any

from django.apps import apps as django_apps
from django.db import models
from django.db.models import Field
from django.db.models import QuerySet
from django.db.models.query_utils import DeferredAttribute
from django.db.models.signals import class_prepared

Expand All @@ -33,30 +40,29 @@
]

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from typing import Any
from typing import Self

from _typeshed import Incomplete
from django.contrib.auth.models import PermissionsMixin as UserWithPermissions
from django.utils.functional import _StrOrPromise

_Model = models.Model
_FSMModel = models.Model
_Field = models.Field[Any, Any]
CharField = models.CharField[Any, Any]
IntegerField = models.IntegerField[Any, Any]
ForeignKey = models.ForeignKey[Any, Any]

_StateValue = str | int
_Permission = str | Callable[[_FSMModel, UserWithPermissions], bool]
_Instance = models.Model # TODO: use real type
_ToDo = Any # TODO: use real type

else:
_Model = object
_FSMModel = object
_Field = object
CharField = models.CharField
IntegerField = models.IntegerField
ForeignKey = models.ForeignKey
Self = Any


class TransitionNotAllowed(Exception):
Expand Down Expand Up @@ -265,7 +271,7 @@ class FSMFieldMixin(_Field):

def __init__(self, *args: Any, **kwargs: Any) -> None:
self.protected = kwargs.pop("protected", False)
self.transitions: dict[type[_Model], dict[str, Any]] = {} # cls -> (transitions name -> method)
self.transitions: dict[type[_FSMModel], dict[str, Any]] = {} # cls -> (transitions name -> method)
self.state_proxy = {} # state -> ProxyClsRef

state_choices = kwargs.pop("state_choices", None)
Expand Down Expand Up @@ -317,7 +323,7 @@ def set_proxy(self, instance: _Instance, state: str) -> None:

instance.__class__ = model

def change_state(self, instance: _Instance, method: _ToDo, *args: Any, **kwargs: Any) -> Any:
def change_state(self, instance: _Instance, method: Incomplete, *args: Any, **kwargs: Any) -> Any:
meta = method._django_fsm
method_name = method.__name__
current_state = self.get_state(instance)
Expand Down Expand Up @@ -370,7 +376,7 @@ def change_state(self, instance: _Instance, method: _ToDo, *args: Any, **kwargs:

return result

def get_all_transitions(self, instance_cls: type[_Model]) -> Generator[Transition, None, None]:
def get_all_transitions(self, instance_cls: type[_FSMModel]) -> Generator[Transition, None, None]:
"""
Returns [(source, target, name, method)] for all field transitions
"""
Expand All @@ -382,7 +388,7 @@ def get_all_transitions(self, instance_cls: type[_Model]) -> Generator[Transitio
for transition in meta.transitions.values():
yield transition

def contribute_to_class(self, cls: type[_Model], name: str, private_only: bool = False, **kwargs: Any) -> None:
def contribute_to_class(self, cls: type[_FSMModel], name: str, private_only: bool = False, **kwargs: Any) -> None:
self.base_cls = cls

super().contribute_to_class(cls, name, private_only=private_only, **kwargs)
Expand All @@ -403,7 +409,7 @@ def _collect_transitions(self, *args: Any, **kwargs: Any) -> None:
if not issubclass(sender, self.base_cls):
return

def is_field_transition_method(attr: _ToDo) -> bool:
def is_field_transition_method(attr: Incomplete) -> bool:
return (
(inspect.ismethod(attr) or inspect.isfunction(attr))
and hasattr(attr, "_django_fsm")
Expand Down Expand Up @@ -449,14 +455,14 @@ class FSMKeyField(FSMFieldMixin, ForeignKey):
State Machine support for Django model
"""

def get_state(self, instance: _Instance) -> _ToDo:
def get_state(self, instance: _Instance) -> Incomplete:
return instance.__dict__[self.attname]

def set_state(self, instance: _Instance, state: str) -> None:
instance.__dict__[self.attname] = self.to_python(state)


class ConcurrentTransitionMixin(_Model):
class ConcurrentTransitionMixin(_FSMModel):
"""
Protects a Model from undesirable effects caused by concurrently executed transitions,
e.g. running the same transition multiple times at the same time, or running different
Expand Down Expand Up @@ -490,7 +496,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
def state_fields(self) -> Iterable[Any]:
return filter(lambda field: isinstance(field, FSMFieldMixin), self._meta.fields)

def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update): # type: ignore[no-untyped-def]
def _do_update(
self,
base_qs: QuerySet[Self],
using: Any,
pk_val: Any,
values: Collection[Any] | None,
update_fields: Iterable[str] | None,
forced_update: bool,
) -> bool:
# _do_update is called once for each model class in the inheritance hierarchy.
# We can only filter the base_qs on state fields (can be more than one!) present in this particular model.

Expand All @@ -500,7 +514,7 @@ def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_updat
# state filter will be used to narrow down the standard filter checking only PK
state_filter = {field.attname: self.__initial_states[field.attname] for field in filter_on}

updated = super()._do_update( # type: ignore[misc]
updated: bool = super()._do_update( # type: ignore[misc]
base_qs=base_qs.filter(**state_filter),
using=using,
pk_val=pk_val,
Expand Down Expand Up @@ -538,7 +552,7 @@ def transition(
target: _StateValue | State | None = None,
on_error: _StateValue | None = None,
conditions: list[Callable[[Any], bool]] = [],
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None,
permission: _Permission | None = None,
custom: dict[str, _StrOrPromise] = {},
) -> Callable[[Any], Any]:
"""
Expand All @@ -548,21 +562,22 @@ def transition(
has not changed after the function call.
"""

def inner_transition(func: _ToDo) -> _ToDo:
def inner_transition(func: Incomplete) -> Incomplete:
wrapper_installed, fsm_meta = True, getattr(func, "_django_fsm", None)
if not fsm_meta:
wrapper_installed = False
fsm_meta = FSMMeta(field=field, method=func)
setattr(func, "_django_fsm", fsm_meta)

# if isinstance(source, Iterable):
if isinstance(source, (list, tuple, set)):
for state in source:
func._django_fsm.add_transition(func, state, target, on_error, conditions, permission, custom)
else:
func._django_fsm.add_transition(func, source, target, on_error, conditions, permission, custom)

@wraps(func)
def _change_state(instance: _Instance, *args: Any, **kwargs: Any) -> _ToDo:
def _change_state(instance: _Instance, *args: Any, **kwargs: Any) -> Incomplete:
return fsm_meta.field.change_state(instance, func, *args, **kwargs)

if not wrapper_installed:
Expand All @@ -573,7 +588,7 @@ def _change_state(instance: _Instance, *args: Any, **kwargs: Any) -> _ToDo:
return inner_transition


def can_proceed(bound_method: _ToDo, check_conditions: bool = True) -> bool:
def can_proceed(bound_method: Incomplete, check_conditions: bool = True) -> bool:
"""
Returns True if model in state allows to call bound_method
Expand All @@ -590,7 +605,7 @@ def can_proceed(bound_method: _ToDo, check_conditions: bool = True) -> bool:
return meta.has_transition(current_state) and (not check_conditions or meta.conditions_met(self, current_state))


def has_transition_perm(bound_method: _ToDo, user: UserWithPermissions) -> bool:
def has_transition_perm(bound_method: Incomplete, user: UserWithPermissions) -> bool:
"""
Returns True if model in state allows to call bound_method and user have rights on it
"""
Expand All @@ -609,15 +624,15 @@ def has_transition_perm(bound_method: _ToDo, user: UserWithPermissions) -> bool:


class State:
def get_state(self, model: _Model, transition: Transition, result: Any, args: Any = [], kwargs: Any = {}) -> _ToDo:
def get_state(self, model: _FSMModel, transition: Transition, result: Any, args: Any = [], kwargs: Any = {}) -> Incomplete:
raise NotImplementedError


class RETURN_VALUE(State):
def __init__(self, *allowed_states: Sequence[_StateValue]) -> None:
self.allowed_states = allowed_states if allowed_states else None

def get_state(self, model: _Model, transition: Transition, result: Any, args: Any = [], kwargs: Any = {}) -> _ToDo:
def get_state(self, model: _FSMModel, transition: Transition, result: Any, args: Any = [], kwargs: Any = {}) -> Incomplete:
if self.allowed_states is not None:
if result not in self.allowed_states:
raise InvalidResultState(f"{result} is not in list of allowed states\n{self.allowed_states}")
Expand All @@ -630,8 +645,8 @@ def __init__(self, func: Callable[..., _StateValue | Any], states: Sequence[_Sta
self.allowed_states = states

def get_state(
self, model: _Model, transition: Transition, result: _StateValue | Any, args: Any = [], kwargs: Any = {}
) -> _ToDo:
self, model: _FSMModel, transition: Transition, result: _StateValue | Any, args: Any = [], kwargs: Any = {}
) -> Incomplete:
result_state = self.func(model, *args, **kwargs)
if self.allowed_states is not None:
if result_state not in self.allowed_states:
Expand Down
73 changes: 73 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,79 @@ fixable = ["I"]
force-single-line = true
required-imports = ["from __future__ import annotations"]

[tool.django-stubs]
django_settings_module = "tests.settings"

[tool.mypy]
python_version = 3.11
plugins = ["mypy_django_plugin.main"]

# Start off with these
warn_unused_configs = true
warn_redundant_casts = true
warn_unused_ignores = true

# Getting these passing should be easy
strict_equality = true
extra_checks = true

# Strongly recommend enabling this one as soon as you can
check_untyped_defs = true

# These shouldn't be too much additional work, but may be tricky to
# get passing if you use a lot of untyped libraries
disallow_subclassing_any = true
disallow_untyped_decorators = true
disallow_any_generics = true

# These next few are various gradations of forcing use of type annotations
disallow_untyped_calls = true
disallow_incomplete_defs = true
disallow_untyped_defs = true

# This one isn't too hard to get passing, but return on investment is lower
no_implicit_reexport = true

# This one can be tricky to get passing if you use a lot of untyped libraries
warn_return_any = true

[[tool.mypy.overrides]]
module = [
"tests.*",
"django_fsm.tests.*"
]
ignore_errors = true

# Start off with these
warn_unused_ignores = true

# Getting these passing should be easy
strict_equality = false
extra_checks = false

# Strongly recommend enabling this one as soon as you can
check_untyped_defs = false
# These shouldn't be too much additional work, but may be tricky to
# get passing if you use a lot of untyped libraries
disallow_subclassing_any = false
disallow_untyped_decorators = false
disallow_any_generics = false

# These next few are various gradations of forcing use of type annotations
disallow_untyped_calls = false
disallow_incomplete_defs = false
disallow_untyped_defs = false

# This one isn't too hard to get passing, but return on investment is lower
no_implicit_reexport = false

# This one can be tricky to get passing if you use a lot of untyped libraries
warn_return_any = false

[[tool.mypy.overrides]]
module = "django_fsm.management.commands.graph_transitions"
ignore_errors = true

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
2 changes: 1 addition & 1 deletion tests/testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class BlogPost(models.Model):

state = FSMField(default="new", protected=True)

def can_restore(self, user):
def can_restore(self, user) -> bool:
return user.is_superuser or user.is_staff

@transition(field=state, source="new", target="published", on_error="failed", permission="testapp.can_publish_post")
Expand Down
6 changes: 3 additions & 3 deletions tests/testapp/tests/test_multidecorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from django_fsm.signals import post_transition


class TestModel(models.Model):
class MultipletransitionsModel(models.Model):
counter = models.IntegerField(default=0)
signal_counter = models.IntegerField(default=0)
state = FSMField(default="SUBMITTED_BY_USER")
Expand All @@ -27,12 +27,12 @@ def count_calls(sender, instance, name, source, target, **kwargs):
instance.signal_counter += 1


post_transition.connect(count_calls, sender=TestModel)
post_transition.connect(count_calls, sender=MultipletransitionsModel)


class TestStateProxy(TestCase):
def test_transition_method_called_once(self):
model = TestModel()
model = MultipletransitionsModel()
model.review()
self.assertEqual(1, model.counter)
self.assertEqual(1, model.signal_counter)
Loading

0 comments on commit 9194723

Please sign in to comment.