diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index c1fdb3b2..b544d38b 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -281,7 +281,7 @@ " def continue_fn(self):\n", " print('continuing')\n", " # message is stored in the process status\n", - " return plumpy.Kill('I was killed')\n", + " return plumpy.Kill(plumpy.MessageBuilder.kill('I was killed'))\n", "\n", "\n", "process = ContinueProcess()\n", @@ -1118,7 +1118,7 @@ "\n", "process = SimpleProcess(communicator=communicator)\n", "\n", - "pprint(communicator.rpc_send(str(process.pid), plumpy.STATUS_MSG).result())" + "pprint(communicator.rpc_send(str(process.pid), plumpy.MessageBuilder.status()).result())" ] }, { diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index d99d0705..681858f0 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """The state machine for processes""" +from __future__ import annotations + import enum import functools import inspect @@ -8,7 +10,19 @@ import os import sys from types import TracebackType -from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Optional, + Sequence, + Set, + Type, + Union, +) from plumpy.futures import Future @@ -31,7 +45,7 @@ class StateEntryFailed(Exception): # noqa: N818 Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: + def __init__(self, state: State, *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -187,7 +201,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': :param kwargs: Any keyword arguments to be passed to the constructor :return: An instance of the state machine """ - inst = super().__call__(*args, **kwargs) + inst: StateMachine = super().__call__(*args, **kwargs) inst.transition_to(inst.create_initial_state()) call_with_super_check(inst.init) return inst @@ -300,16 +314,25 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None: + def transition_to(self, new_state: State | None, **kwargs: Any) -> None: + """Transite to the new state. + + The new target state will be create lazily when the state is not yet instantiated, + which will happened for states not in the expect path such as pause and kill. + The arguments are passed to the state class to create state instance. + (process arg does not need to pass since it will always call with 'self' as process) + """ assert not self._transitioning, 'Cannot call transition_to when already transitioning state' + if new_state is None: + # early return if the new state is `None` + # it can happened when transit from terminal state + return None + initial_state_label = self._state.LABEL if self._state is not None else None label = None try: self._transitioning = True - - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, *args, **kwargs) label = new_state.LABEL # If the previous transition failed, do not try to exit it but go straight to next state @@ -319,8 +342,7 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A try: self._enter_next_state(new_state) except StateEntryFailed as exception: - # Make sure we have a state instance - new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs) + new_state = exception.state label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) @@ -338,7 +360,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._transitioning = False def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: """Called when a state transitions fails. @@ -355,6 +381,10 @@ def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: + # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic + # because the label is defined after the state and required to be know before calling this function. + # This method should be replaced by `_create_state_instance`. + # aiida-core using this method for its Waiting state override. try: return self.get_states_map()[state_label](self, *args, **kwargs) except KeyError: @@ -383,20 +413,10 @@ def _enter_next_state(self, next_state: State) -> None: self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State: - if isinstance(state, State): - # It's already a state instance - return state - - # OK, have to create it - state_cls = self._ensure_state_class(state) - return state_cls(self, *args, **kwargs) + def _create_state_instance(self, state_cls: Hashable, **kwargs: Any) -> State: + if state_cls not in self.get_states_map(): + raise ValueError(f'{state_cls} is not a valid state') - def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]: - if inspect.isclass(state) and issubclass(state, State): - return state + cls = self.get_states_map()[state_cls] - try: - return self.get_states_map()[cast(Hashable, state)] - except KeyError: - raise ValueError(f'{state} is not a valid state') + return cls(self, **kwargs) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 293c680b..e615ee4a 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- """Module for process level communication functions and classes""" +from __future__ import annotations + import asyncio -import copy import logging from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast @@ -12,10 +13,7 @@ from .utils import PID_TYPE __all__ = [ - 'KILL_MSG', - 'PAUSE_MSG', - 'PLAY_MSG', - 'STATUS_MSG', + 'MessageBuilder', 'ProcessLauncher', 'RemoteProcessController', 'RemoteProcessThreadController', @@ -31,6 +29,7 @@ INTENT_KEY = 'intent' MESSAGE_KEY = 'message' +FORCE_KILL_KEY = 'force_kill' class Intent: @@ -42,10 +41,45 @@ class Intent: STATUS: str = 'status' -PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} -PLAY_MSG = {INTENT_KEY: Intent.PLAY} -KILL_MSG = {INTENT_KEY: Intent.KILL} -STATUS_MSG = {INTENT_KEY: Intent.STATUS} +MessageType = Dict[str, Any] + + +class MessageBuilder: + """MessageBuilder will construct different messages that can passing over communicator.""" + + @classmethod + def play(cls, text: str | None = None) -> MessageType: + """The play message send over communicator.""" + return { + INTENT_KEY: Intent.PLAY, + MESSAGE_KEY: text, + } + + @classmethod + def pause(cls, text: str | None = None) -> MessageType: + """The pause message send over communicator.""" + return { + INTENT_KEY: Intent.PAUSE, + MESSAGE_KEY: text, + } + + @classmethod + def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType: + """The kill message send over communicator.""" + return { + INTENT_KEY: Intent.KILL, + MESSAGE_KEY: text, + FORCE_KILL_KEY: force_kill, + } + + @classmethod + def status(cls, text: str | None = None) -> MessageType: + """The status message send over communicator.""" + return { + INTENT_KEY: Intent.STATUS, + MESSAGE_KEY: text, + } + TASK_KEY = 'task' TASK_ARGS = 'args' @@ -162,7 +196,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus': :param pid: the process id :return: the status response from the process """ - future = self._communicator.rpc_send(pid, STATUS_MSG) + future = self._communicator.rpc_send(pid, MessageBuilder.status()) result = await asyncio.wrap_future(future) return result @@ -174,11 +208,9 @@ async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pr :param msg: optional pause message :return: True if paused, False otherwise """ - message = copy.copy(PAUSE_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + msg = MessageBuilder.pause(text=msg) - pause_future = self._communicator.rpc_send(pid, message) + pause_future = self._communicator.rpc_send(pid, msg) # rpc_send return a thread future from communicator future = await asyncio.wrap_future(pause_future) # future is just returned from rpc call which return a kiwipy future @@ -192,12 +224,12 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': :param pid: the pid of the process to play :return: True if played, False otherwise """ - play_future = self._communicator.rpc_send(pid, PLAY_MSG) + play_future = self._communicator.rpc_send(pid, MessageBuilder.play()) future = await asyncio.wrap_future(play_future) result = await asyncio.wrap_future(future) return result - async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult': + async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult': """ Kill the process @@ -205,12 +237,11 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pro :param msg: optional kill message :return: True if killed, False otherwise """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = MessageBuilder.kill() # Wait for the communication to go through - kill_future = self._communicator.rpc_send(pid, message) + kill_future = self._communicator.rpc_send(pid, msg) future = await asyncio.wrap_future(kill_future) # Now wait for the kill to be enacted result = await asyncio.wrap_future(future) @@ -331,7 +362,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: :param pid: the process id :return: the status response from the process """ - return self._communicator.rpc_send(pid, STATUS_MSG) + return self._communicator.rpc_send(pid, MessageBuilder.status()) def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: """ @@ -342,11 +373,9 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu :return: a response future from the process to be paused """ - message = copy.copy(PAUSE_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + msg = MessageBuilder.pause(text=msg) - return self._communicator.rpc_send(pid, message) + return self._communicator.rpc_send(pid, msg) def pause_all(self, msg: Any) -> None: """ @@ -364,7 +393,7 @@ def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future: :return: a response future from the process to be played """ - return self._communicator.rpc_send(pid, PLAY_MSG) + return self._communicator.rpc_send(pid, MessageBuilder.play()) def play_all(self) -> None: """ @@ -372,7 +401,7 @@ def play_all(self) -> None: """ self._communicator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future: """ Kill the process @@ -381,18 +410,20 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fut :return: a response future from the process to be killed """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = MessageBuilder.kill() - return self._communicator.rpc_send(pid, message) + return self._communicator.rpc_send(pid, msg) - def kill_all(self, msg: Optional[Any]) -> None: + def kill_all(self, msg: Optional[MessageType]) -> None: """ Kill all processes that are subscribed to the same communicator :param msg: an optional pause message """ + if msg is None: + msg = MessageBuilder.kill() + self._communicator.broadcast_send(msg, subject=Intent.KILL) def continue_process( diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index cf29973a..d369a1e9 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import sys import traceback from enum import Enum @@ -8,6 +10,8 @@ import yaml from yaml.loader import Loader +from plumpy.process_comms import MessageBuilder, MessageType + try: import tblib @@ -48,7 +52,12 @@ class Interruption(Exception): # noqa: N818 class KillInterruption(Interruption): - pass + def __init__(self, msg: MessageType | None): + super().__init__() + if msg is None: + msg = MessageBuilder.kill() + + self.msg: MessageType = msg class PauseInterruption(Interruption): @@ -64,7 +73,7 @@ class Command(persistence.Savable): @auto_persist('msg') class Kill(Command): - def __init__(self, msg: Optional[Any] = None): + def __init__(self, msg: Optional[MessageType] = None): super().__init__() self.msg = msg @@ -76,7 +85,10 @@ class Pause(Command): @auto_persist('msg', 'data') class Wait(Command): def __init__( - self, continue_fn: Optional[Callable[..., Any]] = None, msg: Optional[Any] = None, data: Optional[Any] = None + self, + continue_fn: Optional[Callable[..., Any]] = None, + msg: Optional[Any] = None, + data: Optional[Any] = None, ): super().__init__() self.continue_fn = continue_fn @@ -349,13 +361,23 @@ def resume(self, value: Any = NULL) -> None: class Excepted(State): + """ + Excepted state, can optionally provide exception and trace_back + + :param exception: The exception instance + :param trace_back: An optional exception traceback + """ + LABEL = ProcessState.EXCEPTED EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' def __init__( - self, process: 'Process', exception: Optional[BaseException], trace_back: Optional[TracebackType] = None + self, + process: 'Process', + exception: Optional[BaseException], + trace_back: Optional[TracebackType] = None, ): """ :param process: The associated process @@ -387,15 +409,27 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi else: self.traceback = None - def get_exc_info(self) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: + def get_exc_info( + self, + ) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: """ Recreate the exc_info tuple and return it """ - return type(self.exception) if self.exception else None, self.exception, self.traceback + return ( + type(self.exception) if self.exception else None, + self.exception, + self.traceback, + ) @auto_persist('result', 'successful') class Finished(State): + """State for process is finished. + + :param result: The result of process + :param successful: Boolean for the exit code is ``0`` the process is successful. + """ + LABEL = ProcessState.FINISHED def __init__(self, process: 'Process', result: Any, successful: bool) -> None: @@ -406,13 +440,21 @@ def __init__(self, process: 'Process', result: Any, successful: bool) -> None: @auto_persist('msg') class Killed(State): + """ + Represents a state where a process has been killed. + + This state is used to indicate that a process has been terminated and can optionally + include a message providing details about the termination. + + :param msg: An optional message explaining the reason for the process termination. + """ + LABEL = ProcessState.KILLED - def __init__(self, process: 'Process', msg: Optional[str]): + def __init__(self, process: 'Process', msg: Optional[MessageType]): """ :param process: The associated process :param msg: Optional kill message - """ super().__init__(process) self.msg = msg diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ffddf7b5..0866ee41 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -26,6 +26,7 @@ Sequence, Tuple, Type, + TypeVar, Union, cast, ) @@ -39,15 +40,27 @@ import yaml from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed -from . import events, exceptions, futures, persistence, ports, process_comms, process_states, utils +from . import ( + events, + exceptions, + futures, + persistence, + ports, + process_comms, + process_states, + utils, +) from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper +from .process_comms import MESSAGE_KEY, MessageBuilder, MessageType from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected +T = TypeVar('T') + __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] _LOGGER = logging.getLogger(__name__) @@ -91,7 +104,13 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: @persistence.auto_persist( - '_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper' + '_pid', + '_creation_time', + '_future', + '_paused', + '_status', + '_pre_paused_status', + '_event_helper', ) class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ @@ -231,7 +250,9 @@ def get_description(cls) -> Dict[str, Any]: @classmethod def recreate_from( - cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + cls, + saved_state: SAVED_STATE_TYPE, + load_context: Optional[persistence.LoadSaveContext] = None, ) -> 'Process': """ Recreate a process from a saved state, passing any positional and @@ -314,14 +335,21 @@ def init(self) -> None: identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) except kiwipy.TimeoutError: - self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) + self.logger.exception( + 'Process<%s>: failed to register as a broadcast subscriber', + self.pid, + ) if not self._future.done(): def try_killing(future: futures.Future) -> None: if future.cancelled(): - if not self.kill('Killed by future being cancelled'): - self.logger.warning('Process<%s>: Failed to kill process on future cancel', self.pid) + msg = MessageBuilder.kill(text='Killed by future being cancelled') + if not self.kill(msg): + self.logger.warning( + 'Process<%s>: Failed to kill process on future cancel', + self.pid, + ) self._future.add_done_callback(try_killing) @@ -425,7 +453,13 @@ def launch( The process is started asynchronously, without blocking other task in the event loop. """ - process = process_class(inputs=inputs, pid=pid, logger=logger, loop=self.loop, communicator=self._communicator) + process = process_class( + inputs=inputs, + pid=pid, + logger=logger, + loop=self.loop, + communicator=self._communicator, + ) self.loop.create_task(process.step_until_terminated()) return process @@ -477,7 +511,7 @@ def killed(self) -> bool: """Return whether the process is killed.""" return self.state == process_states.ProcessState.KILLED - def killed_msg(self) -> Optional[str]: + def killed_msg(self) -> Optional[MessageType]: """Return the killed message.""" if isinstance(self._state, process_states.Killed): return self._state.msg @@ -529,7 +563,10 @@ def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> return handle def callback_excepted( - self, _callback: Callable[..., Any], exception: Optional[BaseException], trace: Optional[TracebackType] + self, + _callback: Callable[..., Any], + exception: Optional[BaseException], + trace: Optional[TracebackType], ) -> None: if self.state != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @@ -555,7 +592,7 @@ def _process_scope(self) -> Generator[None, None, None]: stack_copy.pop() PROCESS_STACK.set(stack_copy) - async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any) -> T: """ This method should be used to run all Process related functions and coroutines. If there is an exception the process will enter the EXCEPTED state. @@ -576,7 +613,9 @@ async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: An # region Persistence def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] + self, + out_state: SAVED_STATE_TYPE, + save_context: Optional[persistence.LoadSaveContext], ) -> None: """ Ask the process to save its current instance state. @@ -828,7 +867,9 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.ProcessState.FINISHED, result, False) + state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] + finished_state = state_cls(self, result=result, successful=False) + raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -857,10 +898,15 @@ def on_excepted(self) -> None: self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) @super_check - def on_kill(self, msg: Optional[str]) -> None: + def on_kill(self, msg: Optional[MessageType]) -> None: """Entering the KILLED state.""" - self.set_status(msg) - self.future().set_exception(exceptions.KilledError(msg)) + if msg is None: + msg_txt = '' + else: + msg_txt = msg[MESSAGE_KEY] or '' + + self.set_status(msg_txt) + self.future().set_exception(exceptions.KilledError(msg_txt)) @super_check def on_killed(self) -> None: @@ -906,7 +952,12 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An :param msg: the message :return: the outcome of processing the message, the return value will be sent back as a response to the sender """ - self.logger.debug("Process<%s>: received RPC message with communicator '%s': %r", self.pid, _comm, msg) + self.logger.debug( + "Process<%s>: received RPC message with communicator '%s': %r", + self.pid, + _comm, + msg, + ) intent = msg[process_comms.INTENT_KEY] @@ -915,7 +966,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An if intent == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.kill, msg=msg) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -935,7 +986,11 @@ def broadcast_receive( """ self.logger.debug( - "Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body + "Process<%s>: received broadcast message '%s' with communicator '%s': %r", + self.pid, + subject, + _comm, + body, ) # If we get a message we recognise then action it, otherwise ignore @@ -1001,13 +1056,20 @@ def close(self) -> None: # region State related methods def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: # If we are creating, then reraise instead of failing. if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace) + new_state = self._create_state_instance( + process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace + ) + self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1070,8 +1132,8 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: - # Ignore the next state - self.transition_to(process_states.ProcessState.KILLED, str(exception)) + new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=exception.msg) + self.transition_to(new_state) return True finally: self._killing = None @@ -1123,9 +1185,12 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back) + new_state = self._create_state_instance( + process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace_back + ) + self.transition_to(new_state) - def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: + def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ Kill the process :param msg: An optional kill message @@ -1151,7 +1216,8 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.ProcessState.KILLED, msg) + new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg) + self.transition_to(new_state) return True @property @@ -1168,7 +1234,10 @@ def create_initial_state(self) -> process_states.State: :return: A Created state """ - return cast(process_states.State, self.get_state_class(process_states.ProcessState.CREATED)(self, self.run)) + return cast( + process_states.State, + self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), + ) def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State: """ diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 5b4b73d8..3a1621a2 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -57,6 +57,7 @@ class Paused(state_machine.State): def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) + self._player = player self.playing_state = playing_state def __str__(self): @@ -64,7 +65,7 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(player=self.state_machine, track=track)) else: self.state_machine.transition_to(self.playing_state) @@ -80,7 +81,7 @@ def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(self.state_machine, track=track)) class CdPlayer(state_machine.StateMachine): @@ -107,12 +108,12 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, playing_state=self._state) + self.transition_to(Paused(self, playing_state=self._state)) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) def stop(self): - self.transition_to(Stopped) + self.transition_to(Stopped(self)) class TestStateMachine(unittest.TestCase): diff --git a/tests/persistence/test_inmemory.py b/tests/persistence/test_inmemory.py index b0db46e7..9e3141de 100644 --- a/tests/persistence/test_inmemory.py +++ b/tests/persistence/test_inmemory.py @@ -1,11 +1,9 @@ # -*- coding: utf-8 -*- import unittest -from ..utils import ProcessWithCheckpoint - import plumpy -import plumpy +from ..utils import ProcessWithCheckpoint class TestInMemoryPersister(unittest.TestCase): diff --git a/tests/persistence/test_pickle.py b/tests/persistence/test_pickle.py index dd68b4fd..da4ede51 100644 --- a/tests/persistence/test_pickle.py +++ b/tests/persistence/test_pickle.py @@ -5,10 +5,10 @@ if getattr(tempfile, 'TemporaryDirectory', None) is None: from backports import tempfile -from ..utils import ProcessWithCheckpoint - import plumpy +from ..utils import ProcessWithCheckpoint + class TestPicklePersister(unittest.TestCase): def test_save_load_roundtrip(self): diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 7223b888..a6249d10 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import asyncio -import copy import kiwipy import pytest @@ -196,8 +195,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - msg = copy.copy(process_comms.KILL_MSG) - msg[process_comms.MESSAGE_KEY] = 'bang bang, I shot you down' + msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down') sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) diff --git a/tests/test_expose.py b/tests/test_expose.py index 0f6f8087..c5e6014c 100644 --- a/tests/test_expose.py +++ b/tests/test_expose.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- import unittest -from .utils import NewLoopProcess - from plumpy.ports import PortNamespace from plumpy.process_spec import ProcessSpec from plumpy.processes import Process +from .utils import NewLoopProcess + def validator_function(input, port): pass diff --git a/tests/test_process_comms.py b/tests/test_process_comms.py index c59737ac..44947230 100644 --- a/tests/test_process_comms.py +++ b/tests/test_process_comms.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- import pytest -from tests import utils import plumpy from plumpy import process_comms +from tests import utils class Process(plumpy.Process): diff --git a/tests/test_processes.py b/tests/test_processes.py index faea9eae..7b21c463 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -2,18 +2,17 @@ """Process tests""" import asyncio -import copy import enum import unittest import kiwipy import pytest -from tests import utils import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import MessageBuilder from plumpy.utils import AttributesFrozendict +from tests import utils class ForgetToCallParent(plumpy.Process): @@ -323,8 +322,7 @@ def run(self, **kwargs): def test_kill(self): proc: Process = utils.DummyProcess() - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'Farewell!' + msg = MessageBuilder.kill(text='Farewell!') proc.kill(msg) self.assertTrue(proc.killed()) self.assertEqual(proc.killed_msg(), msg) @@ -430,8 +428,7 @@ class KillProcess(Process): after_kill = False def run(self, **kwargs): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = MessageBuilder.kill(text='killed') self.kill(msg) # The following line should be executed because kill will not # interrupt execution of a method call in the RUNNING state diff --git a/tests/utils.py b/tests/utils.py index f2a58dfc..13abc38c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,13 +3,12 @@ import asyncio import collections -import copy import unittest from collections.abc import Mapping import plumpy from plumpy import persistence, process_states, processes, utils -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import MessageBuilder Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) @@ -86,8 +85,7 @@ def last_step(self): class KillProcess(processes.Process): @utils.override def run(self): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = MessageBuilder.kill(text='killed') return process_states.Kill(msg=msg)