From 684cc070e40f62e9be6261bd30c5790ce8c10691 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Thu, 28 Nov 2024 16:15:59 +0100 Subject: [PATCH] StateMachine parts removed out (almost) --- src/plumpy/base/state_machine.py | 9 ++++----- src/plumpy/process_states.py | 34 +++++++++++++++++++++++++++----- src/plumpy/workchains.py | 16 ++++++++++++++- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index d99d0705..36ca73ba 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -74,8 +74,8 @@ def event( if from_states != '*': if inspect.isclass(from_states): from_states = (from_states,) - if not all(issubclass(state, State) for state in from_states): # type: ignore - raise TypeError(f'from_states: {from_states}') + # if not all(issubclass(state, State) for state in from_states): # type: ignore + # raise TypeError(f'from_states: {from_states}') if to_states != '*': if inspect.isclass(to_states): to_states = (to_states,) @@ -138,7 +138,6 @@ def label(self) -> LABEL_TYPE: """Convenience property to get the state label""" return self.LABEL - @super_check def enter(self) -> None: """Entering the state""" @@ -158,7 +157,7 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'Sta return self.state_machine.create_state(state_label, *args, **kwargs) def do_enter(self) -> None: - call_with_super_check(self.enter) + self.enter() self.in_state = True def do_exit(self) -> None: @@ -240,7 +239,7 @@ def __ensure_built(cls) -> None: # Build the states map cls._STATES_MAP = {} for state_cls in cls.STATES: - assert issubclass(state_cls, State) + # assert issubclass(state_cls, State) label = state_cls.LABEL assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" cls._STATES_MAP[label] = state_cls diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 8e4390e4..2446b044 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -4,11 +4,13 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional, Tuple, Type, Union, cast import yaml from yaml.loader import Loader +from plumpy.base.utils import call_with_super_check + try: import tblib @@ -265,8 +267,10 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: return cast(State, state) # casting from base.State to process.State -class Waiting(state_machine.State, persistence.Savable): # class Waiting(state_machine.State): +class Waiting(state_machine.State, persistence.Savable): + """The basic waiting state.""" + LABEL = ProcessState.WAITING ALLOWED = { ProcessState.RUNNING, @@ -280,6 +284,7 @@ class Waiting(state_machine.State, persistence.Savable): _interruption = None _auto_persist = {'msg', 'data', 'in_state'} + is_terminal_state = False def __str__(self) -> str: state_info = super().__str__() @@ -295,7 +300,8 @@ def __init__( data: Any | None = None, saver: Savable | None = None, ) -> None: - super().__init__(process) + self._process = process + self.in_state: bool = False self.done_callback = done_callback self.msg = msg self.data = data @@ -306,7 +312,10 @@ def process(self) -> state_machine.StateMachine: """ :return: The process """ - return self.state_machine + return self._process + + def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'state_machine.State': + return self._process.create_state(state_label, *args, **kwargs) def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -316,7 +325,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + # FIXME: the save/load instance state methods should be generic from Saver + self._process = load_context.process callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: self.done_callback = getattr(self.process, callback_name) @@ -353,6 +363,20 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) + def do_enter(self) -> None: + self.in_state = True + + def do_exit(self) -> None: + if self.is_terminal(): + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False + + @classmethod + def is_terminal(cls) -> bool: + # deprecated using class attribute `is_terminal_state` directly. + return cls.is_terminal_state + class Excepted(State): LABEL = ProcessState.EXCEPTED diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 748a44d7..0f5dd97e 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -70,7 +70,7 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']: @persistence.auto_persist('_awaiting') class Waiting(process_states.Waiting): - """Overwrite the waiting state""" + """WorkChain waiting state that can wait on its awaiting list.""" def __init__( self, @@ -90,6 +90,13 @@ def enter(self) -> None: for awaitable in self._awaiting: awaitable.add_done_callback(self._awaitable_done) + def do_enter(self) -> None: + for awaitable in self._awaiting: + awaitable.add_done_callback(self._awaitable_done) + + # FIXME: + self.in_state = True + def exit(self) -> None: super().exit() for awaitable in self._awaiting: @@ -105,6 +112,13 @@ def _awaitable_done(self, awaitable: asyncio.Future) -> None: if not self._awaiting: self._waiting_future.set_result(lang.NULL) + def do_exit(self) -> None: + if self.is_terminal(): + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + # FIXME: + self.in_state = False + class WorkChain(mixins.ContextMixin, processes.Process): """