From 881de7fc42500ab1c26e294917ff4bae5c1fd002 Mon Sep 17 00:00:00 2001 From: Matthias Veit Date: Thu, 7 Nov 2024 10:14:54 +0100 Subject: [PATCH] completely remove aiostream --- fixcore/.pylintrc | 2 +- fixcore/fixcore/cli/cli.py | 34 ++--- fixcore/fixcore/cli/command.py | 22 ++-- fixcore/fixcore/db/graphdb.py | 7 +- fixcore/fixcore/infra_apps/local_runtime.py | 2 +- fixcore/fixcore/model/db_updater.py | 54 ++++---- fixcore/fixcore/report/benchmark_renderer.py | 36 +++--- fixcore/fixcore/report/inspector_service.py | 14 +- fixcore/fixcore/task/task_handler.py | 6 +- fixcore/pyproject.toml | 1 - fixcore/tests/fixcore/hypothesis_extension.py | 7 +- .../fixcore/report/benchmark_renderer_test.py | 4 +- fixcore/tests/fixcore/util_test.py | 16 +-- .../fixcore/web/content_renderer_test.py | 122 ++++++++---------- fixlib/fixlib/asynchronous/stream.py | 100 +++++++------- fixlib/test/asynchronous/stream_test.py | 6 +- fixshell/.pylintrc | 2 +- requirements-all.txt | 1 - requirements-extra.txt | 1 - requirements.txt | 1 - 20 files changed, 200 insertions(+), 238 deletions(-) diff --git a/fixcore/.pylintrc b/fixcore/.pylintrc index 91536aa44b..94fd5c5778 100644 --- a/fixcore/.pylintrc +++ b/fixcore/.pylintrc @@ -246,7 +246,7 @@ ignored-modules= # List of classes names for which member attributes should not be checked # (useful for classes with attributes dynamically set). -ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local, aiostream.pipe +ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular diff --git a/fixcore/fixcore/cli/cli.py b/fixcore/fixcore/cli/cli.py index 8c535bef82..178bb7233c 100644 --- a/fixcore/fixcore/cli/cli.py +++ b/fixcore/fixcore/cli/cli.py @@ -10,14 +10,13 @@ from typing import Dict, List, Tuple, Union, Sequence from typing import Optional, Any, TYPE_CHECKING -from aiostream import stream from attrs import evolve from parsy import Parser from rich.padding import Padding from fixcore import version from fixcore.analytics import CoreEvent -from fixcore.cli import cmd_with_args_parser, key_values_parser, T, Sink, args_values_parser, JsGen +from fixcore.cli import cmd_with_args_parser, key_values_parser, T, Sink, args_values_parser, JsStream from fixcore.cli.command import ( SearchPart, PredecessorsPart, @@ -78,6 +77,7 @@ from fixcore.types import JsonElement from fixcore.user.model import Permission from fixcore.util import group_by +from fixlib.asynchronous.stream import Stream from fixlib.parse_util import make_parser, pipe_p, semicolon_p if TYPE_CHECKING: @@ -104,7 +104,7 @@ def command_line_parser() -> Parser: return ParsedCommands(commands, maybe_env if maybe_env else {}) -# multiple piped commands are separated by semicolon +# semicolon separates multiple piped commands multi_command_parser = command_line_parser.sep_by(semicolon_p) @@ -187,7 +187,7 @@ def overview() -> str: logo = ctx.render_console(Padding(WelcomeCommand.ck, pad=(0, 0, 0, middle))) if ctx.supports_color() else "" return headline + logo + ctx.render_console(result) - def help_command() -> JsGen: + def help_command() -> JsStream: if not arg: result = overview() elif arg == "placeholders": @@ -209,7 +209,7 @@ def help_command() -> JsGen: else: result = f"No command found with this name: {arg}" - return stream.just(result) + return Stream.just(result) return CLISource.single(help_command, required_permissions={Permission.read}) @@ -352,11 +352,11 @@ def command( self, name: str, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any ) -> ExecutableCommand: """ - Create an executable command for given command name, args and context. - :param name: the name of the command to execute (must be a known command) - :param arg: the arg of the command (must be parsable by the command) - :param ctx: the context of this command. - :return: the ready to run executable command. + Create an executable command for given command name, args, and context. + :param name: The name of the command to execute (must be a known command). + :param arg: The arg of the command (must be parsable by the command). + :param ctx: The context of this command. + :return: The ready to run executable command. :raises: CLIParseError: if the name of the command is not known, or the argument fails to parse. """ @@ -377,9 +377,9 @@ async def create_query( Takes a list of query part commands and combine them to a single executable query command. This process can also introduce new commands that should run after the query is finished. Therefore, a list of executable commands is returned. - :param commands: the incoming executable commands, which actions are all instances of SearchCLIPart. - :param ctx: the context to execute within. - :return: the resulting list of commands to execute. + :param commands: The incoming executable commands, which actions are all instances of SearchCLIPart. + :param ctx: The context to execute within. + :return: The resulting list of commands to execute. """ # Pass parsed options to execute query @@ -484,8 +484,8 @@ async def parse_query(query_arg: str) -> Query: first_head_tail_in_a_row = None head_tail_keep_order = True - # Define default sort order, if not already defined - # A sort order is required to always return the result in a deterministic way to the user. + # Define default sort order, if not already defined. + # A sort order is required to always return the result deterministically to the user. # Deterministic order is required for head/tail to work if query.is_simple_fulltext_search(): # Do not define any additional sort order for fulltext searches @@ -494,7 +494,7 @@ async def parse_query(query_arg: str) -> Query: parts = [pt if pt.sort else evolve(pt, sort=default_sort) for pt in query.parts] query = evolve(query, parts=parts) - # If the last part is a navigation, we need to add sort which will ingest a new part. + # If the last part is a navigation, we need to add a sort which will ingest a new part. with_sort = query.set_sort(*default_sort) if query.current_part.navigation else query section = ctx.env.get("section", PathRoot) # If this is an aggregate query, the default sort needs to be changed @@ -534,7 +534,7 @@ def rewrite_command_line(cmds: List[ExecutableCommand], ctx: CLIContext) -> List Rules: - add the list command if no output format is defined - add a format to write commands if no output format is defined - - report benchmark run will be formatted as benchmark result automatically + - report benchmark run will be formatted as a benchmark result automatically """ if ctx.env.get("no_rewrite") or len(cmds) == 0: return cmds diff --git a/fixcore/fixcore/cli/command.py b/fixcore/fixcore/cli/command.py index acb01fd6fc..53b8000b58 100644 --- a/fixcore/fixcore/cli/command.py +++ b/fixcore/fixcore/cli/command.py @@ -1593,7 +1593,7 @@ def args_info(self) -> ArgsInfo: def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIFlow: size = int(arg) if arg else 100 return CLIFlow( - lambda in_stream: Stream(in_stream).chunks(size).map(Stream.as_list), + lambda in_stream: Stream(in_stream).chunks(size), required_permissions={Permission.read}, ) @@ -1978,12 +1978,12 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa return CLIFlow(lambda i: Stream(i).chunks(buffer_size).flatmap(func), required_permissions={Permission.write}) async def set_desired( - self, arg: Optional[str], graph_name: GraphName, patch: Json, items: Stream[Json] + self, arg: Optional[str], graph_name: GraphName, patch: Json, items: List[Json] ) -> AsyncIterator[JsonElement]: model = await self.dependencies.model_handler.load_model(graph_name) db = self.dependencies.db_access.get_graph_db(graph_name) node_ids = [] - async for item in items: + for item in items: if "id" in item: node_ids.append(item["id"]) elif isinstance(item, str): @@ -2090,7 +2090,7 @@ def patch(self, arg: Optional[str], ctx: CLIContext) -> Json: return {"clean": True} async def set_desired( - self, arg: Optional[str], graph_name: GraphName, patch: Json, items: Stream[Json] + self, arg: Optional[str], graph_name: GraphName, patch: Json, items: List[Json] ) -> AsyncIterator[JsonElement]: reason = f"Reason: {strip_quotes(arg)}" if arg else "No reason provided." async for elem in super().set_desired(arg, graph_name, patch, items): @@ -2113,11 +2113,11 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa func = partial(self.set_metadata, ctx.graph_name, self.patch(arg, ctx)) return CLIFlow(lambda i: Stream(i).chunks(buffer_size).flatmap(func), required_permissions={Permission.write}) - async def set_metadata(self, graph_name: GraphName, patch: Json, items: Stream[Json]) -> AsyncIterator[JsonElement]: + async def set_metadata(self, graph_name: GraphName, patch: Json, items: List[Json]) -> AsyncIterator[JsonElement]: model = await self.dependencies.model_handler.load_model(graph_name) db = self.dependencies.db_access.get_graph_db(graph_name) node_ids = [] - async for item in items: + for item in items: if "id" in item: node_ids.append(item["id"]) elif isinstance(item, str): @@ -2864,7 +2864,7 @@ def extract_values(elem: JsonElement) -> List[Any | None]: result.append(value) return result - async def generate_markdown(chunk: Tuple[int, Stream[List[Any]]]) -> JsGen: + async def generate_markdown(chunk: Tuple[int, List[List[Any]]]) -> JsGen: idx, rows = chunk def to_str(elem: Any) -> str: @@ -2896,7 +2896,7 @@ def to_str(elem: Any) -> str: line += "|" yield line - async for row in rows: + for row in rows: line = "" for value, padding in zip(row, columns_padding): line += f"|{to_str(value).ljust(padding)}" @@ -3260,12 +3260,12 @@ def load_by_id_merged( expected_kind: Optional[str] = None, **env: str, ) -> JsStream: - async def load_element(items: JsStream) -> AsyncIterator[JsonElement]: + async def load_element(items: List[JsonElement]) -> AsyncIterator[JsonElement]: # collect ids either from json dict or string - ids: List[str] = [i["id"] if is_node(i) else i async for i in items] # type: ignore + ids: List[str] = [i["id"] if is_node(i) else i for i in items] # type: ignore # if there is an entry which is not a string, use the list as is (e.g. chunked) if any(a for a in ids if not isinstance(a, str)): - async for a in items: + for a in items: yield a else: # one query to load all items that match given ids (max 1000 as defined in chunk size) diff --git a/fixcore/fixcore/db/graphdb.py b/fixcore/fixcore/db/graphdb.py index 14b997c8a5..0424414776 100644 --- a/fixcore/fixcore/db/graphdb.py +++ b/fixcore/fixcore/db/graphdb.py @@ -23,7 +23,6 @@ Union, ) -from aiostream import stream, pipe from arango import AnalyzerGetError from arango.collection import VertexCollection, StandardCollection, EdgeCollection from arango.graph import Graph @@ -67,6 +66,7 @@ set_value_in_path, if_set, ) +from fixlib.asynchronous.stream import Stream log = logging.getLogger(__name__) @@ -675,9 +675,8 @@ async def move_security_temp_to_proper() -> None: try: # stream updates to the temp collection - async with (stream.iterate(iterator) | pipe.chunks(1000)).stream() as streamer: - async for part in streamer: - await update_chunk(dict(part)) + async for part in Stream.iterate(iterator).chunks(1000): + await update_chunk(dict(part)) # move temp collection to proper and history collection await move_security_temp_to_proper() finally: diff --git a/fixcore/fixcore/infra_apps/local_runtime.py b/fixcore/fixcore/infra_apps/local_runtime.py index 22506071a0..4ca160e8a1 100644 --- a/fixcore/fixcore/infra_apps/local_runtime.py +++ b/fixcore/fixcore/infra_apps/local_runtime.py @@ -116,4 +116,4 @@ async def _interpret_line(self, line: str, ctx: CLIContext) -> JsStream: total_nr_outputs = total_nr_outputs + (src_ctx.count or 0) command_streams.append(command_output_stream) - return Stream.iterate(command_streams).concat(task_limit=1) # type: ignore + return Stream.iterate(command_streams).concat() # type: ignore diff --git a/fixcore/fixcore/model/db_updater.py b/fixcore/fixcore/model/db_updater.py index cab00efb37..9a29c95573 100644 --- a/fixcore/fixcore/model/db_updater.py +++ b/fixcore/fixcore/model/db_updater.py @@ -13,11 +13,9 @@ from multiprocessing import Process, Queue from pathlib import Path from queue import Empty -from typing import Optional, Union, Any, Generator, List, AsyncIterator, Dict +from typing import Optional, Union, Any, List, AsyncIterator, Dict import aiofiles -from aiostream import stream, pipe -from aiostream.core import Stream from attrs import define from fixcore.analytics import AnalyticsEventSender, InMemoryEventSender, AnalyticsEvent @@ -36,6 +34,7 @@ from fixcore.system_start import db_access, setup_process, reset_process_start_method from fixcore.types import Json from fixcore.util import utc, uuid_str, shutdown_process +from fixlib.asynchronous.stream import Stream log = logging.getLogger(__name__) @@ -56,9 +55,9 @@ class ReadFile(ProcessAction): path: Path task_id: Optional[str] - def jsons(self) -> Generator[Json, Any, None]: - with open(self.path, "r", encoding="utf-8") as f: - for line in f: + async def jsons(self) -> AsyncIterator[Json]: + async with aiofiles.open(self.path, "r", encoding="utf-8") as f: + async for line in f: if line.strip(): yield json.loads(line) @@ -75,8 +74,8 @@ class ReadElement(ProcessAction): elements: List[Union[bytes, Json]] task_id: Optional[str] - def jsons(self) -> Generator[Json, Any, None]: - return (e if isinstance(e, dict) else json.loads(e) for e in self.elements) + def jsons(self) -> AsyncIterator[Json]: + return Stream.iterate(self.elements).map(lambda e: e if isinstance(e, dict) else json.loads(e)) @define @@ -125,15 +124,15 @@ def get_value(self) -> GraphUpdate: class DbUpdaterProcess(Process): """ - This update class implements Process and is supposed to run as separate process. + This update class implements Process and is supposed to run as a separate process. Note: default starting method is supposed to be "spawn". This process has 2 queues to read input from and write output to. - All elements in either queues are of type ProcessAction. + All elements in all queues are of type ProcessAction. The parent process should stream the raw commands of graph to this process via ReadElement objects. Once the MergeGraph action is received, the graph gets imported. - From here the parent expects result messages from the child. + From here, the parent expects result messages from the child. All events happen in the child are forwarded to the parent via EmitEvent. Once the graph update is done, a result is send. The result is either an exception in case of failure or a graph update in success case. @@ -156,8 +155,8 @@ def __init__( def next_action(self) -> ProcessAction: try: - # graph is read into memory. If the sender does not send data in a given amount of time, - # we raise an exception and abort the update. + # The graph is read into memory. + # If the sender does not send data in a given amount of time, we raise an exception and abort the update. return self.read_queue.get(True, 90) except Empty as ex: raise ImportAborted("Merge process did not receive any data for more than 90 seconds. Abort.") from ex @@ -168,12 +167,12 @@ async def merge_graph(self, db: DbAccess) -> GraphUpdate: # type: ignore builder = GraphBuilder(model, self.change_id) nxt = self.next_action() if isinstance(nxt, ReadFile): - for element in nxt.jsons(): + async for element in nxt.jsons(): builder.add_from_json(element) nxt = self.next_action() elif isinstance(nxt, ReadElement): while isinstance(nxt, ReadElement): - for element in nxt.jsons(): + async for element in nxt.jsons(): builder.add_from_json(element) log.debug(f"Read {int(BatchSize / 1000)}K elements in process") nxt = self.next_action() @@ -276,16 +275,11 @@ async def __process_item(self, item: GraphUpdateTask) -> Union[GraphUpdate, Exce async def start(self) -> None: async def wait_for_update() -> None: log.info("Start waiting for graph updates") - fl = ( - stream.call(self.update_queue.get) # type: ignore - | pipe.cycle() - | pipe.map(self.__process_item, task_limit=self.config.graph.parallel_imports) # type: ignore - ) + fl = Stream.for_ever(self.update_queue.get).map(self.__process_item, task_limit=self.config.graph.parallel_imports) # type: ignore # noqa with suppress(CancelledError): - async with fl.stream() as streamer: - async for update in streamer: - if isinstance(update, GraphUpdate): - log.info(f"Finished spawned graph merge: {update}") + async for update in fl: + if isinstance(update, GraphUpdate): + log.info(f"Finished spawned graph merge: {update}") self.handler_task = asyncio.create_task(wait_for_update()) @@ -373,19 +367,17 @@ async def read_forever() -> GraphUpdate: task: Optional[Task[GraphUpdate]] = None result: Optional[GraphUpdate] = None try: - reset_process_start_method() # other libraries might have tampered the value in the mean time + reset_process_start_method() # other libraries might have tampered the value in the meantime updater.start() task = read_results() # concurrently read result queue # Either send a file or stream the content directly if isinstance(content, Path): await send_to_child(ReadFile(content, task_id)) else: - chunked: Stream[List[Union[bytes, Json]]] = stream.chunks(content, BatchSize) # type: ignore - async with chunked.stream() as streamer: - async for lines in streamer: - if not await send_to_child(ReadElement(lines, task_id)): - # in case the child is dead, we should stop - break + async for lines in Stream.iterate(content).chunks(BatchSize): + if not await send_to_child(ReadElement(lines, task_id)): + # in case the child is dead, we should stop + break await send_to_child(MergeGraph(db.name, change_id, maybe_batch is not None, task_id)) result = await task # wait for final result await self.model_handler.load_model(db.name, force=True) # reload model to get the latest changes diff --git a/fixcore/fixcore/report/benchmark_renderer.py b/fixcore/fixcore/report/benchmark_renderer.py index af40975da7..babd149d5d 100644 --- a/fixcore/fixcore/report/benchmark_renderer.py +++ b/fixcore/fixcore/report/benchmark_renderer.py @@ -1,6 +1,5 @@ from typing import AsyncGenerator, List, AsyncIterable -from aiostream import stream from networkx import DiGraph from rich._emoji_codes import EMOJI @@ -91,27 +90,26 @@ def render_check_result(check_result: CheckResult, account: str) -> str: async def respond_benchmark_result(gen: AsyncIterable[JsonElement]) -> AsyncGenerator[str, None]: - # step 1: read graph + # step 1: read graph graph = DiGraph() - async with stream.iterate(gen).stream() as streamer: - async for item in streamer: - if isinstance(item, dict): - type_name = item.get("type") - if type_name == "node": - uid = value_in_path(item, NodePath.node_id) - reported = value_in_path(item, NodePath.reported) - kind = value_in_path(item, NodePath.reported_kind) - if uid and reported and kind and (reader := kind_reader.get(kind)): - graph.add_node(uid, data=reader(item)) - elif type_name == "edge": - from_node = value_in_path(item, NodePath.from_node) - to_node = value_in_path(item, NodePath.to_node) - if from_node and to_node: - graph.add_edge(from_node, to_node) - else: - raise AttributeError(f"Expect json object but got: {type(item)}: {item}") + async for item in gen: + if isinstance(item, dict): + type_name = item.get("type") + if type_name == "node": + uid = value_in_path(item, NodePath.node_id) + reported = value_in_path(item, NodePath.reported) + kind = value_in_path(item, NodePath.reported_kind) + if uid and reported and kind and (reader := kind_reader.get(kind)): + graph.add_node(uid, data=reader(item)) + elif type_name == "edge": + from_node = value_in_path(item, NodePath.from_node) + to_node = value_in_path(item, NodePath.to_node) + if from_node and to_node: + graph.add_edge(from_node, to_node) else: raise AttributeError(f"Expect json object but got: {type(item)}: {item}") + else: + raise AttributeError(f"Expect json object but got: {type(item)}: {item}") # step 2: read benchmark result from graph def traverse(node_id: str, collection: CheckCollectionResult) -> None: diff --git a/fixcore/fixcore/report/inspector_service.py b/fixcore/fixcore/report/inspector_service.py index ceb9c8f5b4..f3b70cc77e 100644 --- a/fixcore/fixcore/report/inspector_service.py +++ b/fixcore/fixcore/report/inspector_service.py @@ -3,8 +3,6 @@ from functools import lru_cache from typing import Optional, List, Dict, Tuple, Callable, AsyncIterator, cast, Set -from aiostream import stream, pipe -from aiostream.core import Stream from attr import define from fixcore.analytics import CoreEvent @@ -40,6 +38,7 @@ from fixcore.service import Service from fixcore.types import Json from fixcore.util import value_in_path, uuid_str, value_in_path_get +from fixlib.asynchronous.stream import Stream from fixlib.json_bender import Bender, S, bend log = logging.getLogger(__name__) @@ -380,7 +379,7 @@ async def list_failing_resources( async def __list_failing_resources( self, graph: GraphName, model: Model, inspection: ReportCheck, context: CheckContext ) -> AsyncIterator[Json]: - # final environment: defaults are coming from the check and are eventually overriden in the config + # final environment: defaults are coming from the check and are eventually overridden in the config env = inspection.environment(context.override_values()) account_id_prop = "ancestors.account.reported.id" ignore_prop = "metadata.security_ignore" @@ -484,7 +483,7 @@ def to_result(cc: CheckCollection) -> CheckCollectionResult: node_id=next_node_id(), ) - async def __perform_checks( # type: ignore + async def __perform_checks( self, graph: GraphName, checks: List[ReportCheck], context: CheckContext ) -> Dict[str, SingleCheckResult]: # load model @@ -493,11 +492,10 @@ async def __perform_checks( # type: ignore async def perform_single(check: ReportCheck) -> Tuple[str, SingleCheckResult]: return check.id, await self.__perform_check(graph, model, check, context) - check_results: Stream[Tuple[str, SingleCheckResult]] = stream.iterate(checks) | pipe.map( - perform_single, ordered=False, task_limit=context.parallel_checks # type: ignore + check_results: Stream[Tuple[str, SingleCheckResult]] = Stream.iterate(checks).map( + perform_single, ordered=False, task_limit=context.parallel_checks ) - async with check_results.stream() as streamer: - return {key: value async for key, value in streamer} + return {key: value async for key, value in check_results} async def __perform_check( self, graph: GraphName, model: Model, inspection: ReportCheck, context: CheckContext diff --git a/fixcore/fixcore/task/task_handler.py b/fixcore/fixcore/task/task_handler.py index d8677fee54..4d7073afe1 100644 --- a/fixcore/fixcore/task/task_handler.py +++ b/fixcore/fixcore/task/task_handler.py @@ -8,7 +8,6 @@ from copy import copy from datetime import timedelta from typing import Optional, Any, Callable, Union, Sequence, Dict, List, Tuple -from aiostream import stream from attrs import evolve from fixcore.analytics import AnalyticsEventSender, CoreEvent @@ -57,6 +56,7 @@ ) from fixcore.util import first, Periodic, group_by, utc_str, utc, partition_by from fixcore.types import Json +from fixlib.asynchronous.stream import Stream log = logging.getLogger(__name__) @@ -89,7 +89,7 @@ def __init__( # note: the waiting queue is kept in memory and lost when the service is restarted. self.start_when_done: Dict[str, TaskDescription] = {} - # Step1: define all workflows and jobs in code: later it will be persisted and read from database + # Step1: define all workflows and jobs in code: later it will be persisted and read from the database self.task_descriptions: Sequence[TaskDescription] = [*self.known_workflows(config), *self.known_jobs()] self.tasks: Dict[TaskId, RunningTask] = {} self.message_bus_watcher: Optional[Task[None]] = None @@ -496,7 +496,7 @@ async def execute_commands() -> None: results[command] = None elif isinstance(command, ExecuteOnCLI): ctx = evolve(self.cli_context, env={**command.env, **wi.descriptor.environment}) - result = await self.cli.execute_cli_command(command.command, stream.list, ctx) # type: ignore + result = await self.cli.execute_cli_command(command.command, Stream.as_list, ctx) results[command] = result else: raise AttributeError(f"Does not understand this command: {wi.descriptor.name}: {command}") diff --git a/fixcore/pyproject.toml b/fixcore/pyproject.toml index 5c85874c80..2ef84831f1 100644 --- a/fixcore/pyproject.toml +++ b/fixcore/pyproject.toml @@ -35,7 +35,6 @@ dependencies = [ "aiohttp-jinja2", "aiohttp-swagger3", "aiohttp[speedups]", - "aiostream", "cryptography", "deepdiff", "detect_secrets", diff --git a/fixcore/tests/fixcore/hypothesis_extension.py b/fixcore/tests/fixcore/hypothesis_extension.py index 297376f6ea..db2cecf6b4 100644 --- a/fixcore/tests/fixcore/hypothesis_extension.py +++ b/fixcore/tests/fixcore/hypothesis_extension.py @@ -1,9 +1,7 @@ import string from datetime import datetime -from typing import TypeVar, Callable, Any, cast, Optional, List, Generator +from typing import TypeVar, Any, cast, Optional, List, Generator -from aiostream import stream -from aiostream.core import Stream from hypothesis.strategies import ( SearchStrategy, just, @@ -20,6 +18,7 @@ from fixcore.model.resolve_in_graph import NodePath from fixcore.types import JsonElement, Json from fixcore.util import value_in_path, interleave +from fixlib.asynchronous.stream import Stream T = TypeVar("T") @@ -71,4 +70,4 @@ def from_node() -> Generator[Json, Any, None]: for from_n, to_n in interleave(node_ids): yield {"type": "edge", "from": from_n, "to": to_n} - return stream.iterate(from_node()) + return Stream.iterate(from_node()) diff --git a/fixcore/tests/fixcore/report/benchmark_renderer_test.py b/fixcore/tests/fixcore/report/benchmark_renderer_test.py index 080740695c..f34c711124 100644 --- a/fixcore/tests/fixcore/report/benchmark_renderer_test.py +++ b/fixcore/tests/fixcore/report/benchmark_renderer_test.py @@ -1,16 +1,16 @@ import pytest -from aiostream import stream from fixcore.report.benchmark_renderer import respond_benchmark_result from fixcore.report.inspector_service import InspectorService from fixcore.ids import GraphName +from fixlib.asynchronous.stream import Stream @pytest.mark.asyncio async def test_benchmark_renderer(inspector_service: InspectorService) -> None: bench_results = await inspector_service.perform_benchmarks(GraphName("ns"), ["test"]) bench_result = bench_results["test"] - render_result = [elem async for elem in respond_benchmark_result(stream.iterate(bench_result.to_graph()))] + render_result = [elem async for elem in respond_benchmark_result(Stream.iterate(bench_result.to_graph()))] assert len(render_result) == 1 assert ( render_result[0] diff --git a/fixcore/tests/fixcore/util_test.py b/fixcore/tests/fixcore/util_test.py index bf64f33d8e..5a688f3366 100644 --- a/fixcore/tests/fixcore/util_test.py +++ b/fixcore/tests/fixcore/util_test.py @@ -5,7 +5,6 @@ import pytest import pytz -from aiostream import stream from fixcore.util import ( AccessJson, @@ -21,6 +20,7 @@ utc_str, parse_utc, ) +from fixlib.asynchronous.stream import Stream def not_in_path(name: str, *other: str) -> bool: @@ -107,17 +107,9 @@ def test_del_value_in_path() -> None: @pytest.mark.asyncio async def test_async_gen() -> None: - async with stream.empty().stream() as empty: - async for _ in await force_gen(empty): - pass - - with pytest.raises(Exception): - async with stream.throw(Exception(";)")).stream() as err: - async for _ in await force_gen(err): - pass - - async with stream.iterate(range(0, 100)).stream() as elems: - assert [x async for x in await force_gen(elems)] == list(range(0, 100)) + async for _ in await force_gen(Stream.empty()): + pass + assert [x async for x in await force_gen(Stream.iterate(range(0, 100)))] == list(range(0, 100)) def test_deep_merge() -> None: diff --git a/fixcore/tests/fixcore/web/content_renderer_test.py b/fixcore/tests/fixcore/web/content_renderer_test.py index 4d3c5c6724..f37cf276e5 100644 --- a/fixcore/tests/fixcore/web/content_renderer_test.py +++ b/fixcore/tests/fixcore/web/content_renderer_test.py @@ -4,7 +4,6 @@ import pytest import yaml -from aiostream import stream from hypothesis import given, settings, HealthCheck from hypothesis.strategies import lists @@ -18,6 +17,7 @@ respond_cytoscape, respond_graphml, ) +from fixlib.asynchronous.stream import Stream from tests.fixcore.hypothesis_extension import ( json_array_gen, json_simple_element_gen, @@ -30,79 +30,72 @@ @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_json(elements: List[JsonElement]) -> None: - async with stream.iterate(elements).stream() as streamer: - result = "" - async for elem in respond_json(streamer): - result += elem - assert json.loads(result) == elements + result = "" + async for elem in respond_json(Stream.iterate(elements)): + result += elem + assert json.loads(result) == elements @given(json_array_gen) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_ndjson(elements: List[JsonElement]) -> None: - async with stream.iterate(elements).stream() as streamer: - result = [] - async for elem in respond_ndjson(streamer): - result.append(json.loads(elem.strip())) - assert result == elements + result = [] + async for elem in respond_ndjson(Stream.iterate(elements)): + result.append(json.loads(elem.strip())) + assert result == elements @given(json_array_gen) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_yaml(elements: List[JsonElement]) -> None: - async with stream.iterate(elements).stream() as streamer: - result = "" - async for elem in respond_yaml(streamer): - result += elem + "\n" - assert [a for a in yaml.full_load_all(result)] == elements + result = "" + async for elem in respond_yaml(Stream.iterate(elements)): + result += elem + "\n" + assert [a for a in yaml.full_load_all(result)] == elements @given(lists(json_simple_element_gen, min_size=1, max_size=10)) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_text_simple_elements(elements: List[JsonElement]) -> None: - async with stream.iterate(elements).stream() as streamer: - result = "" - async for elem in respond_text(streamer): - result += elem + "\n" - # every element is rendered as one or more line (string with \n is rendered as multiple lines) - assert len(elements) + 1 <= len(result.split("\n")) + result = "" + async for elem in respond_text(Stream.iterate(elements)): + result += elem + "\n" + # every element is rendered as one or more line (string with \n is rendered as multiple lines) + assert len(elements) + 1 <= len(result.split("\n")) @given(lists(node_gen(), min_size=1, max_size=10)) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_text_complex_elements(elements: List[JsonElement]) -> None: - async with stream.iterate(elements).stream() as streamer: - result = "" - async for elem in respond_text(streamer): - result += elem - # every element is rendered as yaml with --- as object deliminator - assert len(elements) == len(result.split("---")) + result = "" + async for elem in respond_text(Stream.iterate(elements)): + result += elem + # every element is rendered as yaml with --- as object deliminator + assert len(elements) == len(result.split("---")) @given(lists(node_gen(), min_size=1, max_size=10)) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_cytoscape(elements: List[Json]) -> None: - async with graph_stream(elements).stream() as streamer: - result = "" - async for elem in respond_cytoscape(streamer): - result += elem - # The resulting string can be parsed as json - assert json.loads(result) + result = "" + async for elem in respond_cytoscape(Stream.iterate(elements)): + result += elem + # The resulting string can be parsed as json + assert json.loads(result) @given(lists(node_gen(), min_size=1, max_size=10)) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_graphml(elements: List[Json]) -> None: - async with graph_stream(elements).stream() as streamer: - result = "" - async for elem in respond_graphml(streamer): - result += elem + result = "" + async for elem in respond_graphml(Stream.iterate(elements)): + result += elem # The resulting string can be parsed as xml assert ElementTree.fromstring(result) is not None @@ -119,30 +112,29 @@ def edge(from_node: str, to_node: str) -> Json: nodes = [node("a", "acc1"), node("b", "acc1"), node("c", "acc2")] edges = [edge("a", "b"), edge("a", "c"), edge("b", "c")] - async with stream.iterate(nodes + edges).stream() as streamer: - result = "" - async for elem in respond_dot(streamer): - result += elem + "\n" - expected = ( - "digraph {\n" - "rankdir=LR\n" - "overlap=false\n" - "splines=true\n" - "node [shape=Mrecord colorscheme=paired12]\n" - "edge [arrowsize=0.5]\n" - ' "a" [label="a|a", style=filled fillcolor=1];\n' - ' "b" [label="b|b", style=filled fillcolor=2];\n' - ' "c" [label="c|c", style=filled fillcolor=3];\n' - ' "a" -> "b" [label="delete"]\n' - ' "a" -> "c" [label="delete"]\n' - ' "b" -> "c" [label="delete"]\n' - ' subgraph "acc1" {\n' - ' "a"\n' - ' "b"\n' - " }\n" - ' subgraph "acc2" {\n' - ' "c"\n' - " }\n" - "}\n" - ) - assert result == expected + result = "" + async for elem in respond_dot(Stream.iterate(nodes + edges)): + result += elem + "\n" + expected = ( + "digraph {\n" + "rankdir=LR\n" + "overlap=false\n" + "splines=true\n" + "node [shape=Mrecord colorscheme=paired12]\n" + "edge [arrowsize=0.5]\n" + ' "a" [label="a|a", style=filled fillcolor=1];\n' + ' "b" [label="b|b", style=filled fillcolor=2];\n' + ' "c" [label="c|c", style=filled fillcolor=3];\n' + ' "a" -> "b" [label="delete"]\n' + ' "a" -> "c" [label="delete"]\n' + ' "b" -> "c" [label="delete"]\n' + ' subgraph "acc1" {\n' + ' "a"\n' + ' "b"\n' + " }\n" + ' subgraph "acc2" {\n' + ' "c"\n' + " }\n" + "}\n" + ) + assert result == expected diff --git a/fixlib/fixlib/asynchronous/stream.py b/fixlib/fixlib/asynchronous/stream.py index e0ff050211..516f5afed3 100644 --- a/fixlib/fixlib/asynchronous/stream.py +++ b/fixlib/fixlib/asynchronous/stream.py @@ -13,8 +13,6 @@ DirectOrAwaitable: TypeAlias = Union[T, Awaitable[T]] IterOrAsyncIter: TypeAlias = Union[Iterable[T], AsyncIterable[T]] -DefaultTaskLimit = 1 - def _async_iter(x: Iterable[T]) -> AsyncIterator[T]: async def gen() -> AsyncIterator[T]: @@ -36,17 +34,35 @@ def _flatmap( task_limit: Optional[int], ordered: bool, ) -> AsyncIterator[T]: - if ordered: + if task_limit is None or task_limit == 1: + return _flatmap_direct(source) + elif ordered: return _flatmap_ordered(source, task_limit) else: return _flatmap_unordered(source, task_limit) +async def _flatmap_direct( + source: AsyncIterable[IterOrAsyncIter[DirectOrAwaitable[T]]], +) -> AsyncIterator[T]: + async for sub_iter in source: + if isinstance(sub_iter, AsyncIterable): + async for item in sub_iter: + if isinstance(item, Awaitable): + item = await item + yield item + else: + for item in sub_iter: + if isinstance(item, Awaitable): + item = await item + yield item + + async def _flatmap_unordered( source: AsyncIterable[IterOrAsyncIter[DirectOrAwaitable[T]]], - task_limit: Optional[int] = None, + task_limit: int, ) -> AsyncIterator[T]: - semaphore = asyncio.Semaphore(task_limit or DefaultTaskLimit) + semaphore = asyncio.Semaphore(task_limit) queue: asyncio.Queue[T | Exception] = asyncio.Queue() tasks_in_flight = 0 @@ -91,10 +107,9 @@ async def worker(sub_iter: IterOrAsyncIter[DirectOrAwaitable[T]]) -> None: async def _flatmap_ordered( source: AsyncIterable[IterOrAsyncIter[DirectOrAwaitable[T]]], - task_limit: Optional[int] = None, + task_limit: int, ) -> AsyncIterator[T]: - tlf = task_limit or DefaultTaskLimit - semaphore = asyncio.Semaphore(tlf) + semaphore = asyncio.Semaphore(task_limit) tasks: Dict[int, Task[None]] = {} results: Dict[int, List[T] | Exception] = {} next_index_to_yield = 0 @@ -123,7 +138,7 @@ async def worker(sub_iter: IterOrAsyncIter[T | Awaitable[T]], index: int) -> Non while True: # Start new tasks up to task_limit ahead of next_index_to_yield - while not source_exhausted and (max_index_started - next_index_to_yield + 1) < tlf: + while not source_exhausted and (max_index_started - next_index_to_yield + 1) < task_limit: try: await semaphore.acquire() si = await anext(source_iter) @@ -190,26 +205,30 @@ def map( task_limit: Optional[int] = None, ordered: bool = True, ) -> Stream[R]: - async def gen() -> AsyncIterator[AsyncIterator[R | Awaitable[R]]]: + async def gen() -> AsyncIterator[IterOrAsyncIter[DirectOrAwaitable[R]]]: async for item in self: res = fn(item) - yield _async_iter([res]) + yield [res] + # in the case of a synchronous function, task_limit is ignored + task_limit = task_limit if asyncio.iscoroutinefunction(fn) else 1 return Stream(_flatmap(gen(), task_limit, ordered)) def flatmap( self, - fn: Callable[[T], DirectOrAwaitable[IterOrAsyncIter[R]]], + fn: Callable[[T], DirectOrAwaitable[IterOrAsyncIter[DirectOrAwaitable[R]]]], task_limit: Optional[int] = None, ordered: bool = True, ) -> Stream[R]: - async def gen() -> AsyncIterator[IterOrAsyncIter[R]]: + async def gen() -> AsyncIterator[IterOrAsyncIter[DirectOrAwaitable[R]]]: async for item in self: res = fn(item) if isinstance(res, Awaitable): res = await res yield res + # in the case of a synchronous function, task_limit is ignored + task_limit = task_limit if asyncio.iscoroutinefunction(fn) else 1 return Stream(_flatmap(gen(), task_limit, ordered)) def concat(self: Stream[Stream[T]], task_limit: Optional[int] = None, ordered: bool = True) -> Stream[T]: @@ -256,35 +275,19 @@ async def gen() -> AsyncIterator[Tuple[int, T]]: return Stream(gen()) - def chunks(self, num: int) -> Stream[Stream[T]]: - def take_n(iterator: AsyncIterator[T], n: int) -> AsyncIterator[T]: - async def n_gen() -> AsyncIterator[T]: - count = 0 - try: - while count < n: - item = await anext(iterator) - yield item - count += 1 - except StopAsyncIteration: - return - - return n_gen() - - async def gen() -> AsyncIterator[Stream[T]]: - iterator = aiter(self.iterator) + def chunks(self, num: int) -> Stream[List[T]]: + async def gen() -> AsyncIterator[List[T]]: while True: - chunk_iterator = take_n(iterator, num) + chunk_items: List[T] = [] try: - first_item = await anext(chunk_iterator) + for _ in range(num): + item = await anext(self.iterator) + chunk_items.append(item) + yield chunk_items except StopAsyncIteration: - break # No more items - - async def chunk_with_first() -> AsyncIterator[T]: - yield first_item - async for item in chunk_iterator: - yield item - - yield Stream(chunk_with_first()) + if chunk_items: + yield chunk_items + break return Stream(gen()) @@ -302,18 +305,6 @@ async def gen() -> AsyncIterator[T]: return Stream(gen()) - def cycle(self) -> Stream[T]: - async def gen() -> AsyncIterator[T]: - items = [] - async for item in self: - yield item - items.append(item) - while items: - for item in items: - yield item - - return Stream(gen()) - async def collect(self) -> List[T]: return [item async for item in self] @@ -345,10 +336,13 @@ async def empty() -> AsyncIterator[Never]: return Stream(empty()) @staticmethod - def for_ever(fn: Callable[[], T]) -> Stream[T]: + def for_ever(fn: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Stream[T]: async def gen() -> AsyncIterator[T]: while True: - yield fn() + if asyncio.iscoroutinefunction(fn): + yield await fn(*args, **kwargs) + else: + yield fn(*args, **kwargs) # type: ignore return Stream(gen()) diff --git a/fixlib/test/asynchronous/stream_test.py b/fixlib/test/asynchronous/stream_test.py index 5b4e6efc00..68cba980b1 100644 --- a/fixlib/test/asynchronous/stream_test.py +++ b/fixlib/test/asynchronous/stream_test.py @@ -68,5 +68,7 @@ def with_int(foo: int) -> int: assert await Stream.call(fn, 1, "bla").map(with_int).collect() == [124] -# async def test_chunks(example_stream: Stream) -> None: -# assert await example_stream.chunks(2).collect() == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] +async def test_chunks() -> None: + assert len([chunk async for chunk in example_stream().chunks(2)]) == 3 + assert [chunk async for chunk in example_stream().chunks(2)] == await example_stream().chunks(2).collect() + assert await example_stream().chunks(2).map(Stream.as_list).collect() == [[0, 1], [2, 3], [4]] diff --git a/fixshell/.pylintrc b/fixshell/.pylintrc index fd1655a604..b2bce42c1c 100644 --- a/fixshell/.pylintrc +++ b/fixshell/.pylintrc @@ -245,7 +245,7 @@ ignored-modules= # List of classes names for which member attributes should not be checked # (useful for classes with attributes dynamically set). -ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local, aiostream.pipe +ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular diff --git a/requirements-all.txt b/requirements-all.txt index 50791c998d..df9cb7a7f1 100644 --- a/requirements-all.txt +++ b/requirements-all.txt @@ -5,7 +5,6 @@ aiohttp[speedups]==3.10.10 aiohttp-jinja2==1.6 aiohttp-swagger3==0.9.0 aiosignal==1.3.1 -aiostream==0.6.3 appdirs==1.4.4 apscheduler==3.10.4 asn1crypto==1.5.1 diff --git a/requirements-extra.txt b/requirements-extra.txt index 98cebdde10..551a443508 100644 --- a/requirements-extra.txt +++ b/requirements-extra.txt @@ -5,7 +5,6 @@ aiohttp[speedups]==3.10.10 aiohttp-jinja2==1.6 aiohttp-swagger3==0.9.0 aiosignal==1.3.1 -aiostream==0.6.3 appdirs==1.4.4 apscheduler==3.10.4 asn1crypto==1.5.1 diff --git a/requirements.txt b/requirements.txt index 4b37f2983f..1af1fd3cf8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ aiohttp[speedups]==3.10.10 aiohttp-jinja2==1.6 aiohttp-swagger3==0.9.0 aiosignal==1.3.1 -aiostream==0.6.3 appdirs==1.4.4 apscheduler==3.10.4 attrs==24.2.0