Skip to content

Commit

Permalink
test_statemachine.py WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Nov 29, 2024
1 parent 103193c commit 9619f49
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 71 deletions.
7 changes: 2 additions & 5 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,7 @@ def do_exit(self) -> None:


@runtime_checkable
class StateP(Protocol):
LABEL: ClassVar[str]

# FIXME: fix the LABEL_TYPE
ALLOWED: ClassVar[set[LABEL_TYPE]]
class StateProtocol(Protocol):

def do_enter(self) -> None:
...
Expand All @@ -183,6 +179,7 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'Sta

@classmethod
def is_terminal(cls) -> bool:
# FIXME: not necessary only for fit with legacy state
...


Expand Down
152 changes: 86 additions & 66 deletions test/base/test_statemachine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# -*- coding: utf-8 -*-
import time
from typing import Any, Hashable
import unittest

from plumpy.base import state_machine
from plumpy.exceptions import InvalidStateError
from plumpy.process_states import State

# Events
PLAY = 'Play'
Expand All @@ -15,39 +18,51 @@
STOPPED = 'Stopped'


class Playing(state_machine.State):
class Playing:
LABEL = PLAYING
ALLOWED = {PAUSED, STOPPED}
TRANSITIONS = {STOP: STOPPED}

def __init__(self, player, track):
assert track is not None, 'Must provide a track name'
super().__init__(player)
self.state_machine: state_machine.StateMachine = player
self.track = track
self._last_time = None
self._played = 0.0
self.in_state = False

def __str__(self):
if self.in_state:
self._update_time()
return f'> {self.track} ({self._played}s)'

def enter(self):
super().enter()
self._last_time = time.time()

def exit(self):
super().exit()
self._update_time()
def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State':
return self.state_machine.create_state(state_label, *args, **kwargs)

def play(self, track=None):
return False

def do_enter(self) -> None:
self.in_state = True
self._last_time = time.time()

def do_exit(self) -> None:
"""Exiting the state"""
if self.is_terminal():
raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

self._update_time()
self.in_state = False

def _update_time(self):
current_time = time.time()
self._played += current_time - self._last_time
self._last_time = current_time

@classmethod
def is_terminal(cls) -> bool:
return False


class Paused(state_machine.State):
LABEL = PAUSED
Expand Down Expand Up @@ -82,60 +97,65 @@ def __str__(self):
def play(self, track):
self.state_machine.transition_to(Playing, track)


class CdPlayer(state_machine.StateMachine):
STATES = (Stopped, Playing, Paused)

def __init__(self):
super().__init__()
self.add_state_event_callback(
state_machine.StateEventHook.ENTERING_STATE, lambda _s, _h, state: self.entering(state)
)
self.add_state_event_callback(state_machine.StateEventHook.EXITING_STATE, lambda _s, _h, _st: self.exiting())

def entering(self, state):
print(f'Entering {state}')
print(self._state)

def exiting(self):
print(f'Exiting {self.state}')
print(self._state)

@state_machine.event(to_states=Playing)
def play(self, track=None):
return self._state.play(track)

@state_machine.event(from_states=Playing, to_states=Paused)
def pause(self):
self.transition_to(Paused, self._state)
return True

@state_machine.event(from_states=(Playing, Paused), to_states=Stopped)
def stop(self):
self.transition_to(Stopped)


class TestStateMachine(unittest.TestCase):
def test_basic(self):
cd_player = CdPlayer()
self.assertEqual(cd_player.state, STOPPED)

cd_player.play('Eminem - The Real Slim Shady')
self.assertEqual(cd_player.state, PLAYING)
time.sleep(1.0)

cd_player.pause()
self.assertEqual(cd_player.state, PAUSED)

cd_player.play()
self.assertEqual(cd_player.state, PLAYING)

self.assertEqual(cd_player.play(), False)

cd_player.stop()
self.assertEqual(cd_player.state, STOPPED)

def test_invalid_event(self):
cd_player = CdPlayer()
with self.assertRaises(AssertionError):
cd_player.play()
def test_state_protocol():
assert issubclass(Playing, state_machine.StateProtocol)
assert issubclass(Paused, state_machine.StateProtocol)
assert issubclass(Stopped, state_machine.StateProtocol)


# class CdPlayer(state_machine.StateMachine):
# STATES = (Stopped, Playing, Paused)
#
# def __init__(self):
# super().__init__()
# self.add_state_event_callback(
# state_machine.StateEventHook.ENTERING_STATE, lambda _s, _h, state: self.entering(state)
# )
# self.add_state_event_callback(state_machine.StateEventHook.EXITING_STATE, lambda _s, _h, _st: self.exiting())
#
# def entering(self, state):
# print(f'Entering {state}')
# print(self._state)
#
# def exiting(self):
# print(f'Exiting {self.state}')
# print(self._state)
#
# @state_machine.event(to_states=Playing)
# def play(self, track=None):
# return self._state.play(track)
#
# @state_machine.event(from_states=Playing, to_states=Paused)
# def pause(self):
# self.transition_to(Paused, self._state)
# return True
#
# @state_machine.event(from_states=(Playing, Paused), to_states=Stopped)
# def stop(self):
# self.transition_to(Stopped)


# class TestStateMachine(unittest.TestCase):
# def test_basic(self):
# cd_player = CdPlayer()
# self.assertEqual(cd_player.state, STOPPED)
#
# cd_player.play('Eminem - The Real Slim Shady')
# self.assertEqual(cd_player.state, PLAYING)
# time.sleep(1.0)
#
# cd_player.pause()
# self.assertEqual(cd_player.state, PAUSED)
#
# cd_player.play()
# self.assertEqual(cd_player.state, PLAYING)
#
# self.assertEqual(cd_player.play(), False)
#
# cd_player.stop()
# self.assertEqual(cd_player.state, STOPPED)
#
# def test_invalid_event(self):
# cd_player = CdPlayer()
# with self.assertRaises(AssertionError):
# cd_player.play()

0 comments on commit 9619f49

Please sign in to comment.