diff --git a/LICENSE b/LICENSE index ef7cf27..99b19b6 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2019 - 2024 Max Fischer +Copyright (c) 2019 - 2024 Max Kühn Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/asyncstdlib/__init__.py b/asyncstdlib/__init__.py index a4a1e89..df9cab7 100644 --- a/asyncstdlib/__init__.py +++ b/asyncstdlib/__init__.py @@ -45,7 +45,7 @@ from .asynctools import borrow, scoped_iter, await_each, any_iter, apply, sync from .heapq import merge, nlargest, nsmallest -__version__ = "3.12.5" +__version__ = "3.13.0" __all__ = [ "anext", diff --git a/asyncstdlib/functools.py b/asyncstdlib/functools.py index d1f9355..a035b12 100644 --- a/asyncstdlib/functools.py +++ b/asyncstdlib/functools.py @@ -7,7 +7,6 @@ Generic, Generator, Optional, - Coroutine, AsyncContextManager, Type, cast, @@ -66,25 +65,25 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value!r})" -class _FutureCachedValue(Generic[R, T]): - """A placeholder object to control concurrent access to a cached awaitable value. +class _FutureCachedPropertyValue(Generic[R, T]): + """ + A placeholder object to control concurrent access to a cached awaitable value When given a lock to coordinate access, only the first task to await on a cached property triggers the underlying coroutine. Once a value has been produced, all tasks are unblocked and given the same, single value. - """ - __slots__ = ("_get_attribute", "_instance", "_name", "_lock") + __slots__ = ("_func", "_instance", "_name", "_lock") def __init__( self, - get_attribute: Callable[[T], Coroutine[Any, Any, R]], + func: Callable[[T], Awaitable[R]], instance: T, name: str, lock: AsyncContextManager[Any], ): - self._get_attribute = get_attribute + self._func = func self._instance = instance self._name = name self._lock = lock @@ -98,7 +97,6 @@ def _instance_value(self) -> Awaitable[R]: If the instance (no longer) has this attribute, it was deleted and the process is restarted by delegating to the descriptor. - """ try: return self._instance.__dict__[self._name] @@ -116,12 +114,17 @@ async def _await_impl(self) -> R: # the instance attribute is still this placeholder, and we # hold the lock. Start the getter to store the value on the # instance and return the value. - return await self._get_attribute(self._instance) + return await self._get_attribute() # another task produced a value, or the instance.__dict__ object was # deleted in the interim. return await stored + async def _get_attribute(self) -> R: + value = await self._func(self._instance) + self._instance.__dict__[self._name] = AwaitableValue(value) + return value + def __repr__(self) -> str: return ( f"<{type(self).__name__} for '{type(self._instance).__name__}." @@ -135,9 +138,10 @@ def __init__( getter: Callable[[T], Awaitable[R]], asynccontextmanager_type: Type[AsyncContextManager[Any]] = nullcontext, ): - self.func = getter + self.func = self.__wrapped__ = getter self.attrname = None self.__doc__ = getter.__doc__ + self.__module__ = getter.__module__ self._asynccontextmanager_type = asynccontextmanager_type def __set_name__(self, owner: Any, name: str) -> None: @@ -175,19 +179,12 @@ def __get__( # on this instance. It takes care of coordinating between different # tasks awaiting on the placeholder until the cached value has been # produced. - wrapper = _FutureCachedValue( - self._get_attribute, instance, name, self._asynccontextmanager_type() + wrapper = _FutureCachedPropertyValue( + self.func, instance, name, self._asynccontextmanager_type() ) cache[name] = wrapper return wrapper - async def _get_attribute(self, instance: T) -> R: - value = await self.func(instance) - name = self.attrname - assert name is not None # enforced in __get__ - instance.__dict__[name] = AwaitableValue(value) - return value - def cached_property( type_or_getter: Union[Type[AsyncContextManager[Any]], Callable[[T], Awaitable[R]]], diff --git a/asyncstdlib/itertools.py b/asyncstdlib/itertools.py index cd737ba..006da37 100644 --- a/asyncstdlib/itertools.py +++ b/asyncstdlib/itertools.py @@ -32,7 +32,6 @@ zip, enumerate as aenumerate, iter as aiter, - tuple as atuple, ) S = TypeVar("S") @@ -122,17 +121,31 @@ async def accumulate(iterable, function, *, initial): yield value -async def batched(iterable: AnyIterable[T], n: int) -> AsyncIterator[Tuple[T, ...]]: +async def batched( + iterable: AnyIterable[T], n: int, strict: bool = False +) -> AsyncIterator[Tuple[T, ...]]: """ Batch the ``iterable`` to tuples of the length ``n``. - This lazily exhausts ``iterable`` and returns each batch as soon as it's ready. + This lazily exhausts ``iterable`` and returns each batch as soon as it is ready. + If ``strict`` is :py:data:`True` and the last batch is smaller than ``n``, + :py:exc:`ValueError` is raised. """ if n < 1: raise ValueError("n must be at least one") async with ScopedIter(iterable) as item_iter: - while batch := await atuple(islice(_borrow(item_iter), n)): - yield batch + batch: list[T] = [] + try: + while True: + batch.clear() + for _ in range(n): + batch.append(await anext(item_iter)) + yield tuple(batch) + except StopAsyncIteration: + if batch: + if strict and len(batch) < n: + raise ValueError("batched(): incomplete batch") from None + yield tuple(batch) class chain(AsyncIterator[T]): diff --git a/asyncstdlib/itertools.pyi b/asyncstdlib/itertools.pyi index a525b6a..f65ff6d 100644 --- a/asyncstdlib/itertools.pyi +++ b/asyncstdlib/itertools.pyi @@ -32,27 +32,33 @@ def accumulate( initial: T1, ) -> AsyncIterator[T1]: ... @overload -def batched(iterable: AnyIterable[T], n: Literal[1]) -> AsyncIterator[tuple[T]]: ... +def batched( + iterable: AnyIterable[T], n: Literal[1], strict: bool = ... +) -> AsyncIterator[tuple[T]]: ... @overload -def batched(iterable: AnyIterable[T], n: Literal[2]) -> AsyncIterator[tuple[T, T]]: ... +def batched( + iterable: AnyIterable[T], n: Literal[2], strict: bool = ... +) -> AsyncIterator[tuple[T, T]]: ... @overload def batched( - iterable: AnyIterable[T], n: Literal[3] + iterable: AnyIterable[T], n: Literal[3], strict: bool = ... ) -> AsyncIterator[tuple[T, T, T]]: ... @overload def batched( - iterable: AnyIterable[T], n: Literal[4] + iterable: AnyIterable[T], n: Literal[4], strict: bool = ... ) -> AsyncIterator[tuple[T, T, T, T]]: ... @overload def batched( - iterable: AnyIterable[T], n: Literal[5] + iterable: AnyIterable[T], n: Literal[5], strict: bool = ... ) -> AsyncIterator[tuple[T, T, T, T, T]]: ... @overload def batched( - iterable: AnyIterable[T], n: Literal[6] + iterable: AnyIterable[T], n: Literal[6], strict: bool = ... ) -> AsyncIterator[tuple[T, T, T, T, T, T]]: ... @overload -def batched(iterable: AnyIterable[T], n: int) -> AsyncIterator[tuple[T, ...]]: ... +def batched( + iterable: AnyIterable[T], n: int, strict: bool = ... +) -> AsyncIterator[tuple[T, ...]]: ... class chain(AsyncIterator[T]): __slots__: tuple[str, ...] diff --git a/docs/conf.py b/docs/conf.py index 6500c39..b54edd3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,7 +25,7 @@ # -- Project information ----------------------------------------------------- project = "asyncstdlib" -author = "Max Fischer" +author = "Max Kühn" copyright = f"2019-2024 {author}" # The short X.Y version diff --git a/docs/source/api/itertools.rst b/docs/source/api/itertools.rst index 4eb0ca3..733332f 100644 --- a/docs/source/api/itertools.rst +++ b/docs/source/api/itertools.rst @@ -86,11 +86,15 @@ Iterator splitting .. versionadded:: 3.10.0 -.. autofunction:: batched(iterable: (async) iter T, n: int) +.. autofunction:: batched(iterable: (async) iter T, n: int, strict: bool = False) :async-for: :T .. versionadded:: 3.11.0 + .. versionadded:: 3.13.0 + + The ``strict`` parameter. + .. py:function:: groupby(iterable: (async) iter T) :async-for: :(T, async iter T) :noindex: diff --git a/pyproject.toml b/pyproject.toml index c47e790..703b31c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "flit_core.buildapi" dynamic = ["version", "description"] name = "asyncstdlib" authors = [ - {name = "Max Fischer", email = "maxfischer2781@gmail.com"}, + {name = "Max Kühn", email = "maxfischer2781@gmail.com"}, ] readme = "README.rst" classifiers = [ diff --git a/unittests/test_itertools.py b/unittests/test_itertools.py index 8fa8349..5f88e96 100644 --- a/unittests/test_itertools.py +++ b/unittests/test_itertools.py @@ -79,6 +79,19 @@ async def test_batched_invalid(length): await a.list(a.batched(range(10), length)) +@sync +@pytest.mark.parametrize("values", ([1, 2, 3, 4], [1, 2, 3, 4, 5], [1])) +async def test_batched_strict(values: "list[int]"): + for n in range(1, len(values) + 1): + batches = a.batched(values, n, strict=True) + if len(values) % n == 0: + assert values == list(await a.reduce(lambda a, b: a + b, batches)) + else: + assert await a.anext(batches) + with pytest.raises(ValueError): + await a.list(batches) + + @sync async def test_cycle(): async for _ in a.cycle([]):