diff --git a/django_fsm/__init__.py b/django_fsm/__init__.py index 298cd71..94eb31d 100644 --- a/django_fsm/__init__.py +++ b/django_fsm/__init__.py @@ -7,8 +7,6 @@ from functools import partialmethod from functools import wraps from typing import TYPE_CHECKING -from typing import Any -from typing import Callable from django.apps import apps as django_apps from django.db import models @@ -35,6 +33,13 @@ ] if TYPE_CHECKING: + from collections.abc import Callable + from collections.abc import Sequence + from typing import Any + + from django.contrib.auth.models import AbstractBaseUser + from django.utils.functional import _StrOrPromise + _Model = models.Model else: _Model = object @@ -62,7 +67,16 @@ class ConcurrentTransition(Exception): class Transition: - def __init__(self, method: Callable, source, target, on_error, conditions, permission, custom) -> None: + def __init__( + self, + method: Callable, + source: str | int | Sequence[str | int] | State, + target: str | int | State | None, + on_error: str | int | None, + conditions: list[Callable[[Any], bool]], + permission: str | Callable[[models.Model, AbstractBaseUser], bool] | None, + custom: dict[str, _StrOrPromise], + ) -> None: self.method = method self.source = source self.target = target @@ -493,7 +507,15 @@ def save(self, *args, **kwargs): self._update_initial_state() -def transition(field, source="*", target=None, on_error=None, conditions=[], permission=None, custom={}): +def transition( + field, + source: str | int | Sequence[str | int] | State = "*", + target: str | int | State | None = None, + on_error: str | int | None = None, + conditions: list[Callable[[Any], bool]] = [], + permission: str | Callable[[models.Model, AbstractBaseUser], bool] | None = None, + custom: dict[str, _StrOrPromise] = {}, +): """ Method decorator to mark allowed transitions.