diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d1e0831..5c4b513 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,10 +15,11 @@ repos: hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.3.0 + rev: v1.6.1 hooks: - id: mypy - files: ^aiostream/ + files: ^(?!tests) + types: [python] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.0.272 diff --git a/aiostream/core.py b/aiostream/core.py index e224a8e..a11f082 100644 --- a/aiostream/core.py +++ b/aiostream/core.py @@ -109,14 +109,14 @@ def __await__(self) -> Generator[Any, None, T]: """ return wait_stream(self).__await__() - def __or__(self, func: Callable[[BaseStream[T]], BaseStream[X]]) -> BaseStream[X]: + def __or__(self, func: Callable[[BaseStream[T]], X]) -> X: """Pipe protocol. Allow to pipe stream operators. """ return func(self) - def __add__(self, value: BaseStream[X]) -> BaseStream[Union[X, T]]: + def __add__(self, value: AsyncIterable[X]) -> Stream[Union[X, T]]: """Addition protocol. Concatenate with a given asynchronous sequence. @@ -125,7 +125,7 @@ def __add__(self, value: BaseStream[X]) -> BaseStream[Union[X, T]]: return chain(self, value) - def __getitem__(self, value: Union[int, slice]) -> BaseStream[T]: + def __getitem__(self, value: Union[int, slice]) -> Stream[T]: """Get item protocol. Accept index or slice to extract the corresponding item(s) @@ -276,7 +276,9 @@ def raw( ) -> AsyncIterator[T]: ... - def pipe(self, source: AsyncIterable[A]) -> Stream[T]: + def pipe( + self, *args: P.args, **kwargs: P.kwargs + ) -> Callable[[AsyncIterable[A]], Stream[T]]: ... @@ -555,7 +557,7 @@ def pipe( "__doc__": doc, "raw": staticmethod(raw), "original": staticmethod(original), - "pipe": classmethod(pipe), + "pipe": classmethod(pipe), # type: ignore[arg-type] } # Create operator class diff --git a/aiostream/pipe.py b/aiostream/pipe.py index 1dd44bd..c27e101 100644 --- a/aiostream/pipe.py +++ b/aiostream/pipe.py @@ -3,19 +3,37 @@ from . import stream -__all__: list[str] = [] - - -def update_pipe_module() -> None: - """Populate the pipe module dynamically.""" - module_dir = __all__ - operators = stream.__dict__ - for key, value in operators.items(): - if getattr(value, "pipe", None): - globals()[key] = value.pipe - if key not in module_dir: - module_dir.append(key) - - -# Populate the module -update_pipe_module() +accumulate = stream.accumulate.pipe +action = stream.action.pipe +amap = stream.amap.pipe +chain = stream.chain.pipe +chunks = stream.chunks.pipe +concat = stream.concat.pipe +concatmap = stream.concatmap.pipe +cycle = stream.cycle.pipe +delay = stream.delay.pipe +dropwhile = stream.dropwhile.pipe +enumerate = stream.enumerate.pipe +filter = stream.filter.pipe +flatmap = stream.flatmap.pipe +flatten = stream.flatten.pipe +getitem = stream.getitem.pipe +list = stream.list.pipe +map = stream.map.pipe +merge = stream.merge.pipe +print = stream.print.pipe +reduce = stream.reduce.pipe +skip = stream.skip.pipe +skiplast = stream.skiplast.pipe +smap = stream.smap.pipe +spaceout = stream.spaceout.pipe +starmap = stream.starmap.pipe +switch = stream.switch.pipe +switchmap = stream.switchmap.pipe +take = stream.take.pipe +takelast = stream.takelast.pipe +takewhile = stream.takewhile.pipe +timeout = stream.timeout.pipe +until = stream.until.pipe +zip = stream.zip.pipe +ziplatest = stream.ziplatest.pipe diff --git a/examples/demo.py b/examples/demo.py index 1c6ca32..e1352c6 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -3,12 +3,16 @@ from aiostream import pipe, stream -async def main(): +def square(x: int, *_: int) -> int: + return x**2 + + +async def main() -> None: # Create a counting stream with a 0.2 seconds interval xs = stream.count(interval=0.2) # Operators can be piped using '|' - ys = xs | pipe.map(lambda x: x**2) + ys = xs | pipe.map(square) # Streams can be sliced zs = ys[1:10:2] diff --git a/examples/extra.py b/examples/extra.py index bab5e9c..5a2084e 100644 --- a/examples/extra.py +++ b/examples/extra.py @@ -1,11 +1,14 @@ import asyncio import random as random_module +from typing import AsyncIterable, AsyncIterator from aiostream import operator, pipable_operator, pipe, streamcontext @operator -async def random(offset=0.0, width=1.0, interval=0.1): +async def random( + offset: float = 0.0, width: float = 1.0, interval: float = 0.1 +) -> AsyncIterator[float]: """Generate a stream of random numbers.""" while True: await asyncio.sleep(interval) @@ -13,7 +16,9 @@ async def random(offset=0.0, width=1.0, interval=0.1): @pipable_operator -async def power(source, exponent): +async def power( + source: AsyncIterable[float], exponent: float | int +) -> AsyncIterator[float]: """Raise the elements of an asynchronous sequence to the given power.""" async with streamcontext(source) as streamer: async for item in streamer: @@ -21,12 +26,12 @@ async def power(source, exponent): @pipable_operator -def square(source): +def square(source: AsyncIterable[float]) -> AsyncIterator[float]: """Square the elements of an asynchronous sequence.""" return power.raw(source, 2) -async def main(): +async def main() -> None: xs = ( random() # Stream random numbers | square.pipe() # Square the values diff --git a/examples/norm_server.py b/examples/norm_server.py index 5aa955a..a1283f2 100644 --- a/examples/norm_server.py +++ b/examples/norm_server.py @@ -14,6 +14,7 @@ [...] """ +import math import asyncio from aiostream import pipe, stream @@ -40,22 +41,24 @@ # Client handler -async def euclidean_norm_handler(reader, writer): +async def euclidean_norm_handler( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter +) -> None: # Define lambdas - def strip(x): + def strip(x: bytes, *_: object) -> str: return x.decode().strip() - def nonempty(x): + def nonempty(x: str) -> bool: return x != "" - def square(x): + def square(x: float, *_: object) -> float: return x**2 - def write_cursor(x): + def write_cursor(_: float) -> None: return writer.write(b"> ") - def square_root(x): - return x**0.5 + def square_root(x: float, *_: object) -> float: + return math.sqrt(x) # Create awaitable handle_request = ( @@ -67,7 +70,7 @@ def square_root(x): | pipe.map(square) | pipe.print("square: {:.2f}") | pipe.action(write_cursor) - | pipe.accumulate(initializer=0) + | pipe.accumulate(initializer=0.0) | pipe.map(square_root) | pipe.print("norm -> {:.2f}") ) @@ -86,7 +89,7 @@ def square_root(x): # Main function -async def main(bind="127.0.0.1", port=8888): +async def main(bind: str = "127.0.0.1", port: int = 8888) -> None: # Start the server server = await asyncio.start_server(euclidean_norm_handler, bind, port) diff --git a/examples/preserve.py b/examples/preserve.py index d6faf48..9fedb0b 100644 --- a/examples/preserve.py +++ b/examples/preserve.py @@ -1,10 +1,11 @@ import asyncio +from typing import AsyncIterator from aiostream import operator, stream -async def main(): - async def agen(): +async def main() -> None: + async def agen() -> AsyncIterator[int]: yield 1 yield 2 yield 3 diff --git a/examples/simple.py b/examples/simple.py index dd9569d..2d5d74c 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -3,14 +3,22 @@ from aiostream import pipe, stream -async def main(): +def is_odd(x: int) -> bool: + return x % 2 == 1 + + +def square(x: int, *_: object) -> int: + return x**2 + + +async def main() -> None: # This stream computes 11² + 13² in 1.5 second xs = ( stream.count(interval=0.1) # Count from zero every 0.1 s | pipe.skip(10) # Skip the first 10 numbers | pipe.take(5) # Take the following 5 - | pipe.filter(lambda x: x % 2) # Keep odd numbers - | pipe.map(lambda x: x**2) # Square the results + | pipe.filter(is_odd) # Keep odd numbers + | pipe.map(square) # Square the results | pipe.accumulate() # Add the numbers together ) print("11² + 13² = ", await xs) diff --git a/setup.cfg b/setup.cfg index b939c7e..e9876ef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,7 +16,7 @@ ignore = F401, F403, E731, W503, E501, E203 [mypy] strict = True -packages = aiostream +packages = aiostream, examples [mypy-aiostream.test_utils] ignore_errors = True diff --git a/tests/test_pipe.py b/tests/test_pipe.py new file mode 100644 index 0000000..66f3596 --- /dev/null +++ b/tests/test_pipe.py @@ -0,0 +1,10 @@ +from aiostream import stream, pipe + + +def test_pipe_module(): + for name in dir(stream): + obj = getattr(stream, name) + pipe_method = getattr(obj, "pipe", None) + if pipe_method is None: + continue + assert getattr(pipe, name) == pipe_method