Skip to content

Commit

Permalink
Step 6
Browse files Browse the repository at this point in the history
  • Loading branch information
pfouque committed Nov 22, 2023
1 parent 4cb9f7f commit 26368af
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions django_fsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@

_Model = models.Model
_Field = models.Field[Any, Any]
CharField = models.CharField[str, str]
IntegerField = models.IntegerField[int, int]
CharField = models.CharField[Any, Any]
IntegerField = models.IntegerField[Any, Any]
ForeignKey = models.ForeignKey[Any, Any]

_StateValue = str | int
Expand Down Expand Up @@ -168,10 +168,10 @@ def get_transition(self, source: str) -> Transition | None:

def add_transition(
self,
method: Callable[..., str | int | None],
method: Callable[..., _StateValue | Any],
source: str,
target: str | int,
on_error: str | int | None = None,
target: _StateValue,
on_error: _StateValue | None = None,
conditions: list[Callable[[_Instance], bool]] = [],
permission: str | Callable[[_Instance, UserWithPermissions], bool] | None = None,
custom: dict[str, _StrOrPromise] = {},
Expand Down Expand Up @@ -225,15 +225,15 @@ def has_transition_perm(self, instance: _Instance, state: str, user: UserWithPer
else:
return bool(transition.has_perm(instance, user))

def next_state(self, current_state: str) -> str | int:
def next_state(self, current_state: str) -> _StateValue:
transition = self.get_transition(current_state)

if transition is None:
raise TransitionNotAllowed(f"No transition from {current_state}")

return transition.target

def exception_state(self, current_state: str) -> str | int | None:
def exception_state(self, current_state: str) -> _StateValue | None:
transition = self.get_transition(current_state)

if transition is None:
Expand Down Expand Up @@ -534,9 +534,9 @@ def save(self, *args: Any, **kwargs: Any) -> None:

def transition(
field: FSMFieldMixin,
source: str | int | Sequence[str | int] = "*",
target: str | int | State | None = None,
on_error: str | int | None = None,
source: _StateValue | Sequence[_StateValue] = "*",
target: _StateValue | State | None = None,
on_error: _StateValue | None = None,
conditions: list[Callable[[Any], bool]] = [],
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None,
custom: dict[str, _StrOrPromise] = {},
Expand Down Expand Up @@ -614,7 +614,7 @@ def get_state(self, model: _Model, transition: Transition, result: Any, args: An


class RETURN_VALUE(State):
def __init__(self, *allowed_states: Sequence[str | int]) -> None:
def __init__(self, *allowed_states: Sequence[_StateValue]) -> None:
self.allowed_states = allowed_states if allowed_states else None

def get_state(self, model: _Model, transition: Transition, result: Any, args: Any = [], kwargs: Any = {}) -> _ToDo:
Expand All @@ -625,7 +625,7 @@ def get_state(self, model: _Model, transition: Transition, result: Any, args: An


class GET_STATE(State):
def __init__(self, func: Callable[..., str | int], states: Sequence[str | int] | None = None) -> None:
def __init__(self, func: Callable[..., _StateValue | Any], states: Sequence[_StateValue] | None = None) -> None:
self.func = func
self.allowed_states = states

Expand Down

0 comments on commit 26368af

Please sign in to comment.