Skip to content

Commit

Permalink
[resotocore][feat] Enable count and total count for search queries (#…
Browse files Browse the repository at this point in the history
…1832)

* [resotocore][feat] Enable count and total count for search queries

* apply limit in inner loop in case of an aggregation query
  • Loading branch information
aquamatthias authored Nov 21, 2023
1 parent 38d43d6 commit edbe7a9
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 45 deletions.
27 changes: 15 additions & 12 deletions resotocore/resotocore/cli/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
ArgInfo,
EntityProvider,
FilePath,
CLISourceContext,
)
from resotocore.user.model import Permission, AllowedRoleNames
from resotocore.cli.tip_of_the_day import SuggestionPolicy, SuggestionStrategy, get_suggestion_strategy
Expand Down Expand Up @@ -1456,7 +1457,7 @@ async def explain_search() -> AsyncIterator[Json]:
explanation = await db.explain(query_model, with_edges)
yield to_js(explanation)

async def prepare() -> Tuple[Optional[int], AsyncIterator[Json]]:
async def prepare() -> Tuple[CLISourceContext, AsyncIterator[Json]]:
db, graph_name = await get_db(at, current_graph_name)

query_model = await load_query_model(db, graph_name)
Expand Down Expand Up @@ -1484,7 +1485,7 @@ async def iterate_and_close() -> AsyncIterator[Json]:
finally:
cursor.close()

return cursor.count(), iterate_and_close()
return CLISourceContext(cursor.count(), cursor.full_count()), iterate_and_close()

return (
CLISource.single(explain_search, required_permissions={Permission.read})
Expand Down Expand Up @@ -1951,7 +1952,7 @@ def show(k: ComplexKind) -> bool:
result = sorted([k.fqn for k in model.kinds.values() if isinstance(k, ComplexKind) and show(k)])
return len(model.kinds), stream.iterate(result)

return CLISource(source, required_permissions={Permission.read})
return CLISource.only_count(source, required_permissions={Permission.read})


class SetDesiredStateBase(CLICommand, EntityProvider, ABC):
Expand Down Expand Up @@ -3107,9 +3108,9 @@ async def show_help() -> AsyncIterator[str]:
elif arg and len(args) == 2:
raise CLIParseError(f"Does not understand action {args[0]}. Allowed: add, update, delete.")
elif arg and len(args) == 1 and args[0] == "running":
return CLISource(running_jobs, required_permissions={Permission.read})
return CLISource.only_count(running_jobs, required_permissions={Permission.read})
elif arg and len(args) == 1 and args[0] == "list":
return CLISource(list_jobs, required_permissions={Permission.read})
return CLISource.only_count(list_jobs, required_permissions={Permission.read})
else:
return CLISource.single(show_help, required_permissions={Permission.read})

Expand Down Expand Up @@ -3895,7 +3896,7 @@ async def get_template(name: str) -> AsyncIterator[JsonElement]:
maybe_template = await self.dependencies.template_expander.get_template(name)
yield maybe_template.template if maybe_template else f"No template with this name: {name}"

async def list_templates() -> Tuple[Optional[int], AsyncIterator[Json]]:
async def list_templates() -> Tuple[int, AsyncIterator[Json]]:
templates = await self.dependencies.template_expander.list_templates()
return len(templates), stream.iterate(template_str(t) for t in templates)

Expand Down Expand Up @@ -3932,7 +3933,7 @@ async def expand_template(spec: str) -> AsyncIterator[str]:
elif arg and len(args) == 1:
return CLISource.single(partial(get_template, arg.strip()), required_permissions={Permission.read})
elif not arg:
return CLISource(list_templates, required_permissions={Permission.read})
return CLISource.only_count(list_templates, required_permissions={Permission.read})
else:
raise CLIParseError(f"Can not parse arguments: {arg}")

Expand Down Expand Up @@ -4431,17 +4432,19 @@ async def stop_workflow(task_id: TaskId) -> AsyncIterator[str]:
elif arg and len(args) == 1 and args[0] == "history":
return CLISource.single(history_aggregation, required_permissions={Permission.read})
elif arg and len(args) == 2 and args[0] == "history":
return CLISource(partial(history_of, re.split("\\s+", args[1])), required_permissions={Permission.read})
return CLISource.only_count(
partial(history_of, re.split("\\s+", args[1])), required_permissions={Permission.read}
)
elif arg and len(args) == 2 and args[0] == "log":
return CLISource(partial(show_log, args[1].strip()), required_permissions={Permission.read})
return CLISource.only_count(partial(show_log, args[1].strip()), required_permissions={Permission.read})
elif arg and len(args) == 2 and args[0] == "run":
return CLISource.single(partial(run_workflow, args[1].strip()), required_permissions={Permission.admin})
elif arg and len(args) == 2 and args[0] == "stop":
return CLISource.single(partial(stop_workflow, args[1].strip()), required_permissions={Permission.admin})
elif arg and len(args) == 1 and args[0] == "running":
return CLISource(running_workflows, required_permissions={Permission.read})
return CLISource.only_count(running_workflows, required_permissions={Permission.read})
elif arg and len(args) == 1 and args[0] == "list":
return CLISource(list_workflows, required_permissions={Permission.read})
return CLISource.only_count(list_workflows, required_permissions={Permission.read})
else:
return CLISource.single(
lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read}
Expand Down Expand Up @@ -4646,7 +4649,7 @@ async def list_configs() -> Tuple[int, JsStream]:
required_permissions={Permission.admin},
)
elif arg and len(args) == 1 and args[0] == "list":
return CLISource(list_configs, required_permissions={Permission.read})
return CLISource.only_count(list_configs, required_permissions={Permission.read})
else:
return CLISource.single(
lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read}
Expand Down
41 changes: 31 additions & 10 deletions resotocore/resotocore/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,16 @@ def make_stream(in_stream: JsGen) -> JsStream:
return in_stream if isinstance(in_stream, Stream) else stream.iterate(in_stream)


