From 30665d98cbba98191b6665fffca589d421f72c41 Mon Sep 17 00:00:00 2001 From: Shahriar Heidrich Date: Sun, 21 Jul 2024 17:26:56 +0200 Subject: [PATCH] Add strict parameter to stream.zip (issue #118, PR #119) * Add strict parameter to stream.zip (issue #118) * Use shortcut for anext called without default * Add tests for exception passthrough in zip * Add (failing) test case for early exit from zip * Exit from non-strict zip as early as possible Fixes failing test from previous commit. * Make UNSET an enum Co-authored-by: Vincent Michel * Make STOP_SENTINEL an enum Co-authored-by: Vincent Michel * Fix imports for enums * Move strict condition further up & fix typing * Update aiostream/stream/combine.py Fix Pyton 3.8 compat Co-authored-by: Vincent Michel * Move STOP_SENTINEL construction out of function * Type inner anext wrapper function * Improve un-overloaded anext() type signature * Use ellipsis instead of pass --------- Co-authored-by: Vincent Michel --- aiostream/aiter_utils.py | 28 ++++++++++++++++++++++++-- aiostream/stream/combine.py | 31 +++++++++++++++++++++------- tests/test_combine.py | 40 +++++++++++++++++++++++++++++++++++++ tests/test_core.py | 8 ++++---- 4 files changed, 94 insertions(+), 13 deletions(-) diff --git a/aiostream/aiter_utils.py b/aiostream/aiter_utils.py index 9f5e455..72194c7 100644 --- a/aiostream/aiter_utils.py +++ b/aiostream/aiter_utils.py @@ -5,6 +5,7 @@ import sys from types import TracebackType +import enum import warnings import functools from typing import ( @@ -19,6 +20,7 @@ AsyncIterator, Any, cast, + overload, ) if TYPE_CHECKING: @@ -46,16 +48,37 @@ # Magic method shorcuts +_UnsetType = enum.Enum("_UnsetType", "UNSET") +UNSET = _UnsetType.UNSET + + def aiter(obj: AsyncIterable[T]) -> AsyncIterator[T]: """Access aiter magic method.""" assert_async_iterable(obj) return obj.__aiter__() -def anext(obj: AsyncIterator[T]) -> Awaitable[T]: +@overload +def anext(obj: AsyncIterator[T]) -> Awaitable[T]: ... + + +@overload +def anext(obj: AsyncIterator[T], default: U) -> Awaitable[T | U]: ... + + +def anext(obj: AsyncIterator[T], default: U | _UnsetType = UNSET) -> Awaitable[T | U]: """Access anext magic method.""" assert_async_iterator(obj) - return obj.__anext__() + if default is UNSET: + return obj.__anext__() + + async def anext_default_handling_wrapper() -> T | U: + try: + return await obj.__anext__() + except StopAsyncIteration: + return default + + return anext_default_handling_wrapper() # Async / await helper functions @@ -109,6 +132,7 @@ def assert_async_iterator(obj: object) -> None: T = TypeVar("T", covariant=True) Self = TypeVar("Self", bound="AsyncIteratorContext[Any]") +U = TypeVar("U") class AsyncIteratorContext( diff --git a/aiostream/stream/combine.py b/aiostream/stream/combine.py index e7a4dbf..0ea7b43 100644 --- a/aiostream/stream/combine.py +++ b/aiostream/stream/combine.py @@ -4,6 +4,7 @@ import asyncio import builtins +import enum from typing import ( Awaitable, @@ -46,8 +47,14 @@ async def chain(*sources: AsyncIterable[T]) -> AsyncIterator[T]: yield item +_StopSentinelType = enum.Enum("_StopSentinelType", "STOP_SENTINEL") +STOP_SENTINEL = _StopSentinelType.STOP_SENTINEL + + @sources_operator -async def zip(*sources: AsyncIterable[T]) -> AsyncIterator[tuple[T, ...]]: +async def zip( + *sources: AsyncIterable[T], strict: bool = False +) -> AsyncIterator[tuple[T, ...]]: """Combine and forward the elements of several asynchronous sequences. Each generated value is a tuple of elements, using the same order as @@ -76,14 +83,24 @@ async def zip(*sources: AsyncIterable[T]) -> AsyncIterator[tuple[T, ...]]: await stack.enter_async_context(streamcontext(source)) for source in sources ] # Loop over items + items: list[T] while True: - try: - coros = builtins.map(anext, streamers) - items = await asyncio.gather(*coros) - except StopAsyncIteration: - break + if strict: + coros = (anext(streamer, STOP_SENTINEL) for streamer in streamers) + _items = await asyncio.gather(*coros) + if all(item == STOP_SENTINEL for item in _items): + break + elif any(item == STOP_SENTINEL for item in _items): + raise ValueError("iterables have different lengths") + # This holds because we've ruled out STOP_SENTINEL above: + items = cast("list[T]", _items) else: - yield tuple(items) + coros = (anext(streamer) for streamer in streamers) + try: + items = await asyncio.gather(*coros) + except StopAsyncIteration: + break + yield tuple(items) X = TypeVar("X", contravariant=True) diff --git a/tests/test_combine.py b/tests/test_combine.py index 27892a0..faa2184 100644 --- a/tests/test_combine.py +++ b/tests/test_combine.py @@ -29,10 +29,50 @@ async def test_zip(assert_run): expected = [(x,) * 3 for x in range(5)] await assert_run(ys, expected) + # Exceptions from iterables are propagated + xs = stream.zip(stream.range(2), stream.throw(AttributeError)) + with pytest.raises(AttributeError): + await xs + # Empty zip (issue #95) xs = stream.zip() await assert_run(xs, []) + # Strict mode (issue #118): Iterable length mismatch raises + xs = stream.zip(stream.range(2), stream.range(1), strict=True) + with pytest.raises(ValueError): + await xs + + # Strict mode (issue #118): No raise for matching-length iterables + xs = stream.zip(stream.range(2), stream.range(2), strict=True) + await assert_run(xs, [(0, 0), (1, 1)]) + + # Strict mode (issue #118): Exceptions from iterables are propagated + xs = stream.zip(stream.range(2), stream.throw(AttributeError), strict=True) + with pytest.raises(AttributeError): + await xs + + # Strict mode (issue #118): Non-strict mode works as before + xs = stream.zip(stream.range(2), stream.range(1)) + await assert_run(xs, [(0, 0)]) + + # Strict mode (issue #118): In particular, we stop immediately if any + # one iterable is exhausted, not waiting for the others + slow_iterable_continued_after_sleep = asyncio.Event() + + async def fast_iterable(): + yield 0 + await asyncio.sleep(1) + + async def slow_iterable(): + yield 0 + await asyncio.sleep(2) + slow_iterable_continued_after_sleep.set() + + xs = stream.zip(fast_iterable(), slow_iterable()) + await assert_run(xs, [(0, 0)]) + assert not slow_iterable_continued_after_sleep.is_set() + @pytest.mark.asyncio async def test_map(assert_run, assert_cleanup): diff --git a/tests/test_core.py b/tests/test_core.py index fabb97a..7e68a6d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -237,7 +237,7 @@ def test_introspection_for_sources_operator(): ) assert ( str(inspect.signature(original)) - == "(*sources: 'AsyncIterable[T]') -> 'AsyncIterator[tuple[T, ...]]'" + == "(*sources: 'AsyncIterable[T]', strict: 'bool' = False) -> 'AsyncIterator[tuple[T, ...]]'" ) # Check the stream operator @@ -251,7 +251,7 @@ def test_introspection_for_sources_operator(): assert stream.zip.raw.__doc__ == original_doc assert ( str(inspect.signature(stream.zip.raw)) - == "(*sources: 'AsyncIterable[T]') -> 'AsyncIterator[tuple[T, ...]]'" + == "(*sources: 'AsyncIterable[T]', strict: 'bool' = False) -> 'AsyncIterator[tuple[T, ...]]'" ) # Check the __call__ method @@ -260,7 +260,7 @@ def test_introspection_for_sources_operator(): assert stream.zip.__call__.__doc__ == original_doc assert ( str(inspect.signature(stream.zip.__call__)) - == "(*sources: 'AsyncIterable[T]') -> 'Stream[tuple[T, ...]]'" + == "(*sources: 'AsyncIterable[T]', strict: 'bool' = False) -> 'Stream[tuple[T, ...]]'" ) # Check the pipe method @@ -272,5 +272,5 @@ def test_introspection_for_sources_operator(): ) assert ( str(inspect.signature(stream.zip.pipe)) - == "(*sources: 'AsyncIterable[T]') -> 'Callable[[AsyncIterable[Any]], Stream[tuple[T, ...]]]'" + == "(*sources: 'AsyncIterable[T]', strict: 'bool' = False) -> 'Callable[[AsyncIterable[Any]], Stream[tuple[T, ...]]]'" )