Skip to content

Commit

Permalink
Operator as object
Browse files Browse the repository at this point in the history
  • Loading branch information
vxgmichel committed May 6, 2024
1 parent 4952900 commit 46a5426
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 117 deletions.
217 changes: 100 additions & 117 deletions aiostream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Protocol,
Union,
TypeVar,
cast,
AsyncIterable,
Awaitable,
)
Expand Down Expand Up @@ -258,23 +257,26 @@ def streamcontext(aiterable: AsyncIterable[T]) -> Streamer[T]:
# Operator type protocol


class OperatorType(Protocol[P, T]):
class Operator(Protocol[P, T]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Stream[T]: ...

def raw(self, *args: P.args, **kwargs: P.kwargs) -> AsyncIterator[T]: ...
@staticmethod
def raw(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[T]: ...


class PipableOperatorType(Protocol[A, P, T]):
class PipableOperator(Protocol[A, P, T]):
def __call__(
self, source: AsyncIterable[A], /, *args: P.args, **kwargs: P.kwargs
) -> Stream[T]: ...

@staticmethod
def raw(
self, source: AsyncIterable[A], /, *args: P.args, **kwargs: P.kwargs
source: AsyncIterable[A], /, *args: P.args, **kwargs: P.kwargs
) -> AsyncIterator[T]: ...

@staticmethod
def pipe(
self, *args: P.args, **kwargs: P.kwargs
*args: P.args, **kwargs: P.kwargs
) -> Callable[[AsyncIterable[A]], Stream[T]]: ...


Expand All @@ -284,7 +286,7 @@ def pipe(
def operator(
func: Callable[P, AsyncIterator[T]] | None = None,
pipable: bool | None = None,
) -> OperatorType[P, T]:
) -> Operator[P, T]:
"""Create a stream operator from an asynchronous generator
(or any function returning an asynchronous iterable).
Expand Down Expand Up @@ -330,7 +332,6 @@ async def random(offset=0., width=1.):
)

# Gather data
bases = (Stream,)
name = func.__name__
module = func.__module__
extra_doc = func.__doc__
Expand All @@ -345,49 +346,46 @@ async def random(offset=0., width=1.):
"since the decorated function becomes an operator class"
)

# Injected parameters
self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
inspect.Parameter("cls", inspect.Parameter.POSITIONAL_OR_KEYWORD)

# Wrapped static method
original = func
original.__qualname__ = name + ".original"
original_func = func
original_func.__qualname__ = name + ".original"

# Raw static method
raw = func
raw.__qualname__ = name + ".raw"
raw_func = func
raw_func.__qualname__ = name + ".raw"

# Init method
def init(self: BaseStream[T], *args: P.args, **kwargs: P.kwargs) -> None:
factory = functools.partial(raw, *args, **kwargs)
return BaseStream.__init__(self, factory)
# Gather attributes
class OperatorImpl:
__qualname__ = name
__module__ = module
__doc__ = doc

# Customize init signature
new_parameters = [self_parameter] + parameters
init.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined]
original = staticmethod(original_func)

# Customize init method
init.__qualname__ = name + ".__init__"
init.__name__ = "__init__"
init.__module__ = module
init.__doc__ = f"Initialize the {name} stream."
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Stream[T]:
factory = functools.partial(raw_func, *args, **kwargs)
return Stream(factory)

# Gather attributes
attrs = {
"__init__": init,
"__module__": module,
"__doc__": doc,
"raw": staticmethod(raw),
"original": staticmethod(original),
}
@staticmethod
def raw(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[T]:
return raw_func(*args, **kwargs)

# Customize call method
self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
new_parameters = [self_parameter] + parameters
OperatorImpl.__call__.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined]
OperatorImpl.__call__.__qualname__ = name + ".__call__"
OperatorImpl.__call__.__name__ = "__call__"
OperatorImpl.__call__.__module__ = module
OperatorImpl.__call__.__doc__ = f"Initialize the {name} stream."

# Create operator class
return cast("OperatorType[P, T]", type(name, bases, attrs))
return OperatorImpl()


def pipable_operator(
func: Callable[Concatenate[AsyncIterable[X], P], AsyncIterator[T]],
) -> PipableOperatorType[X, P, T]:
) -> PipableOperator[X, P, T]:
"""Create a pipable stream operator from an asynchronous generator
(or any function returning an asynchronous iterable).
Expand Down Expand Up @@ -441,7 +439,6 @@ def double(source):
)

# Gather data
bases = (Stream,)
name = func.__name__
module = func.__module__
extra_doc = func.__doc__
Expand All @@ -456,6 +453,13 @@ def double(source):
"since the decorated function becomes an operator class"
)

# Check for positional first parameter
if not parameters or parameters[0].kind not in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
raise ValueError("The first parameter of the operator must be positional")

# Look for "more_sources"
for i, p in enumerate(parameters):
if p.name == "more_sources" and p.kind == inspect.Parameter.VAR_POSITIONAL:
Expand All @@ -464,89 +468,68 @@ def double(source):
else:
more_sources_index = None

