diff --git a/goless/__init__.py b/goless/__init__.py index 27ad68f..92d5480 100644 --- a/goless/__init__.py +++ b/goless/__init__.py @@ -19,7 +19,7 @@ # noinspection PyUnresolvedReferences from .channels import chan, ChannelClosed # noinspection PyUnresolvedReferences -from .selecting import dcase, rcase, scase, select +from .selecting import dcase, rcase, scase, select, select_ok version_info = 0, 7, 2 diff --git a/goless/selecting.py b/goless/selecting.py index 29c0114..956e38a 100644 --- a/goless/selecting.py +++ b/goless/selecting.py @@ -1,4 +1,5 @@ from .backends import current as _be, Deadlock as _Deadlock +from .channels import ChannelClosed # noinspection PyPep8Naming,PyShadowingNames @@ -10,7 +11,7 @@ def __init__(self, chan): self.chan = chan def ready(self): - return self.chan.recv_ready() + return self.chan is not None and (self.chan._closed or self.chan.recv_ready()) def exec_(self): return self.chan.recv() @@ -25,7 +26,7 @@ def __init__(self, chan, value): self.value = value def ready(self): - return self.chan.send_ready() + return self.chan is not None and (self.chan._closed or self.chan.send_ready()) def exec_(self): self.chan.send(self.value) @@ -38,20 +39,15 @@ def ready(self): return False -def select(*cases): +def select_ok(*cases): """ - Select the first case that becomes ready. - If a default case (:class:`goless.dcase`) is present, - return that if no other cases are ready. - If there is no default case and no case is ready, - block until one becomes ready. - - See Go's ``reflect.Select`` method for an analog - (http://golang.org/pkg/reflect/#Select). + Select the first case that becomes ready, including an ``ok`` indication. + This is the same as the ``select`` method except than an ``ok`` indication + is included, allowing checks for closed channels. :param cases: List of case instances, such as :class:`goless.rcase`, :class:`goless.scase`, or :class:`goless.dcase`. - :return: ``(chosen case, received value)``. + :return: ``(chosen case, received value, ok indication)``. If the chosen case is not an :class:`goless.rcase`, it will be None. """ if len(cases) == 0: @@ -70,13 +66,16 @@ def select(*cases): default = None for c in cases: if c.ready(): - return c, c.exec_() + try: + return c, c.exec_(), True + except ChannelClosed: + return c, None, False if isinstance(c, dcase): assert default is None, 'Only one default case is allowd.' default = c if default is not None: # noinspection PyCallingNonCallable - return default, None + return default, None, True # We need to check for deadlocks before selecting. # We can't rely on the underlying backend to do it, @@ -89,5 +88,33 @@ def select(*cases): while True: for c in cases: if c.ready(): - return c, c.exec_() + try: + return c, c.exec_(), True + except ChannelClosed: + return c, None, False _be.yield_() + + +def select(*cases): + """ + Select the first case that becomes ready. + If a default case (:class:`goless.dcase`) is present, + return that if no other cases are ready. + If there is no default case and no case is ready, + block until one becomes ready. + + See Go's ``reflect.Select`` method for an analog + (http://golang.org/pkg/reflect/#Select). + + :param cases: List of case instances, such as + :class:`goless.rcase`, :class:`goless.scase`, or :class:`goless.dcase`. + :return: ``(chosen case, received value)``. + If the chosen case is not an :class:`goless.rcase`, it will be None. + """ + result = select_ok(*cases) + if result is not None: + chosen, value, ok = result + if not ok: + raise ChannelClosed() + result = chosen, value + return result diff --git a/tests/test_select.py b/tests/test_select.py index aa1a659..e717e3d 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -1,5 +1,6 @@ import goless from goless.backends import current as be +from goless.channels import ChannelClosed from . import BaseTests @@ -93,6 +94,45 @@ def test_select_chooses_ready_selection(self): self.assertIs(result, cases[1]) self.assertEqual(val, 3) + def test_select_ok_default_is_ok(self): + cases = [goless.rcase(self.chan1), goless.dcase()] + result, val, ok = goless.select_ok(cases) + self.assertIs(result, cases[1]) + self.assertTrue(ok) + + def test_select_ok_ignores_null_chan(self): + cases = [goless.scase(None, None), goless.rcase(None), goless.dcase()] + result, val, ok = goless.select_ok(cases) + self.assertIs(result, cases[2]) + self.assertTrue(ok) + + def test_select_ok_chooses_closed_over_default(self): + readychan = goless.chan(1) + readychan.send(3) + readychan.close() + cases = [goless.rcase(readychan), goless.dcase()] + + result, val, ok = goless.select_ok(cases) + self.assertIs(result, cases[0]) + self.assertEqual(val, 3) + self.assertTrue(ok) + + result, val, ok = goless.select_ok(cases) + self.assertIs(result, cases[0]) + self.assertIsNone(val) + self.assertFalse(ok) + + result, val, ok = goless.select_ok(cases) + self.assertIs(result, cases[0]) + self.assertIsNone(val) + self.assertFalse(ok) + + def test_select_raises_if_closed(self): + self.chan1.close() + cases = [goless.rcase(self.chan1), goless.dcase()] + with self.assertRaises(ChannelClosed): + goless.select(cases) + def test_select_no_default_no_ready_blocks(self): chan1 = goless.chan() chan2 = goless.chan()