@define
class CLISourceContext:
count: Optional[int] = None
total_count: Optional[int] = None


class CLISource(CLIAction):
def __init__(
self,
fn: Callable[[], Union[Tuple[Optional[int], JsGen], Awaitable[Tuple[Optional[int], JsGen]]]],
fn: Callable[[], Union[Tuple[CLISourceContext, JsGen], Awaitable[Tuple[CLISourceContext, JsGen]]]],
produces: MediaType = MediaType.Json,
requires: Optional[List[CLICommandRequirement]] = None,
envelope: Optional[Dict[str, str]] = None,
Expand All @@ -242,10 +248,25 @@ def __init__(
super().__init__(produces, requires, envelope, required_permissions)
self._fn = fn

async def source(self) -> Tuple[Optional[int], JsStream]:
async def source(self) -> Tuple[CLISourceContext, JsStream]:
res = self._fn()
count, gen = await res if iscoroutine(res) else res
return count, self.make_stream(await gen if iscoroutine(gen) else gen)
context, gen = await res if iscoroutine(res) else res
return context, self.make_stream(await gen if iscoroutine(gen) else gen)

@staticmethod
def only_count(
fn: Callable[[], Union[Tuple[int, JsGen], Awaitable[Tuple[int, JsGen]]]],
produces: MediaType = MediaType.Json,
requires: Optional[List[CLICommandRequirement]] = None,
envelope: Optional[Dict[str, str]] = None,
required_permissions: Optional[Set[Permission]] = None,
) -> CLISource:
async def combine() -> Tuple[CLISourceContext, JsGen]:
res = fn()
count, gen = await res if iscoroutine(res) else res
return CLISourceContext(count=count, total_count=count), gen

return CLISource(combine, produces, requires, envelope, required_permissions)

@staticmethod
def no_count(
Expand All @@ -266,10 +287,10 @@ def with_count(
envelope: Optional[Dict[str, str]] = None,
required_permissions: Optional[Set[Permission]] = None,
) -> CLISource:
async def combine() -> Tuple[Optional[int], JsGen]:
async def combine() -> Tuple[CLISourceContext, JsGen]:
res = fn()
gen = await res if iscoroutine(res) else res
return count, gen
return CLISourceContext(count=count), gen

return CLISource(combine, produces, requires, envelope, required_permissions)

Expand Down Expand Up @@ -695,16 +716,16 @@ def is_allowed_to_execute(self) -> bool:
return False
return all(self.ctx.user.has_permission(cmd.action.required_permissions) for cmd in self.executable_commands)

async def execute(self) -> Tuple[Optional[int], JsStream]:
async def execute(self) -> Tuple[CLISourceContext, JsStream]:
if self.executable_commands:
source_action = cast(CLISource, self.executable_commands[0].action)
count, flow = await source_action.source()
context, flow = await source_action.source()
for command in self.executable_commands[1:]:
flow_action = cast(CLIFlow, command.action)
flow = await flow_action.flow(flow)
return count, flow
return context, flow
else:
return 0, stream.empty()
return CLISourceContext(count=0), stream.empty()


class CLI(ABC):
Expand Down
13 changes: 9 additions & 4 deletions resotocore/resotocore/db/arango_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def to_query(
bind_vars: Json = {}
start = from_collection or f"`{db.vertex_name}`"
cursor, query_str = query_string(db, query, query_model, start, with_edges, bind_vars, count, id_column=id_column)
return f"""{query_str} FOR result in {cursor} RETURN UNSET(result, {unset_props})""", bind_vars
last_limit = (
f" LIMIT {ll.offset}, {ll.length}" if ((ll := query.current_part.limit) and not query.is_aggregate()) else ""
)
return f"""{query_str} FOR result in {cursor}{last_limit} RETURN UNSET(result, {unset_props})""", bind_vars


def query_string(
Expand Down Expand Up @@ -467,6 +470,7 @@ def merge_part_result(d: Json) -> str:
def part(p: Part, in_cursor: str, part_idx: int) -> Tuple[Part, str, str, str]:
query_part = ""
filtered_out = ""
last_part = len(query.parts) == (part_idx + 1)

def filter_statement(current_cursor: str, part_term: Term, limit: Optional[Limit]) -> str:
if isinstance(part_term, AllTerm) and limit is None and not p.sort:
Expand Down Expand Up @@ -665,9 +669,10 @@ def navigation(in_crsr: str, nav: Navigation) -> str:
query_part += f"LET {nav_crsr} = UNION_DISTINCT({all_walks_combined})"
return nav_crsr

# apply the limit in the filter statement only, when no with clause is present
# otherwise the limit is applied in the with clause
filter_limit = p.limit if p.with_clause is None else None
# Skip the limit in case of
# - with clause: the limit is applied in the with clause
# - last part of a non aggregation query: the limit is applied in the outermost for loop
filter_limit = p.limit if (p.with_clause is None and (not last_part or query.is_aggregate())) else None
cursor = in_cursor
part_term = p.term
if isinstance(p.term, MergeTerm):
Expand Down
3 changes: 3 additions & 0 deletions resotocore/resotocore/db/async_arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def close(self) -> None:
def count(self) -> Optional[int]:
return self.cursor.count()

def full_count(self) -> Optional[int]:
return stats.get("fullCount") if (stats := self.cursor.statistics()) else None

async def next_filtered(self) -> Optional[Json]:
element = await self.next_from_db()
vertex: Optional[Json] = None
Expand Down
12 changes: 11 additions & 1 deletion resotocore/resotocore/db/graphdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ async def list_possible_values(
return await self.db.aql_cursor(
query=q_string,
count=with_count,
full_count=with_count,
bind_vars=bind,
batch_size=10000,
ttl=cast(Number, int(timeout.total_seconds())) if timeout else None,
Expand All @@ -581,6 +582,7 @@ async def search_list(
query=q_string,
trafo=self.document_to_instance_fn(query.model, query),
count=with_count,
full_count=with_count,
bind_vars=bind,
batch_size=10000,
ttl=cast(Number, int(timeout.total_seconds())) if timeout else None,
Expand Down Expand Up @@ -623,7 +625,13 @@ async def search_history(
)
ttl = cast(Number, int(timeout.total_seconds())) if timeout else None
return await self.db.aql_cursor(
query=q_string, trafo=trafo, count=with_count, bind_vars=bind, batch_size=10000, ttl=ttl
query=q_string,
trafo=trafo,
count=with_count,
full_count=with_count,
bind_vars=bind,
batch_size=10000,
ttl=ttl,
)

async def search_graph_gen(
Expand All @@ -636,6 +644,7 @@ async def search_graph_gen(
trafo=self.document_to_instance_fn(query.model, query),
bind_vars=bind,
count=with_count,
full_count=with_count,
batch_size=10000,
ttl=cast(Number, int(timeout.total_seconds())) if timeout else None,
)
Expand Down Expand Up @@ -990,6 +999,7 @@ def combine_dict(left: Dict[K, List[V]], right: Dict[K, List[V]]) -> Dict[K, Lis
for num, (root, graph) in enumerate(graphs):
root_kind = GraphResolver.resolved_kind(graph_to_merge.nodes[root])
if root_kind:
# noinspection PyTypeChecker
log.info(f"Update subgraph: root={root} ({root_kind}, {num+1} of {len(roots)})")
node_query = self.query_update_nodes(root_kind), {"update_id": root}
edge_query = partial(merge_edges, root, root_kind)
Expand Down
4 changes: 2 additions & 2 deletions resotocore/resotocore/infra_apps/local_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ async def _interpret_line(self, line: str, ctx: CLIContext) -> Stream:
total_nr_outputs: int = 0
parsed_commands = await self.cli.evaluate_cli_command(line, ctx, True)
for parsed in parsed_commands:
nr_outputs, command_output_stream = await parsed.execute()
total_nr_outputs = total_nr_outputs + (nr_outputs or 0)
src_ctx, command_output_stream = await parsed.execute()
total_nr_outputs = total_nr_outputs + (src_ctx.count or 0)
command_streams.append(command_output_stream)

return stream.concat(stream.iterate(command_streams), task_limit=1)
3 changes: 3 additions & 0 deletions resotocore/resotocore/query/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,9 @@ def merge_names(self) -> Set[str]:
def merge_query_by_name(self) -> List[MergeQuery]:
return [mt for part in self.parts if isinstance(part.term, MergeTerm) for mt in part.term.merge]

def is_aggregate(self) -> bool:
return self.aggregate is not None

def filter(self, term: Union[str, Term], *terms: Union[str, Term]) -> Query:
res = Query.mk_term(term, *terms)
parts = self.parts.copy()
Expand Down
Loading

0 comments on commit edbe7a9

Please sign in to comment.