Skip to content

Commit

Permalink
Mapping states from state name
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 3, 2024
1 parent e5c74ad commit 7c105bd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> None:
def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
"""Transite to the new state.
The new target state will be create lazily when the state is not yet instantiated,
Expand All @@ -332,9 +332,9 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) ->
try:
self._transitioning = True

if not isinstance(new_state, State):
# Make sure we have a state instance
new_state = self._create_state_instance(new_state, **kwargs)
# if not isinstance(new_state, State):
# # Make sure we have a state instance
# new_state = self._create_state_instance(new_state, **kwargs)

label = new_state.LABEL

Expand Down
16 changes: 12 additions & 4 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,9 @@ def transition_failed(
if final_state == process_states.ProcessState.CREATED:
raise exception.with_traceback(trace)

self.transition_to(process_states.Excepted, exception=exception, trace_back=trace)
state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace)
self.transition_to(new_state)

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
"""Pause the process.
Expand Down Expand Up @@ -1127,7 +1129,9 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu

def do_kill(_next_state: process_states.State) -> Any:
try:
self.transition_to(process_states.Killed, msg=exception.msg)
state_class = self.get_states_map()[process_states.ProcessState.KILLED]
new_state = self._create_state_instance(state_class, msg=exception.msg)
self.transition_to(new_state)
return True
finally:
self._killing = None
Expand Down Expand Up @@ -1179,7 +1183,9 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
:param exception: The exception that caused the failure
:param trace_back: Optional exception traceback
"""
self.transition_to(process_states.Excepted, exception=exception, trace_back=trace_back)
state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back)
self.transition_to(new_state)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
"""
Expand Down Expand Up @@ -1207,7 +1213,9 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

self.transition_to(process_states.Killed, msg=msg)
state_class = self.get_states_map()[process_states.ProcessState.KILLED]
new_state = self._create_state_instance(state_class, msg=msg)
self.transition_to(new_state)
return True

@property
Expand Down
9 changes: 5 additions & 4 deletions tests/base/test_statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ class Paused(state_machine.State):
def __init__(self, player, playing_state):
assert isinstance(playing_state, Playing), 'Must provide the playing state to pause'
super().__init__(player)
self._player = player
self.playing_state = playing_state

def __str__(self):
return f'|| ({self.playing_state})'

def play(self, track=None):
if track is not None:
self.state_machine.transition_to(Playing, track=track)
self.state_machine.transition_to(Playing(player=self.state_machine, track=track))
else:
self.state_machine.transition_to(self.playing_state)

Expand All @@ -80,7 +81,7 @@ def __str__(self):
return '[]'

def play(self, track):
self.state_machine.transition_to(Playing, track=track)
self.state_machine.transition_to(Playing(self.state_machine, track=track))


class CdPlayer(state_machine.StateMachine):
Expand All @@ -107,12 +108,12 @@ def play(self, track=None):

@state_machine.event(from_states=Playing, to_states=Paused)
def pause(self):
self.transition_to(Paused, playing_state=self._state)
self.transition_to(Paused(self, playing_state=self._state))
return True

@state_machine.event(from_states=(Playing, Paused), to_states=Stopped)
def stop(self):
self.transition_to(Stopped)
self.transition_to(Stopped(self))


class TestStateMachine(unittest.TestCase):
Expand Down

0 comments on commit 7c105bd

Please sign in to comment.