Skip to content

Commit

Permalink
[feat][resotocore] Implement aggregate over arrays (#1780)
Browse files Browse the repository at this point in the history
* [feat][resotocore] Implement aggregate over arrays

* add default sort for aggregation queries

* fix test

* use hashable container

* sort on the correct value

* fix test
  • Loading branch information
aquamatthias authored Sep 26, 2023
1 parent 0a9a16d commit 8afbdfa
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 22 deletions.
9 changes: 7 additions & 2 deletions resotocore/resotocore/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ async def parse_query(query_arg: str) -> Query:
# we also add the aggregate_to_count command after the query
assert query.aggregate is None, "Can not combine aggregate and count!"
group_by_var = [AggregateVariable(AggregateVariableName(arg), "name")] if arg else []
aggregate = Aggregate(group_by_var, [AggregateFunction("sum", 1, [], "count")])
aggregate = Aggregate(group_by_var, [AggregateFunction("sum", 1, (), "count")])
# If the query should be explained, we want the output as is
if "explain" not in parsed_options:
additional_commands.append(self.command("aggregate_to_count", None, ctx))
Expand Down Expand Up @@ -490,8 +490,13 @@ async def parse_query(query_arg: str) -> Query:

# If the last part is a navigation, we need to add sort which will ingest a new part.
with_sort = query.set_sort(*DefaultSort) if query.current_part.navigation else query
section = ctx.env.get("section", PathRoot)
# If this is an aggregate query, the default sort needs to be changed
if query.aggregate is not None and query.current_part.sort == DefaultSort:
with_sort = query.set_sort(*query.aggregate.sort_by_fn(section))

# When all parts are combined, interpret the result on defined section.
final_query = with_sort.on_section(ctx.env.get("section", PathRoot))
final_query = with_sort.on_section(section)
options = ExecuteSearchCommand.argument_string(parsed_options)
query_string = str(final_query)
execute_search = self.command("execute_search", f"{options}'{query_string}'", ctx)
Expand Down
88 changes: 81 additions & 7 deletions resotocore/resotocore/db/arango_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,30 +153,104 @@ def escape_part(path_part: str) -> str:
return prop_name, resolved, merge_name

def aggregate(in_cursor: str, a: Aggregate) -> Tuple[str, str]:
cursor = next_crs("agg")

cursor_lookup: Dict[Tuple[str, ...], str] = {}
nested_function_lookup: Dict[AggregateFunction, str] = {}
nested = {name for agg in a.group_by for name in agg.all_names() if array_marker_in_path_regexp.search(name)}
# If we have a nested array, we need to unfold the array and create a new for loop for each array access.
if nested:
cursor = next_crs("agg")
for_loop = f"for {cursor} in {in_cursor}"
internals = []
for ag in nested:
inner_crsr = cursor
ars = [a.lstrip(".") for a in array_marker_in_path_regexp.split(ag)]
ar_parts = []
for ar in ars[0:-1]:
ar_parts.append(ar)
if tuple(ar_parts) in cursor_lookup:
continue
nxt_crs = next_crs("pre")
cursor_lookup[tuple(ar_parts)] = nxt_crs
for_loop += f" FOR {nxt_crs} IN APPEND(TO_ARRAY({inner_crsr}.{ar}), {{_internal: true}})"
internals.append(f"{nxt_crs}._internal!=true")
inner_crsr = nxt_crs
for_loop += f" FILTER {' OR '.join(internals)}"
else:
cursor = next_crs("agg")
for_loop = f"for {cursor} in {in_cursor}"

# the property needs to be accessed from the correct cursor
def prop_for(name: str) -> str:
ars = [a.lstrip(".") for a in array_marker_in_path_regexp.split(name)]
if len(ars) == 1: # no array access
return f"{cursor}.{name}"
else: # array access
return f"{cursor_lookup[tuple(ars[0:-1])]}.{ars[-1]}"

# the function needs to be accessed from the correct cursor or from a let expression
def function_value_for(fn: AggregateFunction, name: str) -> str:
ars = [a.lstrip(".") for a in array_marker_in_path_regexp.split(name)]
if len(ars) == 1: # no array access
return f"{cursor}.{fn.name}"
elif tuple(ars[0:-1]) in cursor_lookup: # array access with a related group variable
return f"{cursor_lookup[tuple(ars[0:-1])]}.{ars[-1]}"
else: # array access without a related group variable -> let expression
return nested_function_lookup[fn]

# compute the correct cursor name for the given variable
def var_name(n: Union[AggregateVariableName, AggregateVariableCombined]) -> str:
def comb_name(cb: Union[str, AggregateVariableName]) -> str:
return f'"{cb}"' if isinstance(cb, str) else f"{cursor}.{cb.name}"
return f'"{cb}"' if isinstance(cb, str) else prop_for(cb.name)

return (
f"{cursor}.{n.name}"
prop_for(n.name)
if isinstance(n, AggregateVariableName)
else f'CONCAT({",".join(comb_name(cp) for cp in n.parts)})'
)

# compute the correct function term for the given function
def func_term(fn: AggregateFunction) -> str:
name = f"{cursor}.{fn.name}" if isinstance(fn.name, str) else str(fn.name)
name = function_value_for(fn, fn.name) if isinstance(fn.name, str) else str(fn.name)
return f"{name} {fn.combined_ops()}" if fn.ops else name

# if the function accesses an array, we need to handle this specially
# - in case the property name is also used in the group by, we can simply use the variable cursor
# - if not we need to create a separate let expression before the collect statement
# inside the collect / aggregate we can refer to the let expression
# - if only a part of property name is used, use the last known cursor.
# example: a[*].c in var and a[*].b[*].d in group -> use the a[*] cursor
def unfold_array_func_term(fn: AggregateFunction) -> Optional[str]:
if isinstance(fn.name, int):
return None
ars = [a.lstrip(".") for a in array_marker_in_path_regexp.split(fn.name)]
if len(ars) == 1:
return None
# array access without a related group variable.
res = next_crs("agg_let")
nested_function_lookup[fn] = res
pre = ""
current = cursor
car = []
for ar in ars[0:-1]:
car.append(ar)
tcar = tuple(car)
if tcar in cursor_lookup:
current = cursor_lookup[tcar]
continue
nxt_crs = next_crs("inner")
pre += f" FOR {nxt_crs} IN TO_ARRAY({current}.{ar})"
current = nxt_crs
return f"LET {res} = {fn.function}({pre} RETURN {current}.{ars[-1]})"

variables = ", ".join(f"var_{num}={var_name(v.name)}" for num, v in enumerate(a.group_by))
funcs = ", ".join(f"fn_{num}={f.function}({func_term(f)})" for num, f in enumerate(a.group_func))
agg_vars = ", ".join(f'"{v.get_as_name()}": var_{num}' for num, v in enumerate(a.group_by))
array_functions = " ".join((af for af in (unfold_array_func_term(f) for f in a.group_func) if af is not None))
funcs = ", ".join(f"fn_{num}={f.function}({func_term(f)})" for num, f in enumerate(a.group_func))
agg_funcs = ", ".join(f'"{f.get_as_name()}": fn_{num}' for num, f in enumerate(a.group_func))
group_result = f'"group":{{{agg_vars}}},' if a.group_by else ""
aggregate_term = f"collect {variables} aggregate {funcs}"
return_result = f"{{{group_result} {agg_funcs}}}"
return "aggregated", f"LET aggregated = (for {cursor} in {in_cursor} {aggregate_term} RETURN {return_result})"
return "aggregated", f"LET aggregated = ({for_loop} {array_functions} {aggregate_term} RETURN {return_result})"

def predicate(cursor: str, p: Predicate, context_path: Optional[str] = None) -> Tuple[Optional[str], str]:
pre = ""
Expand Down
13 changes: 11 additions & 2 deletions resotocore/resotocore/query/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import abc
import re
from collections import defaultdict
from datetime import datetime, timedelta

Expand Down Expand Up @@ -798,11 +799,14 @@ def property_paths(self) -> Set[str]:
return set(self.all_names())


AggregateOp = Tuple[str, Union[int, float]] # (operation, value or variable). e.g. ("+", 1) or ("-", "var1")


@define(order=True, hash=True, frozen=True)
class AggregateFunction:
function: str
name: Union[str, int]
ops: List[Tuple[str, Union[int, float]]] = field(factory=list)
ops: Tuple[AggregateOp, ...] = field(factory=tuple) # tuple instead of list to be hashable
as_name: Optional[str] = None

def __str__(self) -> str:
Expand All @@ -814,7 +818,7 @@ def combined_ops(self) -> str:
return " ".join(f"{op} {value}" for op, value in self.ops)

def get_as_name(self) -> str:
return self.as_name if self.as_name else f"{self.function}_of_{self.name}"
return self.as_name if self.as_name else re.sub(r"\W+", "_", f"{self.function}_of_{self.name}")

def change_variable(self, fn: Callable[[str], str]) -> AggregateFunction:
return evolve(self, name=fn(self.name)) if isinstance(self.name, str) else self
Expand Down Expand Up @@ -844,6 +848,11 @@ def property_paths(self) -> Set[str]:
result.update(agg.property_paths()) # type: ignore
return result

def sort_by_fn(self, section: str) -> List[Sort]:
root_or_section = None if section == PathRoot else section
on_section = partial(variable_to_absolute, root_or_section)
return [Sort("/" + fn.change_variable(on_section).get_as_name()) for fn in self.group_func]


SimpleValue = Union[str, int, float, bool]

Expand Down
2 changes: 1 addition & 1 deletion resotocore/resotocore/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def aggregate_group_function_parser() -> Parser:
as_name = None
if with_as:
as_name = yield literal_p
return AggregateFunction(func, term_or_int, ops_list, as_name)
return AggregateFunction(func, term_or_int, tuple(ops_list), as_name)


@make_parser
Expand Down
2 changes: 1 addition & 1 deletion resotocore/resotocore/report/inspector_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ async def load_benchmarks(
# create query
term: Term = P("benchmark").is_in(benchmark_names)
if severity:
term = term & P("severity").is_in(context.severities_including(severity))
term = term & P("severity").is_in([s.value for s in context.severities_including(severity)])
term = P.context("security.issues[]", term)
if accounts:
term = term & P("ancestors.account.reported.id").is_in(accounts)
Expand Down
4 changes: 2 additions & 2 deletions resotocore/tests/resotocore/cli/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ async def test_create_query_parts(cli: CLI) -> None:
assert "-delete[1:]->" in commands[0].executable_commands[0].arg # type: ignore
commands = await cli.evaluate_cli_command("search some_int==0 | ancestors delete")
assert "<-delete[1:]-" in commands[0].executable_commands[0].arg # type: ignore
commands = await cli.evaluate_cli_command("search some_int==0 | aggregate foo, bla as bla: sum(bar)")
commands = await cli.evaluate_cli_command("search some_int==0 | aggregate foo, bla as bla: sum(bar) as a")
assert (
commands[0].executable_commands[0].arg
== f"'aggregate(reported.foo, reported.bla as bla: sum(reported.bar)):reported.some_int == 0 {sort}'"
== f"'aggregate(reported.foo, reported.bla as bla: sum(reported.bar) as a):reported.some_int == 0 sort a asc'"
)

# multiple head/tail commands are combined correctly
Expand Down
43 changes: 43 additions & 0 deletions resotocore/tests/resotocore/db/arango_query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,46 @@ def test_usage(foo_model: Model, graph_db: GraphDB) -> None:
")\n"
' FOR result in with_usage0 RETURN UNSET(result, ["flat"])'
)


def test_aggregation(foo_model: Model, graph_db: GraphDB) -> None:
q, _ = to_query(graph_db, QueryModel(parse_query("aggregate(name: max(num)): is(foo)"), foo_model))
assert "collect var_0=agg0.name aggregate fn_0=max(agg0.num)" in q
# aggregate vars get expanded
q, _ = to_query(graph_db, QueryModel(parse_query("aggregate(name, a[*].b[*].c: max(num)): is(foo)"), foo_model))
assert (
"for agg0 in filter0 FOR pre0 IN APPEND(TO_ARRAY(agg0.a), {_internal: true}) "
"FOR pre1 IN APPEND(TO_ARRAY(pre0.b), {_internal: true}) "
"FILTER pre0._internal!=true OR pre1._internal!=true "
"collect var_0=agg0.name, var_1=pre1.c "
"aggregate fn_0=max(agg0.num) "
'RETURN {"group":{"name": var_0, "c": var_1}, "max_of_num": fn_0}' in q
)
q, _ = to_query(
graph_db,
QueryModel(parse_query("aggregate(name: max(num), min(a[*].x), sum(a[*].b[*].d)): is(foo)"), foo_model),
)
# no expansion on the main level, but expansion in subqueries (let expressions)
assert (
"for agg0 in filter0 "
"LET agg_let0 = min( FOR inner0 IN TO_ARRAY(agg0.a) RETURN inner0.x) "
"LET agg_let1 = sum( FOR inner1 IN TO_ARRAY(agg0.a) FOR inner2 IN TO_ARRAY(inner1.b) RETURN inner2.d) "
"collect var_0=agg0.name "
"aggregate fn_0=max(agg0.num), fn_1=min(agg_let0), fn_2=sum(agg_let1) "
'RETURN {"group":{"name": var_0}, "max_of_num": fn_0, '
'"min_of_a_x": fn_1, "sum_of_a_b_d": fn_2}' in q
)
q, _ = to_query(
graph_db,
QueryModel(parse_query("aggregate(name, a[*].c: max(num), min(a[*].x), sum(a[*].b[*].d)): is(foo)"), foo_model),
)
assert (
"for agg0 in filter0 FOR pre0 IN APPEND(TO_ARRAY(agg0.a), {_internal: true}) "
"FILTER pre0._internal!=true "
"LET agg_let0 = min( RETURN pre0.x) "
"LET agg_let1 = sum( FOR inner0 IN TO_ARRAY(pre0.b) RETURN inner0.d) "
"collect var_0=agg0.name, var_1=pre0.c "
"aggregate fn_0=max(agg0.num), fn_1=min(pre0.x), fn_2=sum(agg_let1) "
'RETURN {"group":{"name": var_0, "c": var_1}, "max_of_num": fn_0, '
'"min_of_a_x": fn_1, "sum_of_a_b_d": fn_2}' in q
)
12 changes: 6 additions & 6 deletions resotocore/tests/resotocore/query/query_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_not() -> None:
assert_round_trip(not_term, P.of_kind("bla").not_term())
assert_round_trip(not_term, term_parser.parse("not(is(a) or not is(b) and not a>1 or not b<2 or not(a>1))"))
# make sure not only negates the simple term, not the combined term
assert term_parser.parse("not a==b and b==c") == CombinedTerm(NotTerm(P("a") == "b"), "and", P("b") == "c")
assert term_parser.parse("not a==b and b==c") == CombinedTerm(NotTerm(P("a").eq("b")), "and", P("b").eq("c"))


def test_filter_term() -> None:
Expand All @@ -142,7 +142,7 @@ def test_filter_term() -> None:
assert_round_trip(term_parser, P.of_kind("foo") | P.of_kind("bla"))
assert_round_trip(
term_parser,
((P.of_kind("foo") | P.of_kind("bla")) & (P("a") > 23)) & (P("b").is_in([1, 2, 3])) & (P("c") == {"a": 123}),
((P.of_kind("foo") | P.of_kind("bla")) & (P("a") > 23)) & (P("b").is_in([1, 2, 3])) & (P("c").eq({"a": 123})),
)


Expand Down Expand Up @@ -200,9 +200,9 @@ def test_query() -> None:
)
.merge_with("cloud", Navigation(1, Navigation.Max, direction=Direction.inbound), Query.mk_term("cloud"))
.traverse_out()
.filter(P("some.int.value") < 1, P("some.other") == 23)
.filter(P("some.int.value") < 1, P("some.other").eq(23))
.traverse_out()
.filter(P("active") == 12, P.function("in_subnet").on("ip", "1.2.3.4/96"))
.filter(P("active").eq(12), P.function("in_subnet").on("ip", "1.2.3.4/96"))
.filter_with(WithClause(WithClauseFilter("==", 0), Navigation()))
.group_by([AggregateVariable(AggregateVariableName("foo"))], [AggregateFunction("sum", "cpu")])
.add_sort(Sort("test", "asc"))
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_aggregate_group_function() -> None:
assert bla.name == "bla"
assert bla.as_name == "bar"
boo = aggregate_group_function_parser.parse("sum(boo * 1024.12 + 1) as bar")
assert boo.ops == [("*", 1024.12), ("+", 1)]
assert boo.ops == (("*", 1024.12), ("+", 1))
with pytest.raises(Exception):
assert aggregate_group_function_parser.parse("sum(test / 3 +)")