# Injected parameters
self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
cls_parameter = inspect.Parameter("cls", inspect.Parameter.POSITIONAL_OR_KEYWORD)

# Wrapped static method
original = func
original.__qualname__ = name + ".original"
original_func = func
original_func.__qualname__ = name + ".original"

# Raw static method
def raw(
arg: AsyncIterable[X], *args: P.args, **kwargs: P.kwargs
) -> AsyncIterator[T]:
assert_async_iterable(arg)
if more_sources_index is not None:
for source in args[more_sources_index - 1 :]:
assert_async_iterable(source)
return func(arg, *args, **kwargs)

# Custonize raw method
raw.__signature__ = signature # type: ignore[attr-defined]
raw.__qualname__ = name + ".raw"
raw.__module__ = module
raw.__doc__ = doc

# Init method
def init(
self: BaseStream[T], arg: AsyncIterable[X], *args: P.args, **kwargs: P.kwargs
) -> None:
assert_async_iterable(arg)
if more_sources_index is not None:
for source in args[more_sources_index - 1 :]:
assert_async_iterable(source)
factory = functools.partial(raw, arg, *args, **kwargs)
return BaseStream.__init__(self, factory)

# Customize init signature
# Gather attributes
class PipableOperatorImpl:
__qualname__ = name
__module__ = module
__doc__ = doc

original = staticmethod(original_func)

@staticmethod
def raw(
arg: AsyncIterable[X], /, *args: P.args, **kwargs: P.kwargs
) -> AsyncIterator[T]:
assert_async_iterable(arg)
if more_sources_index is not None:
for source in args[more_sources_index - 1 :]:
assert_async_iterable(source)
return func(arg, *args, **kwargs)

def __call__(
self, arg: AsyncIterable[X], /, *args: P.args, **kwargs: P.kwargs
) -> Stream[T]:
assert_async_iterable(arg)
if more_sources_index is not None:
for source in args[more_sources_index - 1 :]:
assert_async_iterable(source)
factory = functools.partial(self.raw, arg, *args, **kwargs)
return Stream(factory)

@staticmethod
def pipe(
*args: P.args,
**kwargs: P.kwargs,
) -> Callable[[AsyncIterable[X]], Stream[T]]:
return lambda source: operator_instance(source, *args, **kwargs)

# Customize raw method
PipableOperatorImpl.raw.__signature__ = signature # type: ignore[attr-defined]
PipableOperatorImpl.raw.__qualname__ = name + ".raw"
PipableOperatorImpl.raw.__module__ = module
PipableOperatorImpl.raw.__doc__ = doc

# Customize call method
self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
new_parameters = [self_parameter] + parameters
init.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined]

# Customize init method
init.__qualname__ = name + ".__init__"
init.__name__ = "__init__"
init.__module__ = module
init.__doc__ = f"Initialize the {name} stream."

# Pipe class method
def pipe(
cls: PipableOperatorType[X, P, T],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> Callable[[AsyncIterable[X]], Stream[T]]:
return lambda source: cls(source, *args, **kwargs)

# Customize pipe signature
if parameters and parameters[0].kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
new_parameters = [cls_parameter] + parameters[1:]
else:
raise ValueError("The first parameter of the operator must be positional")
pipe.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined]
PipableOperatorImpl.__call__.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined]
PipableOperatorImpl.__call__.__qualname__ = name + ".__call__"
PipableOperatorImpl.__call__.__name__ = "__call__"
PipableOperatorImpl.__call__.__module__ = module
PipableOperatorImpl.__call__.__doc__ = f"Initialize the {name} stream."

# Customize pipe method
pipe.__qualname__ = name + ".pipe"
pipe.__module__ = module
pipe.__doc__ = f'Pipable "{name}" stream operator.'
PipableOperatorImpl.pipe.__signature__ = signature.replace(parameters=parameters[1:]) # type: ignore[attr-defined]
PipableOperatorImpl.pipe.__qualname__ = name + ".pipe"
PipableOperatorImpl.pipe.__module__ = module
PipableOperatorImpl.pipe.__doc__ = f'Pipable "{name}" stream operator.'
if extra_doc:
pipe.__doc__ += "\n\n " + extra_doc

# Gather attributes
attrs = {
"__init__": init,
"__module__": module,
"__doc__": doc,
"raw": staticmethod(raw),
"original": staticmethod(original),
"pipe": classmethod(pipe), # type: ignore[arg-type]
}
PipableOperatorImpl.pipe.__doc__ += "\n\n " + extra_doc

# Create operator class
return cast(
"PipableOperatorType[X, P, T]",
type(name, bases, attrs),
)
operator_instance = PipableOperatorImpl()
return operator_instance
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ strict = [
"aiostream/manager.py",
"aiostream/pipe.py",
"aiostream/test_utils.py",
"aiostream/core.py",
]

[tool.mypy]
Expand Down

0 comments on commit 46a5426

Please sign in to comment.