From edbe7a90edcc931f00c64ec1697258458dd2414f Mon Sep 17 00:00:00 2001 From: Matthias Veit Date: Tue, 21 Nov 2023 18:43:59 +0100 Subject: [PATCH] [resotocore][feat] Enable count and total count for search queries (#1832) * [resotocore][feat] Enable count and total count for search queries * apply limit in inner loop in case of an aggregation query --- resotocore/resotocore/cli/command.py | 27 +++++----- resotocore/resotocore/cli/model.py | 41 +++++++++++---- resotocore/resotocore/db/arango_query.py | 13 +++-- resotocore/resotocore/db/async_arangodb.py | 3 ++ resotocore/resotocore/db/graphdb.py | 12 ++++- .../resotocore/infra_apps/local_runtime.py | 4 +- resotocore/resotocore/query/model.py | 3 ++ resotocore/resotocore/web/api.py | 52 ++++++++++++++----- .../tests/resotocore/cli/command_test.py | 5 ++ .../tests/resotocore/db/graphdb_test.py | 8 ++- 10 files changed, 123 insertions(+), 45 deletions(-) diff --git a/resotocore/resotocore/cli/command.py b/resotocore/resotocore/cli/command.py index da0e631c06..3881645679 100644 --- a/resotocore/resotocore/cli/command.py +++ b/resotocore/resotocore/cli/command.py @@ -96,6 +96,7 @@ ArgInfo, EntityProvider, FilePath, + CLISourceContext, ) from resotocore.user.model import Permission, AllowedRoleNames from resotocore.cli.tip_of_the_day import SuggestionPolicy, SuggestionStrategy, get_suggestion_strategy @@ -1456,7 +1457,7 @@ async def explain_search() -> AsyncIterator[Json]: explanation = await db.explain(query_model, with_edges) yield to_js(explanation) - async def prepare() -> Tuple[Optional[int], AsyncIterator[Json]]: + async def prepare() -> Tuple[CLISourceContext, AsyncIterator[Json]]: db, graph_name = await get_db(at, current_graph_name) query_model = await load_query_model(db, graph_name) @@ -1484,7 +1485,7 @@ async def iterate_and_close() -> AsyncIterator[Json]: finally: cursor.close() - return cursor.count(), iterate_and_close() + return CLISourceContext(cursor.count(), cursor.full_count()), iterate_and_close() return ( CLISource.single(explain_search, required_permissions={Permission.read}) @@ -1951,7 +1952,7 @@ def show(k: ComplexKind) -> bool: result = sorted([k.fqn for k in model.kinds.values() if isinstance(k, ComplexKind) and show(k)]) return len(model.kinds), stream.iterate(result) - return CLISource(source, required_permissions={Permission.read}) + return CLISource.only_count(source, required_permissions={Permission.read}) class SetDesiredStateBase(CLICommand, EntityProvider, ABC): @@ -3107,9 +3108,9 @@ async def show_help() -> AsyncIterator[str]: elif arg and len(args) == 2: raise CLIParseError(f"Does not understand action {args[0]}. Allowed: add, update, delete.") elif arg and len(args) == 1 and args[0] == "running": - return CLISource(running_jobs, required_permissions={Permission.read}) + return CLISource.only_count(running_jobs, required_permissions={Permission.read}) elif arg and len(args) == 1 and args[0] == "list": - return CLISource(list_jobs, required_permissions={Permission.read}) + return CLISource.only_count(list_jobs, required_permissions={Permission.read}) else: return CLISource.single(show_help, required_permissions={Permission.read}) @@ -3895,7 +3896,7 @@ async def get_template(name: str) -> AsyncIterator[JsonElement]: maybe_template = await self.dependencies.template_expander.get_template(name) yield maybe_template.template if maybe_template else f"No template with this name: {name}" - async def list_templates() -> Tuple[Optional[int], AsyncIterator[Json]]: + async def list_templates() -> Tuple[int, AsyncIterator[Json]]: templates = await self.dependencies.template_expander.list_templates() return len(templates), stream.iterate(template_str(t) for t in templates) @@ -3932,7 +3933,7 @@ async def expand_template(spec: str) -> AsyncIterator[str]: elif arg and len(args) == 1: return CLISource.single(partial(get_template, arg.strip()), required_permissions={Permission.read}) elif not arg: - return CLISource(list_templates, required_permissions={Permission.read}) + return CLISource.only_count(list_templates, required_permissions={Permission.read}) else: raise CLIParseError(f"Can not parse arguments: {arg}") @@ -4431,17 +4432,19 @@ async def stop_workflow(task_id: TaskId) -> AsyncIterator[str]: elif arg and len(args) == 1 and args[0] == "history": return CLISource.single(history_aggregation, required_permissions={Permission.read}) elif arg and len(args) == 2 and args[0] == "history": - return CLISource(partial(history_of, re.split("\\s+", args[1])), required_permissions={Permission.read}) + return CLISource.only_count( + partial(history_of, re.split("\\s+", args[1])), required_permissions={Permission.read} + ) elif arg and len(args) == 2 and args[0] == "log": - return CLISource(partial(show_log, args[1].strip()), required_permissions={Permission.read}) + return CLISource.only_count(partial(show_log, args[1].strip()), required_permissions={Permission.read}) elif arg and len(args) == 2 and args[0] == "run": return CLISource.single(partial(run_workflow, args[1].strip()), required_permissions={Permission.admin}) elif arg and len(args) == 2 and args[0] == "stop": return CLISource.single(partial(stop_workflow, args[1].strip()), required_permissions={Permission.admin}) elif arg and len(args) == 1 and args[0] == "running": - return CLISource(running_workflows, required_permissions={Permission.read}) + return CLISource.only_count(running_workflows, required_permissions={Permission.read}) elif arg and len(args) == 1 and args[0] == "list": - return CLISource(list_workflows, required_permissions={Permission.read}) + return CLISource.only_count(list_workflows, required_permissions={Permission.read}) else: return CLISource.single( lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} @@ -4646,7 +4649,7 @@ async def list_configs() -> Tuple[int, JsStream]: required_permissions={Permission.admin}, ) elif arg and len(args) == 1 and args[0] == "list": - return CLISource(list_configs, required_permissions={Permission.read}) + return CLISource.only_count(list_configs, required_permissions={Permission.read}) else: return CLISource.single( lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} diff --git a/resotocore/resotocore/cli/model.py b/resotocore/resotocore/cli/model.py index 72c076ccb3..77c2241755 100644 --- a/resotocore/resotocore/cli/model.py +++ b/resotocore/resotocore/cli/model.py @@ -230,10 +230,16 @@ def make_stream(in_stream: JsGen) -> JsStream: return in_stream if isinstance(in_stream, Stream) else stream.iterate(in_stream) +@define +class CLISourceContext: + count: Optional[int] = None + total_count: Optional[int] = None + + class CLISource(CLIAction): def __init__( self, - fn: Callable[[], Union[Tuple[Optional[int], JsGen], Awaitable[Tuple[Optional[int], JsGen]]]], + fn: Callable[[], Union[Tuple[CLISourceContext, JsGen], Awaitable[Tuple[CLISourceContext, JsGen]]]], produces: MediaType = MediaType.Json, requires: Optional[List[CLICommandRequirement]] = None, envelope: Optional[Dict[str, str]] = None, @@ -242,10 +248,25 @@ def __init__( super().__init__(produces, requires, envelope, required_permissions) self._fn = fn - async def source(self) -> Tuple[Optional[int], JsStream]: + async def source(self) -> Tuple[CLISourceContext, JsStream]: res = self._fn() - count, gen = await res if iscoroutine(res) else res - return count, self.make_stream(await gen if iscoroutine(gen) else gen) + context, gen = await res if iscoroutine(res) else res + return context, self.make_stream(await gen if iscoroutine(gen) else gen) + + @staticmethod + def only_count( + fn: Callable[[], Union[Tuple[int, JsGen], Awaitable[Tuple[int, JsGen]]]], + produces: MediaType = MediaType.Json, + requires: Optional[List[CLICommandRequirement]] = None, + envelope: Optional[Dict[str, str]] = None, + required_permissions: Optional[Set[Permission]] = None, + ) -> CLISource: + async def combine() -> Tuple[CLISourceContext, JsGen]: + res = fn() + count, gen = await res if iscoroutine(res) else res + return CLISourceContext(count=count, total_count=count), gen + + return CLISource(combine, produces, requires, envelope, required_permissions) @staticmethod def no_count( @@ -266,10 +287,10 @@ def with_count( envelope: Optional[Dict[str, str]] = None, required_permissions: Optional[Set[Permission]] = None, ) -> CLISource: - async def combine() -> Tuple[Optional[int], JsGen]: + async def combine() -> Tuple[CLISourceContext, JsGen]: res = fn() gen = await res if iscoroutine(res) else res - return count, gen + return CLISourceContext(count=count), gen return CLISource(combine, produces, requires, envelope, required_permissions) @@ -695,16 +716,16 @@ def is_allowed_to_execute(self) -> bool: return False return all(self.ctx.user.has_permission(cmd.action.required_permissions) for cmd in self.executable_commands) - async def execute(self) -> Tuple[Optional[int], JsStream]: + async def execute(self) -> Tuple[CLISourceContext, JsStream]: if self.executable_commands: source_action = cast(CLISource, self.executable_commands[0].action) - count, flow = await source_action.source() + context, flow = await source_action.source() for command in self.executable_commands[1:]: flow_action = cast(CLIFlow, command.action) flow = await flow_action.flow(flow) - return count, flow + return context, flow else: - return 0, stream.empty() + return CLISourceContext(count=0), stream.empty() class CLI(ABC): diff --git a/resotocore/resotocore/db/arango_query.py b/resotocore/resotocore/db/arango_query.py index 8eac145918..ec1dd1b6c6 100644 --- a/resotocore/resotocore/db/arango_query.py +++ b/resotocore/resotocore/db/arango_query.py @@ -94,7 +94,10 @@ def to_query( bind_vars: Json = {} start = from_collection or f"`{db.vertex_name}`" cursor, query_str = query_string(db, query, query_model, start, with_edges, bind_vars, count, id_column=id_column) - return f"""{query_str} FOR result in {cursor} RETURN UNSET(result, {unset_props})""", bind_vars + last_limit = ( + f" LIMIT {ll.offset}, {ll.length}" if ((ll := query.current_part.limit) and not query.is_aggregate()) else "" + ) + return f"""{query_str} FOR result in {cursor}{last_limit} RETURN UNSET(result, {unset_props})""", bind_vars def query_string( @@ -467,6 +470,7 @@ def merge_part_result(d: Json) -> str: def part(p: Part, in_cursor: str, part_idx: int) -> Tuple[Part, str, str, str]: query_part = "" filtered_out = "" + last_part = len(query.parts) == (part_idx + 1) def filter_statement(current_cursor: str, part_term: Term, limit: Optional[Limit]) -> str: if isinstance(part_term, AllTerm) and limit is None and not p.sort: @@ -665,9 +669,10 @@ def navigation(in_crsr: str, nav: Navigation) -> str: query_part += f"LET {nav_crsr} = UNION_DISTINCT({all_walks_combined})" return nav_crsr - # apply the limit in the filter statement only, when no with clause is present - # otherwise the limit is applied in the with clause - filter_limit = p.limit if p.with_clause is None else None + # Skip the limit in case of + # - with clause: the limit is applied in the with clause + # - last part of a non aggregation query: the limit is applied in the outermost for loop + filter_limit = p.limit if (p.with_clause is None and (not last_part or query.is_aggregate())) else None cursor = in_cursor part_term = p.term if isinstance(p.term, MergeTerm): diff --git a/resotocore/resotocore/db/async_arangodb.py b/resotocore/resotocore/db/async_arangodb.py index 6dd7de7b21..0f219fb0ee 100644 --- a/resotocore/resotocore/db/async_arangodb.py +++ b/resotocore/resotocore/db/async_arangodb.py @@ -73,6 +73,9 @@ def close(self) -> None: def count(self) -> Optional[int]: return self.cursor.count() + def full_count(self) -> Optional[int]: + return stats.get("fullCount") if (stats := self.cursor.statistics()) else None + async def next_filtered(self) -> Optional[Json]: element = await self.next_from_db() vertex: Optional[Json] = None diff --git a/resotocore/resotocore/db/graphdb.py b/resotocore/resotocore/db/graphdb.py index 0f3696149d..3cdf7ce9d8 100644 --- a/resotocore/resotocore/db/graphdb.py +++ b/resotocore/resotocore/db/graphdb.py @@ -567,6 +567,7 @@ async def list_possible_values( return await self.db.aql_cursor( query=q_string, count=with_count, + full_count=with_count, bind_vars=bind, batch_size=10000, ttl=cast(Number, int(timeout.total_seconds())) if timeout else None, @@ -581,6 +582,7 @@ async def search_list( query=q_string, trafo=self.document_to_instance_fn(query.model, query), count=with_count, + full_count=with_count, bind_vars=bind, batch_size=10000, ttl=cast(Number, int(timeout.total_seconds())) if timeout else None, @@ -623,7 +625,13 @@ async def search_history( ) ttl = cast(Number, int(timeout.total_seconds())) if timeout else None return await self.db.aql_cursor( - query=q_string, trafo=trafo, count=with_count, bind_vars=bind, batch_size=10000, ttl=ttl + query=q_string, + trafo=trafo, + count=with_count, + full_count=with_count, + bind_vars=bind, + batch_size=10000, + ttl=ttl, ) async def search_graph_gen( @@ -636,6 +644,7 @@ async def search_graph_gen( trafo=self.document_to_instance_fn(query.model, query), bind_vars=bind, count=with_count, + full_count=with_count, batch_size=10000, ttl=cast(Number, int(timeout.total_seconds())) if timeout else None, ) @@ -990,6 +999,7 @@ def combine_dict(left: Dict[K, List[V]], right: Dict[K, List[V]]) -> Dict[K, Lis for num, (root, graph) in enumerate(graphs): root_kind = GraphResolver.resolved_kind(graph_to_merge.nodes[root]) if root_kind: + # noinspection PyTypeChecker log.info(f"Update subgraph: root={root} ({root_kind}, {num+1} of {len(roots)})") node_query = self.query_update_nodes(root_kind), {"update_id": root} edge_query = partial(merge_edges, root, root_kind) diff --git a/resotocore/resotocore/infra_apps/local_runtime.py b/resotocore/resotocore/infra_apps/local_runtime.py index a6ef1acc70..15fa8d0229 100644 --- a/resotocore/resotocore/infra_apps/local_runtime.py +++ b/resotocore/resotocore/infra_apps/local_runtime.py @@ -113,8 +113,8 @@ async def _interpret_line(self, line: str, ctx: CLIContext) -> Stream: total_nr_outputs: int = 0 parsed_commands = await self.cli.evaluate_cli_command(line, ctx, True) for parsed in parsed_commands: - nr_outputs, command_output_stream = await parsed.execute() - total_nr_outputs = total_nr_outputs + (nr_outputs or 0) + src_ctx, command_output_stream = await parsed.execute() + total_nr_outputs = total_nr_outputs + (src_ctx.count or 0) command_streams.append(command_output_stream) return stream.concat(stream.iterate(command_streams), task_limit=1) diff --git a/resotocore/resotocore/query/model.py b/resotocore/resotocore/query/model.py index f22a01b196..ff794a7c55 100644 --- a/resotocore/resotocore/query/model.py +++ b/resotocore/resotocore/query/model.py @@ -916,6 +916,9 @@ def merge_names(self) -> Set[str]: def merge_query_by_name(self) -> List[MergeQuery]: return [mt for part in self.parts if isinstance(part.term, MergeTerm) for mt in part.term.merge] + def is_aggregate(self) -> bool: + return self.aggregate is not None + def filter(self, term: Union[str, Term], *terms: Union[str, Term]) -> Query: res = Query.mk_term(term, *terms) parts = self.parts.copy() diff --git a/resotocore/resotocore/web/api.py b/resotocore/resotocore/web/api.py index a0118c9083..78e864926b 100644 --- a/resotocore/resotocore/web/api.py +++ b/resotocore/resotocore/web/api.py @@ -658,7 +658,7 @@ async def perform_benchmark(self, request: Request, deps: TenantDependencies) -> raise ValueError(f"Unknown action {action}. One of run or load is expected.") result_graph = results[benchmark].to_graph() async with stream.iterate(result_graph).stream() as streamer: - return await self.stream_response_from_gen(request, streamer, len(result_graph)) + return await self.stream_response_from_gen(request, streamer, count=len(result_graph)) async def inspection_checks(self, request: Request, deps: TenantDependencies) -> StreamResponse: provider = request.query.get("provider") @@ -1076,6 +1076,7 @@ async def property_path_complete(self, request: Request, deps: TenantDependencie async def possible_values(self, request: Request, deps: TenantDependencies) -> StreamResponse: graph_db, query_model = await self.graph_query_model_from_request(request, deps) section = section_of(request) + # noinspection PyTypeChecker detail: Literal["attributes", "values"] = "attributes" if request.path.endswith("attributes") else "values" root_or_section = None if section is None or section == PathRoot else section fn = partial(variable_to_absolute, root_or_section) @@ -1090,7 +1091,9 @@ async def possible_values(self, request: Request, deps: TenantDependencies) -> S async with await graph_db.list_possible_values( query_model, prop_or_predicate, detail, limit, skip, count ) as cursor: - return await self.stream_response_from_gen(request, cursor, cursor.count()) + return await self.stream_response_from_gen( + request, cursor, count=cursor.count(), total_count=cursor.full_count() + ) async def query_structure(self, request: Request, deps: TenantDependencies) -> StreamResponse: _, query_model = await self.graph_query_model_from_request(request, deps) @@ -1101,7 +1104,9 @@ async def query_list(self, request: Request, deps: TenantDependencies) -> Stream count = request.query.get("count", "true").lower() != "false" timeout = if_set(request.query.get("search_timeout"), duration) async with await graph_db.search_list(query_model, count, timeout) as cursor: - return await self.stream_response_from_gen(request, cursor, cursor.count()) + return await self.stream_response_from_gen( + request, cursor, count=cursor.count(), total_count=cursor.full_count() + ) async def cytoscape(self, request: Request, deps: TenantDependencies) -> StreamResponse: graph_db, query_model = await self.graph_query_model_from_request(request, deps) @@ -1114,12 +1119,16 @@ async def query_graph_stream(self, request: Request, deps: TenantDependencies) - count = request.query.get("count", "true").lower() != "false" timeout = if_set(request.query.get("search_timeout"), duration) async with await graph_db.search_graph_gen(query_model, count, timeout) as cursor: - return await self.stream_response_from_gen(request, cursor, cursor.count()) + return await self.stream_response_from_gen( + request, cursor, count=cursor.count(), total_count=cursor.full_count() + ) async def query_aggregation(self, request: Request, deps: TenantDependencies) -> StreamResponse: graph_db, query_model = await self.graph_query_model_from_request(request, deps) - async with await graph_db.search_aggregation(query_model) as gen: - return await self.stream_response_from_gen(request, gen) + async with await graph_db.search_aggregation(query_model) as cursor: + return await self.stream_response_from_gen( + request, cursor, count=cursor.count(), total_count=cursor.full_count() + ) async def query_history(self, request: Request, deps: TenantDependencies) -> StreamResponse: graph_db, query_model = await self.graph_query_model_from_request(request, deps) @@ -1131,8 +1140,10 @@ async def query_history(self, request: Request, deps: TenantDependencies) -> Str change=HistoryChange[change] if change else None, before=parse_utc(before) if before else None, after=parse_utc(after) if after else None, - ) as gen: - return await self.stream_response_from_gen(request, gen) + ) as cursor: + return await self.stream_response_from_gen( + request, cursor, count=cursor.count(), total_count=cursor.full_count() + ) async def serve_debug_ui(self, request: Request) -> FileResponse: """ @@ -1285,13 +1296,19 @@ async def execute_parsed( return web.json_response(data, status=424) elif len(parsed) == 1: first_result = parsed[0] - count, generator = await first_result.execute() + src_ctx, generator = await first_result.execute() # flat the results from 0 or 1 async with generator.stream() as streamer: gen = await force_gen(streamer) if first_result.produces.text: text_gen = ctx.text_generator(first_result, gen) - return await self.stream_response_from_gen(request, text_gen, count, first_result.envelope) + return await self.stream_response_from_gen( + request, + text_gen, + count=src_ctx.count, + total_count=src_ctx.total_count, + additional_header=first_result.envelope, + ) elif first_result.produces.file_path: await mp_response.prepare(request) await Api.multi_file_response(first_result, gen, boundary, mp_response) @@ -1302,7 +1319,7 @@ async def execute_parsed( elif len(parsed) > 1: await mp_response.prepare(request) for single in parsed: - count, generator = await single.execute() + _, generator = await single.execute() async with generator.stream() as streamer: gen = await force_gen(streamer) if single.produces.text: @@ -1373,15 +1390,22 @@ def optional_json(o: Any, hint: str) -> StreamResponse: async def stream_response_from_gen( request: Request, gen_in: AsyncIterator[JsonElement], + *, count: Optional[int] = None, + total_count: Optional[int] = None, additional_header: Optional[Dict[str, str]] = None, ) -> StreamResponse: # force the async generator, to get an early exception in case of failure gen = await force_gen(gen_in) content_type, result_gen = await result_binary_gen(request, gen) - count_header = {"Resoto-Shell-Element-Count": str(count)} if count else {} - hdr = additional_header or {} - response = web.StreamResponse(status=200, headers={**hdr, "Content-Type": content_type, **count_header}) + headers = {"Content-Type": content_type} + if additional_header: + headers.update(additional_header) + if count is not None: + headers["Result-Count"] = str(count) + if total_count is not None: + headers["Total-Count"] = str(total_count) + response = web.StreamResponse(status=200, headers=headers) enable_compression(request, response) writer: AbstractStreamWriter = await response.prepare(request) # type: ignore cr = "\n".encode("utf-8") diff --git a/resotocore/tests/resotocore/cli/command_test.py b/resotocore/tests/resotocore/cli/command_test.py index 70dab88810..54e53c3bf6 100644 --- a/resotocore/tests/resotocore/cli/command_test.py +++ b/resotocore/tests/resotocore/cli/command_test.py @@ -625,6 +625,11 @@ async def test_list_command(cli: CLI) -> None: ) # b as a ==> b, c as a ==> c, c ==> c_1, ancestors.account.reported.a ==> account_a, again ==> _1 assert result[0][0] == "a=a, b=true, c=false, c_1=false, account_a=a, account_a_1=a, foo=a" + # source context is passed correctly + parsed = await cli.evaluate_cli_command("search is (bla) | head 10 | list") + src_ctx, gen = await parsed[0].execute() + assert src_ctx.count == 10 + assert src_ctx.total_count == 100 @pytest.mark.asyncio diff --git a/resotocore/tests/resotocore/db/graphdb_test.py b/resotocore/tests/resotocore/db/graphdb_test.py index ca51692cbd..1aceb31d7b 100644 --- a/resotocore/tests/resotocore/db/graphdb_test.py +++ b/resotocore/tests/resotocore/db/graphdb_test.py @@ -373,11 +373,13 @@ async def assert_result(query: str, nodes: int, edges: int) -> None: @mark.asyncio async def test_query_nested(filled_graph_db: ArangoGraphDB, foo_model: Model) -> None: - async def assert_count(query: str, count: int) -> None: + async def assert_count(query: str, count: int, total_count: Optional[int] = None) -> None: q = parse_query(query).on_section("reported") async with await filled_graph_db.search_list(QueryModel(q, foo_model), with_count=True) as gen: - assert gen.cursor.count() == count + assert gen.count() == count assert len([a async for a in gen]) == count + if total_count: + assert gen.full_count() == total_count await assert_count("is(bla) and h.inner[*].inner[*].name=in_0_1", 100) await assert_count("is(bla) and h.inner[*].inner[*].inner == []", 100) @@ -388,6 +390,8 @@ async def assert_count(query: str, count: int) -> None: await assert_count("is(bla) and g[*] any = 2", 100) await assert_count("is(bla) and g[*] all = 2", 0) await assert_count("is(bla) and g[*] none = 2", 0) + await assert_count("is(bla) limit 1", 1, 100) + await assert_count("is(bla) limit 10, 10", 10, 100) @mark.asyncio