Expand Down Expand Up @@ -317,7 +317,7 @@ def test_with_clause() -> None:
assert wc.navigation == Navigation(maybe_edge_types=["delete"])
assert str(wc.term) == '(foo == "bla" and test > 23)'
assert str(wc.with_clause) == "with(any, -delete->)"
term = Query.mk_term("foo", P("test") == 23)
term = Query.mk_term("foo", P("test").eq(23))
clause_filter = WithClauseFilter(">", 23)
nav = Navigation()

Expand Down
4 changes: 3 additions & 1 deletion resotoshell/resotoshell/promptsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,14 +402,16 @@ def get_completions(self, document: Document, complete_event: CompleteEvent) ->
class AggregateCompleter(AbstractSearchCompleter):
def __init__(self, kinds: List[str], props: List[str]) -> None:
super().__init__(kinds, props)
self.aggregate_fns = ["sum(", "min(", "max(", "avg("]
self.aggregate_fns = ["sum(", "min(", "max(", "avg(", "stddev(", "variance("]
self.aggregate_fn_completer = FuzzyWordCompleter(
self.aggregate_fns,
meta_dict={
"sum(": "sum over all occurrences",
"min(": "use the smallest occurrence",
"max(": "use the biggest occurrence",
"avg(": "average over all occurrences",
"stddev(": "standard deviation over all occurrences",
"variance(": "variance over all occurrences",
},
)
self.as_completer = FuzzyWordCompleter(["as"], meta_dict=({"as": "rename this result"}))
Expand Down

0 comments on commit 8afbdfa

Please sign in to comment.