Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre-compute window function operands to simplify Dask graph #1331

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions dask_sql/physical/rel/logical/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,22 @@ class DaskWindowPlugin(BaseRelPlugin):

def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer:
(dc,) = self.assert_inputs(rel, 1, context)
df = dc.df

# Output to the right field names right away
field_names = rel.getRowType().getFieldNames()

# Extract the operations here to avoid overly complex graph structure
operations_dict = dict()
for window in rel.window().getGroups():
operations, df = self._extract_operations(rel, window, df, dc, context)
operations_dict[window.toString()] = operations

for window in rel.window().getGroups():
dc = self._apply_window(rel, window, dc, field_names, context)
dc = self._apply_window(
rel, window, dc, df, field_names, context, operations_dict
)
df = dc.df

# Finally, fix the output schema if needed
df = dc.df
Expand All @@ -247,12 +257,14 @@ def _apply_window(
rel,
window,
dc: DataContainer,
df: dd.DataFrame,
field_names: list[str],
context: "dask_sql.Context",
operations_dict: dict[str, list[tuple[Callable, str, list[str]]]],
):
temporary_columns = []
newly_created_columns = []

df = dc.df
cc = dc.column_container

# Now extract the groupby and order information
Expand All @@ -270,11 +282,10 @@ def _apply_window(
f"Before applying the function, partitioning according to {group_columns}."
)

operations, df = self._extract_operations(rel, window, df, dc, context)
for _, _, cols in operations:
temporary_columns += cols

newly_created_columns = [new_column for _, new_column, _ in operations]
operations = operations_dict[window.toString()]
for _, result_col, operand_cols in operations:
temporary_columns += operand_cols
newly_created_columns.append(result_col)

logger.debug(f"Will create {newly_created_columns} new columns")

Expand Down Expand Up @@ -343,7 +354,7 @@ def _apply_window(
cc = cc.add(field_name, c)
dc = DataContainer(df, cc)
logger.debug(
f"Removed unneeded columns and registered new ones: {LoggableDataFrame(dc)}."
f"Removed unneeded columns and registered new ones: {LoggableDataFrame(df)}."
)
return dc

Expand Down
Loading