Skip to content

Commit

Permalink
feat(api): grouping sets/rollup/cube
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Nov 12, 2024
1 parent 321a382 commit 91612a7
Show file tree
Hide file tree
Showing 23 changed files with 1,181 additions and 107 deletions.
2 changes: 1 addition & 1 deletion .codespellrc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[codespell]
# local codespell matches `./docs`, pre-commit codespell matches `docs`
skip = *.lock,.direnv,.git,./docs/_freeze,./docs/_output/**,./docs/_inv/**,docs/_freeze/**,*.svg,*.css,*.html,*.js,ibis/backends/tests/tpc/queries/duckdb/ds/*.sql
ignore-regex = \b(i[if]f|I[IF]F|AFE|alls)\b
ignore-regex = \b(i[if]f|I[IF]F|AFE|alls|ND)\b
builtin = clear,rare,names
ignore-words-list = tim,notin,ang
42 changes: 37 additions & 5 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,13 +1400,42 @@ def visit_JoinLink(self, op, *, how, table, predicates):
def _generate_groups(groups):
return map(sge.convert, range(1, len(groups) + 1))

def visit_Aggregate(self, op, *, parent, groups, metrics):
sel = sg.select(
*self._cleanup_names(groups), *self._cleanup_names(metrics), copy=False
def _compile_agg_select(self, op, *, parent, keys, metrics):
return sg.select(
*self._cleanup_names(keys), *self._cleanup_names(metrics), copy=False
).from_(parent, copy=False)

if groups:
sel = sel.group_by(*self._generate_groups(groups.values()), copy=False)
def _compile_group_by(self, sel, *, groups, grouping_sets, rollups, cubes):
expressions = list(self._generate_groups(groups.values()))
group = sge.Group(
expressions=expressions,
grouping_sets=[
sge.GroupingSets(
expressions=[
sge.Tuple(expressions=expressions)
for expressions in grouping_set
]
)
for grouping_set in grouping_sets
],
rollup=[sge.Rollup(expressions=rollup) for rollup in rollups],
cube=[sge.Cube(expressions=cube) for cube in cubes],
)
return sel.group_by(group, copy=False)

def visit_Aggregate(
self, op, *, parent, keys, groups, metrics, grouping_sets, rollups, cubes
):
sel = self._compile_agg_select(op, parent=parent, keys=keys, metrics=metrics)

if groups or grouping_sets or rollups or cubes:
sel = self._compile_group_by(
sel,
groups=groups,
grouping_sets=grouping_sets,
rollups=rollups,
cubes=cubes,
)

return sel

Expand Down Expand Up @@ -1609,6 +1638,9 @@ def _make_sample_backwards_compatible(self, *, sample, parent):
parent.args["sample"] = sample
return sg.select(STAR).from_(parent)

def visit_GroupID(self, op, *, arg):
return self.f.grouping(*arg)


# `__init_subclass__` is uncalled for subclasses - we manually call it here to
# autogenerate the base class implementations as well.
Expand Down
17 changes: 6 additions & 11 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,12 @@ def visit_ArgMax(self, op, *, arg, key, where):
arg, where=where, order_by=[sge.Ordered(this=key, desc=True)]
)

def visit_Aggregate(self, op, *, parent, groups, metrics):
def _compile_agg_select(self, op, *, parent, keys, metrics):
"""Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
quoted = self.quoted
metrics = tuple(self._cleanup_names(metrics))

if groups:
if keys:
# datafusion doesn't support count distinct aggregations alongside
# computed grouping keys so create a projection of the key and all
# existing columns first, followed by the usual group by
Expand All @@ -484,11 +484,11 @@ def visit_Aggregate(self, op, *, parent, groups, metrics):
),
# can't use set subtraction here since the schema keys'
# order matters and set subtraction doesn't preserve order
(k for k in op.parent.schema.keys() if k not in groups),
(k for k in op.parent.schema.keys() if k not in keys),
)
)
table = (
sg.select(*cols, *self._cleanup_names(groups))
sg.select(*cols, *self._cleanup_names(keys))
.from_(parent)
.subquery(parent.alias)
)
Expand All @@ -497,19 +497,14 @@ def visit_Aggregate(self, op, *, parent, groups, metrics):
# quoted=True is required here for correctness
by_names_quoted = tuple(
sg.column(key, table=getattr(value, "table", None), quoted=quoted)
for key, value in groups.items()
for key, value in keys.items()
)
selections = by_names_quoted + metrics
else:
selections = metrics or (STAR,)
table = parent

sel = sg.select(*selections).from_(table)

if groups:
sel = sel.group_by(*by_names_quoted)

return sel
return sg.select(*selections).from_(table)

def visit_StructColumn(self, op, *, names, values):
args = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ FROM (
FROM "countries" AS "t0"
) AS t0
GROUP BY
"cont"
1
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ WITH "t5" AS (
) AS "t4"
) AS t4
GROUP BY
"t4"."field_of_study"
1
)
SELECT
*
Expand Down
Loading

0 comments on commit 91612a7

Please sign in to comment.