Skip to content

Commit

Permalink
StateMachine parts removed out (almost)
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Nov 28, 2024
1 parent 4aef7e4 commit 684cc07
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 11 deletions.
9 changes: 4 additions & 5 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def event(
if from_states != '*':
if inspect.isclass(from_states):
from_states = (from_states,)
if not all(issubclass(state, State) for state in from_states): # type: ignore
raise TypeError(f'from_states: {from_states}')
# if not all(issubclass(state, State) for state in from_states): # type: ignore
# raise TypeError(f'from_states: {from_states}')
if to_states != '*':
if inspect.isclass(to_states):
to_states = (to_states,)
Expand Down Expand Up @@ -138,7 +138,6 @@ def label(self) -> LABEL_TYPE:
"""Convenience property to get the state label"""
return self.LABEL

@super_check
def enter(self) -> None:
"""Entering the state"""

Expand All @@ -158,7 +157,7 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'Sta
return self.state_machine.create_state(state_label, *args, **kwargs)

def do_enter(self) -> None:
call_with_super_check(self.enter)
self.enter()
self.in_state = True

def do_exit(self) -> None:
Expand Down Expand Up @@ -240,7 +239,7 @@ def __ensure_built(cls) -> None:
# Build the states map
cls._STATES_MAP = {}
for state_cls in cls.STATES:
assert issubclass(state_cls, State)
# assert issubclass(state_cls, State)
label = state_cls.LABEL
assert label not in cls._STATES_MAP, f"Duplicate label '{label}'"
cls._STATES_MAP[label] = state_cls
Expand Down
34 changes: 29 additions & 5 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import traceback
from enum import Enum
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional, Tuple, Type, Union, cast

import yaml
from yaml.loader import Loader

from plumpy.base.utils import call_with_super_check

try:
import tblib

Expand Down Expand Up @@ -265,8 +267,10 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State:
return cast(State, state) # casting from base.State to process.State


class Waiting(state_machine.State, persistence.Savable):
# class Waiting(state_machine.State):
class Waiting(state_machine.State, persistence.Savable):
"""The basic waiting state."""

LABEL = ProcessState.WAITING
ALLOWED = {
ProcessState.RUNNING,
Expand All @@ -280,6 +284,7 @@ class Waiting(state_machine.State, persistence.Savable):

_interruption = None
_auto_persist = {'msg', 'data', 'in_state'}
is_terminal_state = False

def __str__(self) -> str:
state_info = super().__str__()
Expand All @@ -295,7 +300,8 @@ def __init__(
data: Any | None = None,
saver: Savable | None = None,
) -> None:
super().__init__(process)
self._process = process
self.in_state: bool = False
self.done_callback = done_callback
self.msg = msg
self.data = data
Expand All @@ -306,7 +312,10 @@ def process(self) -> state_machine.StateMachine:
"""
:return: The process
"""
return self.state_machine
return self._process

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'state_machine.State':
return self._process.create_state(state_label, *args, **kwargs)

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
super().save_instance_state(out_state, save_context)
Expand All @@ -316,7 +325,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist
def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)

self.state_machine = load_context.process
# FIXME: the save/load instance state methods should be generic from Saver
self._process = load_context.process
callback_name = saved_state.get(self.DONE_CALLBACK, None)
if callback_name is not None:
self.done_callback = getattr(self.process, callback_name)
Expand Down Expand Up @@ -353,6 +363,20 @@ def resume(self, value: Any = NULL) -> None:

self._waiting_future.set_result(value)

def do_enter(self) -> None:
self.in_state = True

def do_exit(self) -> None:
if self.is_terminal():
raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

self.in_state = False

@classmethod
def is_terminal(cls) -> bool:
# deprecated using class attribute `is_terminal_state` directly.
return cls.is_terminal_state


class Excepted(State):
LABEL = ProcessState.EXCEPTED
Expand Down
16 changes: 15 additions & 1 deletion src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']:

@persistence.auto_persist('_awaiting')
class Waiting(process_states.Waiting):
"""Overwrite the waiting state"""
"""WorkChain waiting state that can wait on its awaiting list."""

def __init__(
self,
Expand All @@ -90,6 +90,13 @@ def enter(self) -> None:
for awaitable in self._awaiting:
awaitable.add_done_callback(self._awaitable_done)

def do_enter(self) -> None:
for awaitable in self._awaiting:
awaitable.add_done_callback(self._awaitable_done)

# FIXME:
self.in_state = True

def exit(self) -> None:
super().exit()
for awaitable in self._awaiting:
Expand All @@ -105,6 +112,13 @@ def _awaitable_done(self, awaitable: asyncio.Future) -> None:
if not self._awaiting:
self._waiting_future.set_result(lang.NULL)

def do_exit(self) -> None:
if self.is_terminal():
raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

# FIXME:
self.in_state = False


class WorkChain(mixins.ContextMixin, processes.Process):
"""
Expand Down

0 comments on commit 684cc07

Please sign in to comment.