diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 656c8d7e..ad524fb2 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -88,7 +88,7 @@ jobs: run: pip install -e .[tests] - name: Run pytest - run: pytest -sv --cov=plumpy test + run: pytest -s --cov=plumpy tests - name: Create xml coverage run: coverage xml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b6e5416f..6756d22a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,7 +45,7 @@ jobs: run: pip install .[tests] - name: Run pytest - run: pytest -s --cov=plumpy test + run: pytest -s --cov=plumpy tests/ - name: Create xml coverage run: coverage xml diff --git a/test/__init__.py b/tests/__init__.py similarity index 100% rename from test/__init__.py rename to tests/__init__.py diff --git a/test/base/__init__.py b/tests/base/__init__.py similarity index 100% rename from test/base/__init__.py rename to tests/base/__init__.py diff --git a/test/base/test_statemachine.py b/tests/base/test_statemachine.py similarity index 88% rename from test/base/test_statemachine.py rename to tests/base/test_statemachine.py index bd006146..9d89a41a 100644 --- a/test/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -25,7 +25,7 @@ def __init__(self, player, track): super().__init__(player) self.track = track self._last_time = None - self._played = 0.0 + self._played = 0. def __str__(self): if self.in_state: @@ -40,7 +40,7 @@ def exit(self): super().exit() self._update_time() - def play(self, track=None): + def play(self, track=None): # pylint: disable=no-self-use, unused-argument return False def _update_time(self): @@ -55,7 +55,8 @@ class Paused(state_machine.State): TRANSITIONS = {STOP: STOPPED} def __init__(self, player, playing_state): - assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' + assert isinstance(playing_state, Playing), \ + 'Must provide the playing state to pause' super().__init__(player) self.playing_state = playing_state @@ -64,7 +65,7 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track) + self.state_machine.transition_to(Playing, track=track) else: self.state_machine.transition_to(self.playing_state) @@ -80,7 +81,7 @@ def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track) + self.state_machine.transition_to(Playing, track=track) class CdPlayer(state_machine.StateMachine): @@ -107,7 +108,7 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, self._state) + self.transition_to(Paused, playing_state=self._state) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) @@ -116,13 +117,14 @@ def stop(self): class TestStateMachine(unittest.TestCase): + def test_basic(self): cd_player = CdPlayer() self.assertEqual(cd_player.state, STOPPED) cd_player.play('Eminem - The Real Slim Shady') self.assertEqual(cd_player.state, PLAYING) - time.sleep(1.0) + time.sleep(1.) cd_player.pause() self.assertEqual(cd_player.state, PAUSED) diff --git a/test/base/test_utils.py b/tests/base/test_utils.py similarity index 100% rename from test/base/test_utils.py rename to tests/base/test_utils.py diff --git a/test/conftest.py b/tests/conftest.py similarity index 100% rename from test/conftest.py rename to tests/conftest.py diff --git a/test/notebooks/get_event_loop.ipynb b/tests/notebooks/get_event_loop.ipynb similarity index 100% rename from test/notebooks/get_event_loop.ipynb rename to tests/notebooks/get_event_loop.ipynb diff --git a/test/persistence/__init__.py b/tests/persistence/__init__.py similarity index 100% rename from test/persistence/__init__.py rename to tests/persistence/__init__.py diff --git a/test/persistence/test_inmemory.py b/tests/persistence/test_inmemory.py similarity index 100% rename from test/persistence/test_inmemory.py rename to tests/persistence/test_inmemory.py diff --git a/test/persistence/test_pickle.py b/tests/persistence/test_pickle.py similarity index 100% rename from test/persistence/test_pickle.py rename to tests/persistence/test_pickle.py diff --git a/test/rmq/__init__.py b/tests/rmq/__init__.py similarity index 100% rename from test/rmq/__init__.py rename to tests/rmq/__init__.py diff --git a/test/rmq/docker-compose.yml b/tests/rmq/docker-compose.yml similarity index 100% rename from test/rmq/docker-compose.yml rename to tests/rmq/docker-compose.yml diff --git a/test/rmq/test_communicator.py b/tests/rmq/test_communicator.py similarity index 100% rename from test/rmq/test_communicator.py rename to tests/rmq/test_communicator.py diff --git a/test/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py similarity index 93% rename from test/rmq/test_process_comms.py rename to tests/rmq/test_process_comms.py index 2859e7e8..6bbf27db 100644 --- a/test/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- import asyncio +import copy import kiwipy +from kiwipy import rmq import pytest import shortuuid -from kiwipy import rmq import plumpy -import plumpy.communications from plumpy import process_comms +import plumpy.communications from .. import utils @@ -43,11 +44,12 @@ def sync_controller(thread_communicator: rmq.RmqThreadCommunicator): class TestRemoteProcessController: + @pytest.mark.asyncio async def test_pause(self, thread_communicator, async_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) # Run the process in the background - asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 + asyncio.ensure_future(proc.step_until_terminated()) # Send a pause message result = await async_controller.pause_process(proc.pid) @@ -59,7 +61,7 @@ async def test_pause(self, thread_communicator, async_controller): async def test_play(self, thread_communicator, async_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) # Run the process in the background - asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 + asyncio.ensure_future(proc.step_until_terminated()) assert proc.pause() # Send a play message @@ -77,7 +79,7 @@ async def test_play(self, thread_communicator, async_controller): async def test_kill(self, thread_communicator, async_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) # Run the process in the event loop - asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 + asyncio.ensure_future(proc.step_until_terminated()) # Send a kill message and wait for it to be done result = await async_controller.kill_process(proc.pid) @@ -90,7 +92,7 @@ async def test_kill(self, thread_communicator, async_controller): async def test_status(self, thread_communicator, async_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) # Run the process in the background - asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 + asyncio.ensure_future(proc.step_until_terminated()) # Send a status message status = await async_controller.get_status(proc.pid) @@ -121,6 +123,7 @@ def on_broadcast_receive(**msg): class TestRemoteProcessThreadController: + @pytest.mark.asyncio async def test_pause(self, thread_communicator, sync_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) @@ -195,7 +198,10 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - sync_controller.kill_all('bang bang, I shot you down') + msg = copy.copy(process_comms.KILL_MSG) + msg[process_comms.MESSAGE_KEY] = 'bang bang, I shot you down' + + sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) @@ -203,7 +209,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): async def test_status(self, thread_communicator, sync_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) # Run the process in the background - asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 + asyncio.ensure_future(proc.step_until_terminated()) # Send a status message status_future = sync_controller.get_status(proc.pid) diff --git a/test/test_communications.py b/tests/test_communications.py similarity index 100% rename from test/test_communications.py rename to tests/test_communications.py diff --git a/test/test_events.py b/tests/test_events.py similarity index 100% rename from test/test_events.py rename to tests/test_events.py diff --git a/test/test_expose.py b/tests/test_expose.py similarity index 100% rename from test/test_expose.py rename to tests/test_expose.py diff --git a/test/test_lang.py b/tests/test_lang.py similarity index 100% rename from test/test_lang.py rename to tests/test_lang.py diff --git a/test/test_loaders.py b/tests/test_loaders.py similarity index 100% rename from test/test_loaders.py rename to tests/test_loaders.py diff --git a/test/test_persistence.py b/tests/test_persistence.py similarity index 100% rename from test/test_persistence.py rename to tests/test_persistence.py diff --git a/test/test_port.py b/tests/test_port.py similarity index 100% rename from test/test_port.py rename to tests/test_port.py diff --git a/test/test_process_comms.py b/tests/test_process_comms.py similarity index 87% rename from test/test_process_comms.py rename to tests/test_process_comms.py index ed2be6fa..6d3d335c 100644 --- a/test/test_process_comms.py +++ b/tests/test_process_comms.py @@ -1,17 +1,23 @@ # -*- coding: utf-8 -*- +import asyncio +from test import utils +import unittest + +from kiwipy import rmq import pytest import plumpy -from plumpy import process_comms -from test import utils +from plumpy import communications, process_comms class Process(plumpy.Process): + def run(self): pass class CustomObjectLoader(plumpy.DefaultObjectLoader): + def load_object(self, identifier): if identifier == 'jimmy': return Process @@ -35,6 +41,7 @@ async def test_continue(): pid = process.pid persister.save_checkpoint(process) del process + process = None result = await launcher._continue(None, **plumpy.create_continue_body(pid)[process_comms.TASK_ARGS]) assert result == utils.DummyProcess.EXPECTED_OUTPUTS @@ -42,7 +49,7 @@ async def test_continue(): @pytest.mark.asyncio async def test_loader_is_used(): - """Make sure that the provided class loader is used by the process launcher""" + """ Make sure that the provided class loader is used by the process launcher """ loader = CustomObjectLoader() proc = Process() persister = plumpy.InMemoryPersister(loader=loader) diff --git a/test/test_process_spec.py b/tests/test_process_spec.py similarity index 100% rename from test/test_process_spec.py rename to tests/test_process_spec.py diff --git a/test/test_processes.py b/tests/test_processes.py similarity index 95% rename from test/test_processes.py rename to tests/test_processes.py index ff7ba90d..f118b77c 100644 --- a/test/test_processes.py +++ b/tests/test_processes.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- """Process tests""" - import asyncio +import copy import enum +from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from test import utils import unittest import kiwipy @@ -11,10 +13,10 @@ import plumpy from plumpy import BundleKeys, Process, ProcessState from plumpy.utils import AttributesFrozendict -from test import utils class ForgetToCallParent(plumpy.Process): + def __init__(self, forget_on): super().__init__() self.forget_on = forget_on @@ -42,7 +44,9 @@ def on_kill(self, msg): @pytest.mark.asyncio async def test_process_scope(): + class ProcessTaskInterleave(plumpy.Process): + async def task(self, steps: list): steps.append(f'[{self.pid}] started') assert plumpy.Process.current() is self @@ -62,6 +66,7 @@ async def task(self, steps: list): class TestProcess(unittest.TestCase): + def test_spec(self): """ Check that the references to specs are doing the right thing... @@ -79,10 +84,12 @@ class Proc(utils.DummyProcess): self.assertIs(p.spec(), Proc.spec()) def test_dynamic_inputs(self): + class NoDynamic(Process): pass class WithDynamic(Process): + @classmethod def define(cls, spec): super().define(spec) @@ -95,7 +102,9 @@ def define(cls, spec): proc.execute() def test_inputs(self): + class Proc(Process): + @classmethod def define(cls, spec): super().define(spec) @@ -115,6 +124,7 @@ def test_raw_inputs(self): """ class Proc(Process): + @classmethod def define(cls, spec): super().define(spec) @@ -130,7 +140,9 @@ def define(cls, spec): self.assertDictEqual(dict(process.raw_inputs), {'a': 5, 'nested': {'a': 'value'}}) def test_inputs_default(self): + class Proc(utils.DummyProcess): + @classmethod def define(cls, spec): super().define(spec) @@ -189,6 +201,7 @@ def test_inputs_default_that_evaluate_to_false(self): for def_val in (True, False, 0, 1): class Proc(utils.DummyProcess): + @classmethod def define(cls, spec): super().define(spec) @@ -203,6 +216,7 @@ def test_nested_namespace_defaults(self): """Process with a default in a nested namespace should be created, even if top level namespace not supplied.""" class SomeProcess(Process): + @classmethod def define(cls, spec): super().define(spec) @@ -217,6 +231,7 @@ def test_raise_in_define(self): """Process which raises in its 'define' method. Check that the spec is not set.""" class BrokenProcess(Process): + @classmethod def define(cls, spec): super().define(spec) @@ -280,11 +295,12 @@ def test_run_kill(self): proc.execute() def test_get_description(self): + class ProcWithoutSpec(Process): pass class ProcWithSpec(Process): - """Process with a spec and a docstring""" + """ Process with a spec and a docstring """ @classmethod def define(cls, spec): @@ -310,7 +326,9 @@ def define(cls, spec): self.assertIsInstance(desc_with_spec['description'], str) def test_logging(self): + class LoggerTester(Process): + def run(self, **kwargs): self.logger.info('Test') @@ -319,11 +337,13 @@ def run(self, **kwargs): proc.execute() def test_kill(self): - proc = utils.DummyProcess() + proc: Process = utils.DummyProcess() - proc.kill('Farewell!') + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'Farewell!' + proc.kill(msg) self.assertTrue(proc.killed()) - self.assertEqual(proc.killed_msg(), 'Farewell!') + self.assertEqual(proc.killed_msg(), msg) self.assertEqual(proc.state, ProcessState.KILLED) def test_wait_continue(self): @@ -380,7 +400,7 @@ async def async_test(): self.assertTrue(proc.has_terminated()) self.assertEqual(proc.state, ProcessState.FINISHED) - loop.create_task(proc.step_until_terminated()) # noqa: RUF006 + loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) def test_pause_play_status_messaging(self): @@ -390,8 +410,8 @@ def test_pause_play_status_messaging(self): Any process can have its status set to a given message. When pausing, a pause message can be set for the status, which should store the current status, which should be restored, once the process is played again. """ - PLAY_STATUS = 'process was played by Hans Klok' # noqa: N806 - PAUSE_STATUS = 'process was paused by Evel Knievel' # noqa: N806 + PLAY_STATUS = 'process was played by Hans Klok' + PAUSE_STATUS = 'process was paused by Evel Knievel' loop = asyncio.get_event_loop() proc = utils.WaitForSignalProcess() @@ -415,18 +435,21 @@ async def async_test(): await proc.future() # Check it's done - loop.create_task(proc.step_until_terminated()) # noqa: RUF006 + loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) self.assertTrue(proc.has_terminated()) self.assertEqual(proc.state, ProcessState.FINISHED) def test_kill_in_run(self): + class KillProcess(Process): after_kill = False def run(self, **kwargs): - self.kill('killed') + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'killed' + self.kill(msg) # The following line should be executed because kill will not # interrupt execution of a method call in the RUNNING state self.after_kill = True @@ -439,7 +462,9 @@ def run(self, **kwargs): self.assertEqual(proc.state, ProcessState.KILLED) def test_kill_when_paused_in_run(self): + class PauseProcess(Process): + def run(self, **kwargs): self.pause() self.kill() @@ -457,6 +482,8 @@ def test_kill_when_paused(self): async def async_test(): await utils.run_until_waiting(proc) + saved_state = plumpy.Bundle(proc) + result = await proc.pause() self.assertTrue(result) self.assertTrue(proc.paused) @@ -467,7 +494,7 @@ async def async_test(): with self.assertRaises(plumpy.KilledError): result = await proc.future() - loop.create_task(proc.step_until_terminated()) # noqa: RUF006 + loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) self.assertEqual(proc.state, ProcessState.KILLED) @@ -479,7 +506,7 @@ def test_run_multiple(self): procs = [] for proc_class in utils.TEST_PROCESSES: proc = proc_class() - loop.create_task(proc.step_until_terminated()) # noqa: RUF006 + loop.create_task(proc.step_until_terminated()) procs.append(proc) tasks = asyncio.gather(*[p.future() for p in procs]) @@ -489,7 +516,9 @@ def test_run_multiple(self): self.assertDictEqual(proc_class.EXPECTED_OUTPUTS, result) def test_invalid_output(self): + class InvalidOutput(plumpy.Process): + def run(self): self.out('invalid', 5) @@ -510,26 +539,28 @@ def test_missing_output(self): self.assertFalse(proc.is_successful) def test_unsuccessful_result(self): - error_code = 256 + ERROR_CODE = 256 class Proc(Process): + @classmethod def define(cls, spec): super().define(spec) def run(self): - return plumpy.UnsuccessfulResult(error_code) + return plumpy.UnsuccessfulResult(ERROR_CODE) proc = Proc() proc.execute() - self.assertEqual(proc.result(), error_code) + self.assertEqual(proc.result(), ERROR_CODE) def test_pause_in_process(self): - """Test that we can pause and cancel that by playing within the process""" + """ Test that we can pause and cancel that by playing within the process """ test_case = self class TestPausePlay(plumpy.Process): + def run(self): fut = self.pause() test_case.assertIsInstance(fut, plumpy.Future) @@ -542,18 +573,19 @@ def run(self): proc = TestPausePlay() proc.add_process_listener(listener) - loop.create_task(proc.step_until_terminated()) # noqa: RUF006 + loop.create_task(proc.step_until_terminated()) loop.run_forever() self.assertTrue(proc.paused) self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) def test_pause_play_in_process(self): - """Test that we can pause and play that by playing within the process""" + """ Test that we can pause and play that by playing within the process """ test_case = self class TestPausePlay(plumpy.Process): + def run(self): fut = self.pause() test_case.assertIsInstance(fut, plumpy.Future) @@ -570,6 +602,7 @@ def test_process_stack(self): test_case = self class StackTest(plumpy.Process): + def run(self): test_case.assertIs(self, Process.current()) @@ -586,6 +619,7 @@ def test_nested(process): expect_true.append(process == Process.current()) class StackTest(plumpy.Process): + def run(self): # TODO: unexpected behaviour here # if assert error happend here not raise @@ -595,6 +629,7 @@ def run(self): test_nested(self) class ParentProcess(plumpy.Process): + def run(self): expect_true.append(self == Process.current()) StackTest().execute() @@ -617,17 +652,21 @@ def test_process_nested(self): """ class StackTest(plumpy.Process): + def run(self): pass class ParentProcess(plumpy.Process): + def run(self): StackTest().execute() ParentProcess().execute() def test_call_soon(self): + class CallSoon(plumpy.Process): + def run(self): self.call_soon(self.do_except) @@ -647,6 +686,7 @@ def test_exception_during_on_entered(self): """Test that an exception raised during ``on_entered`` will cause the process to be excepted.""" class RaisingProcess(Process): + def on_entered(self, from_state): if from_state is not None and from_state.label == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') @@ -662,7 +702,9 @@ def on_entered(self, from_state): assert str(process.exception()) == 'exception during on_entered' def test_exception_during_run(self): + class RaisingProcess(Process): + def run(self): raise RuntimeError('exception during run') @@ -720,7 +762,7 @@ async def async_test(): await proc_unbundled.step_until_terminated() self.assertEqual([SavePauseProc.step2.__name__], proc_unbundled.steps_ran) - loop.create_task(nsync_comeback.step_until_terminated()) # noqa: RUF006 + loop.create_task(nsync_comeback.step_until_terminated()) loop.run_until_complete(async_test()) def test_save_future(self): @@ -743,7 +785,7 @@ async def async_test(): self.assertListEqual([SavePauseProc.run.__name__, SavePauseProc.step2.__name__], proc_unbundled.steps_ran) - loop.create_task(proc_unbundled.step_until_terminated()) # noqa: RUF006 + loop.create_task(proc_unbundled.step_until_terminated()) loop.run_until_complete(async_test()) def test_created_bundle(self): @@ -797,7 +839,7 @@ async def async_test(): await loaded_proc.step_until_terminated() self.assertEqual(loaded_proc.outputs, {'finished': True}) - loop.create_task(proc.step_until_terminated()) # noqa: RUF006 + loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) def test_double_restart(self): @@ -822,11 +864,11 @@ async def async_test(): await loaded_proc.step_until_terminated() self.assertEqual(loaded_proc.outputs, {'finished': True}) - loop.create_task(proc.step_until_terminated()) # noqa: RUF006 + loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) def test_wait_save_continue(self): - """Test that process saved while in WAITING state restarts correctly when loaded""" + """ Test that process saved while in WAITING state restarts correctly when loaded """ loop = asyncio.get_event_loop() proc = utils.WaitForSignalProcess() @@ -842,14 +884,14 @@ async def async_test(): # Load from saved state and run again loader = plumpy.get_object_loader() proc2 = saved_state.unbundle(plumpy.LoadSaveContext(loader)) - asyncio.ensure_future(proc2.step_until_terminated()) # noqa: RUF006 + asyncio.ensure_future(proc2.step_until_terminated()) proc2.resume() result2 = await proc2.future() # Check results match self.assertEqual(result1, result2) - loop.create_task(proc.step_until_terminated()) # noqa: RUF006 + loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) def test_killed(self): @@ -869,6 +911,7 @@ def _check_round_trip(self, proc1): class TestProcessNamespace(unittest.TestCase): + def test_namespaced_process(self): """ Test that inputs in nested namespaces are properly validated and the returned @@ -876,6 +919,7 @@ def test_namespaced_process(self): """ class NameSpacedProcess(Process): + @classmethod def define(cls, spec): super().define(spec) @@ -900,6 +944,7 @@ def test_namespaced_process_inputs(self): """ class NameSpacedProcess(Process): + @classmethod def define(cls, spec): super().define(spec) @@ -925,6 +970,7 @@ def test_namespaced_process_dynamic(self): namespace = 'name.space' class DummyDynamicProcess(Process): + @classmethod def define(cls, spec): super().define(spec) @@ -934,7 +980,7 @@ def define(cls, spec): original_inputs = [1, 2, 3, 4] - inputs = {'name': {'space': {str(l): l for l in original_inputs}}} # noqa: E741 + inputs = {'name': {'space': {str(l): l for l in original_inputs}}} proc = DummyDynamicProcess(inputs=inputs) for label, value in proc.inputs['name']['space'].items(): @@ -951,12 +997,14 @@ def test_namespaced_process_outputs(self): namespace_nested = f'{namespace}.nested' class OutputMode(enum.Enum): + NONE = 0 DYNAMIC_PORT_NAMESPACE = 1 SINGLE_REQUIRED_PORT = 2 BOTH_SINGLE_AND_NAMESPACE = 3 class DummyDynamicProcess(Process): + @classmethod def define(cls, spec): super().define(spec) @@ -1015,6 +1063,7 @@ def run(self): class TestProcessEvents(unittest.TestCase): + def test_basic_events(self): proc = utils.DummyProcessWithOutput() events_tester = utils.ProcessListenerTester( @@ -1034,14 +1083,11 @@ def test_killed(self): def test_excepted(self): proc = utils.ExceptionProcess() - events_tester = utils.ProcessListenerTester( - proc, - ( - 'excepted', - 'running', - 'output_emitted', - ), - ) + events_tester = utils.ProcessListenerTester(proc, ( + 'excepted', + 'running', + 'output_emitted', + )) with self.assertRaises(RuntimeError): proc.execute() proc.result() @@ -1080,6 +1126,7 @@ def on_broadcast_receive(_comm, body, sender, subject, correlation_id): class _RestartProcess(utils.WaitForSignalProcess): + @classmethod def define(cls, spec): super().define(spec) diff --git a/test/test_utils.py b/tests/test_utils.py similarity index 100% rename from test/test_utils.py rename to tests/test_utils.py diff --git a/test/test_waiting_process.py b/tests/test_waiting_process.py similarity index 100% rename from test/test_waiting_process.py rename to tests/test_waiting_process.py diff --git a/test/test_workchains.py b/tests/test_workchains.py similarity index 100% rename from test/test_workchains.py rename to tests/test_workchains.py diff --git a/test/utils.py b/tests/utils.py similarity index 90% rename from test/utils.py rename to tests/utils.py index 66f4f1c3..3aaaf61a 100644 --- a/test/utils.py +++ b/tests/utils.py @@ -1,13 +1,17 @@ # -*- coding: utf-8 -*- """Utilities for tests""" - import asyncio import collections -import unittest from collections.abc import Mapping +import copy +import unittest + +import kiwipy.rmq +import shortuuid import plumpy from plumpy import persistence, process_states, processes, utils +from plumpy.process_comms import KILL_MSG, MESSAGE_KEY Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) @@ -22,9 +26,7 @@ class DummyProcess(processes.Process): """ EXPECTED_STATE_SEQUENCE = [ - process_states.ProcessState.CREATED, - process_states.ProcessState.RUNNING, - process_states.ProcessState.FINISHED, + process_states.ProcessState.CREATED, process_states.ProcessState.RUNNING, process_states.ProcessState.FINISHED ] EXPECTED_OUTPUTS = {} @@ -58,12 +60,14 @@ def run(self, **kwargs): class KeyboardInterruptProc(processes.Process): + @utils.override def run(self): raise KeyboardInterrupt() class ProcessWithCheckpoint(processes.Process): + @utils.override def run(self): return process_states.Continue(self.last_step) @@ -73,6 +77,7 @@ def last_step(self): class WaitForSignalProcess(processes.Process): + @utils.override def run(self): return process_states.Wait(self.last_step) @@ -82,13 +87,16 @@ def last_step(self): class KillProcess(processes.Process): + @utils.override def run(self): - return process_states.Kill('killed') + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'killed' + return process_states.Kill(msg=msg) class MissingOutputProcess(processes.Process): - """A process that does not generate a required output""" + """ A process that does not generate a required output """ @classmethod def define(cls, spec): @@ -97,6 +105,7 @@ def define(cls, spec): class NewLoopProcess(processes.Process): + def __init__(self, *args, **kwargs): kwargs['loop'] = plumpy.new_event_loop() super().__init__(*args, **kwargs) @@ -113,7 +122,8 @@ def called(cls, event): cls.called_events.append(event) def __init__(self, *args, **kwargs): - assert isinstance(self, processes.Process), 'Mixin has to be used with a type derived from a Process' + assert isinstance(self, processes.Process), \ + 'Mixin has to be used with a type derived from a Process' super().__init__(*args, **kwargs) self.__class__.called_events = [] @@ -159,6 +169,7 @@ def on_terminate(self): class ProcessEventsTester(EventsTesterMixin, processes.Process): + @classmethod def define(cls, spec): super().define(spec) @@ -186,6 +197,7 @@ def last_step(self): class TwoCheckpointNoFinish(ProcessEventsTester): + def run(self): self.out('test', 5) return process_states.Continue(self.middle_step) @@ -195,18 +207,21 @@ def middle_step(self): class ExceptionProcess(ProcessEventsTester): + def run(self): self.out('test', 5) raise RuntimeError('Great scott!') class ThreeStepsThenException(ThreeSteps): + @utils.override def last_step(self): raise RuntimeError('Great scott!') class ProcessListenerTester(plumpy.ProcessListener): + def __init__(self, process, expected_events): process.add_process_listener(self) self.expected_events = set(expected_events) @@ -238,6 +253,7 @@ def on_process_killed(self, process, msg): class Saver: + def __init__(self): self.snapshots = [] self.outputs = [] @@ -267,23 +283,23 @@ class ProcessSaver(plumpy.ProcessListener): """ def __del__(self): - global _ProcessSaver_Saver # noqa: PLW0602 - global _ProcessSaverProcReferences # noqa: PLW0602 + global _ProcessSaver_Saver + global _ProcessSaverProcReferences if _ProcessSaverProcReferences is not None and id(self) in _ProcessSaverProcReferences: del _ProcessSaverProcReferences[id(self)] if _ProcessSaver_Saver is not None and id(self) in _ProcessSaver_Saver: del _ProcessSaver_Saver[id(self)] def get_process(self): - global _ProcessSaverProcReferences # noqa: PLW0602 + global _ProcessSaverProcReferences return _ProcessSaverProcReferences[id(self)] def _save(self, p): - global _ProcessSaver_Saver # noqa: PLW0602 + global _ProcessSaver_Saver _ProcessSaver_Saver[id(self)]._save(p) def set_process(self, process): - global _ProcessSaverProcReferences # noqa: PLW0602 + global _ProcessSaverProcReferences _ProcessSaverProcReferences[id(self)] = process def __init__(self, proc): @@ -292,7 +308,7 @@ def __init__(self, proc): self.init_not_persistent(proc) def init_not_persistent(self, proc): - global _ProcessSaver_Saver # noqa: PLW0602 + global _ProcessSaver_Saver _ProcessSaver_Saver[id(self)] = Saver() self.set_process(proc) @@ -306,12 +322,12 @@ def capture(self): @property def snapshots(self): - global _ProcessSaver_Saver # noqa: PLW0602 + global _ProcessSaver_Saver return _ProcessSaver_Saver[id(self)].snapshots @property def outputs(self): - global _ProcessSaver_Saver # noqa: PLW0602 + global _ProcessSaver_Saver return _ProcessSaver_Saver[id(self)].outputs @utils.override @@ -345,11 +361,7 @@ def on_process_killed(self, process, msg): TEST_PROCESSES = [DummyProcess, DummyProcessWithOutput, DummyProcessWithDynamicOutput, ThreeSteps] TEST_WAITING_PROCESSES = [ - ProcessWithCheckpoint, - TwoCheckpointNoFinish, - ExceptionProcess, - ProcessEventsTester, - ThreeStepsThenException, + ProcessWithCheckpoint, TwoCheckpointNoFinish, ExceptionProcess, ProcessEventsTester, ThreeStepsThenException ] TEST_EXCEPTION_PROCESSES = [ExceptionProcess, ThreeStepsThenException, MissingOutputProcess] @@ -375,7 +387,7 @@ def check_process_against_snapshots(loop, proc_class, snapshots): for i, bundle in zip(list(range(0, len(snapshots))), snapshots): loaded = bundle.unbundle(plumpy.LoadSaveContext(loop=loop)) # the process listeners are persisted - saver = next(iter(loaded._event_helper._listeners)) + saver = list(loaded._event_helper._listeners)[0] assert isinstance(saver, ProcessSaver) # the process reference inside this particular implementation of process listener # cannot be persisted because of a circular reference. So we load it there @@ -394,7 +406,7 @@ def check_process_against_snapshots(loop, proc_class, snapshots): saver.snapshots[-j], snapshots[-j], saver.snapshots[-j], - exclude={'exception', '_listeners'}, + exclude={'exception', '_listeners'} ) j += 1 @@ -430,8 +442,9 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None): compare_value(bundle1, bundle2, list(v1), list(v2), exclude) elif isinstance(v1, set) and isinstance(v2, set): raise NotImplementedError('Comparison between sets not implemented') - elif v1 != v2: - raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}') + else: + if v1 != v2: + raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}') class TestPersister(persistence.Persister): @@ -440,7 +453,7 @@ class TestPersister(persistence.Persister): """ def save_checkpoint(self, process, tag=None): - """Create the checkpoint bundle""" + """ Create the checkpoint bundle """ persistence.Bundle(process) def load_checkpoint(self, pid, tag=None): @@ -460,7 +473,7 @@ def delete_process_checkpoints(self, pid): def run_until_waiting(proc): - """Set up a future that will be resolved on entering the WAITING state""" + """ Set up a future that will be resolved on entering the WAITING state """ from plumpy import ProcessState listener = plumpy.ProcessListener() @@ -481,7 +494,7 @@ def on_waiting(_waiting_proc): def run_until_paused(proc): - """Set up a future that will be resolved when the process is paused""" + """ Set up a future that will be resolved when the process is paused """ listener = plumpy.ProcessListener() paused = plumpy.Future()