Skip to content

Commit

Permalink
[SNOW-1445834, SNOW-1445726]: Implemented DataFrame.unstack and `Se…
Browse files Browse the repository at this point in the history
…ries.unstack` (#1848)

Signed-off-by: Naren Krishna <[email protected]>
  • Loading branch information
sfc-gh-nkrishna authored Jul 30, 2024
1 parent 53a8cbc commit 8a1b8d8
Show file tree
Hide file tree
Showing 15 changed files with 509 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
- Added support for `Index.value_counts`.
- Added support for `Series.dt.day_name` and `Series.dt.month_name`.
- Added support for indexing on Index, e.g., `df.index[:10]`.
- Added support for `DataFrame.unstack` and `Series.unstack`.

#### Improvements
- Removed the public preview warning message upon importing Snowpark pandas.
Expand Down
1 change: 1 addition & 0 deletions docs/source/modin/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ DataFrame
DataFrame.stack
DataFrame.T
DataFrame.transpose
DataFrame.unstack

.. rubric:: Combining / comparing / joining / merging

Expand Down
1 change: 1 addition & 0 deletions docs/source/modin/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ Series

Series.sort_values
Series.sort_index
Series.unstack
Series.nlargest
Series.nsmallest
Series.squeeze
Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/dataframe_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``tz_localize`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``unstack`` | N | | |
| ``unstack`` | P | ``sort`` | ``N`` for non-integer ``level``. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``update`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/series_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``unique`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``unstack`` | N | | |
| ``unstack`` | P | ``sort`` | ``N`` for non-integer ``level``. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``update`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
26 changes: 18 additions & 8 deletions src/snowflake/snowpark/modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,23 +1800,33 @@ def nsmallest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200
)
)

@dataframe_not_implemented()
def unstack(self, level=-1, fill_value=None): # noqa: PR01, RT01, D200
def unstack(
self,
level: int | str | list = -1,
fill_value: int | str | dict = None,
sort: bool = True,
):
"""
Pivot a level of the (necessarily hierarchical) index labels.
"""
# TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
if not isinstance(self.index, pandas.MultiIndex) or (
isinstance(self.index, pandas.MultiIndex)
and is_list_like(level)
and len(level) == self.index.nlevels
# This ensures that non-pandas MultiIndex objects are caught.
nlevels = self._query_compiler.nlevels()
is_multiindex = nlevels > 1

if not is_multiindex or (
is_multiindex and is_list_like(level) and len(level) == nlevels
):
return self._reduce_dimension(
query_compiler=self._query_compiler.unstack(level, fill_value)
query_compiler=self._query_compiler.unstack(
level, fill_value, sort, is_series_input=False
)
)
else:
return self.__constructor__(
query_compiler=self._query_compiler.unstack(level, fill_value)
query_compiler=self._query_compiler.unstack(
level, fill_value, sort, is_series_input=False
)
)

def pivot(
Expand Down
24 changes: 18 additions & 6 deletions src/snowflake/snowpark/modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,19 +1562,31 @@ def set_axis(
copy=copy,
)

@series_not_implemented()
def unstack(self, level=-1, fill_value=None): # noqa: PR01, RT01, D200
def unstack(
self,
level: int | str | list = -1,
fill_value: int | str | dict = None,
sort: bool = True,
):
"""
Unstack, also known as pivot, Series with MultiIndex to produce DataFrame.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
from snowflake.snowpark.modin.pandas.dataframe import DataFrame

result = DataFrame(
query_compiler=self._query_compiler.unstack(level, fill_value)
)
# We can't unstack a Series object, if we don't have a MultiIndex.
if self._query_compiler.has_multiindex:
result = DataFrame(
query_compiler=self._query_compiler.unstack(
level, fill_value, sort, is_series_input=True
)
)
else:
raise ValueError( # pragma: no cover
f"index must be a MultiIndex to unstack, {type(self.index)} was passed"
)

return result.droplevel(0, axis=1) if result.columns.nlevels > 1 else result
return result

@series_not_implemented()
@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import typing
from collections.abc import Hashable
from enum import Enum
from typing import Optional

from snowflake.snowpark._internal.analyzer.analyzer_utils import (
Expand Down Expand Up @@ -58,6 +59,11 @@
DEFAULT_PANDAS_UNPIVOT_VALUE_NAME = "value"


class StackOperation(Enum):
STACK = "stack"
UNSTACK = "unstack"


class UnpivotResultInfo(typing.NamedTuple):
"""
Structure that stores information about the unpivot result.
Expand Down
209 changes: 189 additions & 20 deletions src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@
)
from snowflake.snowpark.modin.plugin._internal.unpivot_utils import (
UNPIVOT_NULL_REPLACE_VALUE,
StackOperation,
unpivot,
unpivot_empty_df,
)
Expand Down Expand Up @@ -314,6 +315,7 @@
parse_object_construct_snowflake_quoted_identifier_and_extract_pandas_label,
parse_snowflake_object_construct_identifier_to_map,
snowpark_to_pandas_helper,
unquote_name_if_quoted,
)
from snowflake.snowpark.modin.plugin._internal.where_utils import (
validate_expected_boolean_data_columns,
Expand Down Expand Up @@ -15863,34 +15865,201 @@ def stack(
"Snowpark pandas doesn't support multiindex columns in stack API"
)

index_names = ["index"]
# Stack is equivalent to doing df.melt() with index reset, sorting the values, then setting the index
# Note that we always use sort_rows_by_column_values even if sort is False
qc = (
self.reset_index()
.melt(
id_vars=index_names,
value_vars=self.columns,
var_name="index_second_level",
value_name=MODIN_UNNAMED_SERIES_LABEL,
ignore_index=False,
qc = self._stack_helper(operation=StackOperation.STACK)

if dropna:
return qc.dropna(axis=0, how="any", thresh=None)
else:
return qc

def unstack(
self,
level: Union[int, str, list] = -1,
fill_value: Optional[Union[int, str, dict]] = None,
sort: bool = True,
is_series_input: bool = False,
) -> "SnowflakeQueryCompiler":
"""
Pivot a level of the (necessarily hierarchical) index labels.

Returns a DataFrame having a new level of column labels whose
inner-most level consists of the pivoted index labels.

If the index is not a MultiIndex, the output will be a Series
(the analogue of stack when the columns are not a MultiIndex).

Parameters
----------
level : int, str, list, default -1
Level(s) of index to unstack, can pass level name.

fillna : int, str, dict, optional
Replace NaN with this value if the unstack produces missing values.

sort : bool, default True
Sort the level(s) in the resulting MultiIndex columns.

is_series_input : bool, default False
Whether the input is a Series, in which case we call `droplevel`
"""
if not isinstance(level, int):
# TODO: SNOW-1558364: Support index name passed to level parameter
ErrorMessage.not_implemented(
"Snowpark pandas DataFrame/Series.unstack does not yet support a non-integer `level` parameter"
)
if not sort:
ErrorMessage.not_implemented(
"Snowpark pandas DataFrame/Series.unstack does not yet support the `sort` parameter"
)
.sort_rows_by_column_values(
columns=index_names, # type: ignore
if self._modin_frame.is_multiindex(axis=1):
ErrorMessage.not_implemented(
"Snowpark pandas doesn't support multiindex columns in the unstack API"
)

level = [level]

index_names = self.get_index_names()

# Check to see if we have a MultiIndex, if we do, make sure we remove
# the appropriate level(s), and we pivot accordingly.
if len(index_names) > 1:
# Resetting the index keeps the index columns as the first n data columns
qc = self.reset_index()
index_cols = qc._modin_frame.data_column_pandas_labels[0 : len(index_names)]
pivot_cols = [index_cols[lev] for lev in level] # type: ignore
res_index_cols = []
column_names_to_reset_to_none = []
for i in range(len(index_names)):
if index_names[i] is None:
# We need to track the names where the index and columns originally had no name
# in order to reset those names back to None after the operation
column_names_to_reset_to_none.append(
qc._modin_frame.data_column_pandas_labels[i]
)
col = index_cols[i]
if col not in pivot_cols:
res_index_cols.append(col)
vals = [
c
for c in self.columns
if c not in res_index_cols and c not in pivot_cols
]

qc = qc.pivot_table(
columns=pivot_cols,
index=res_index_cols,
values=vals,
aggfunc="min",
fill_value=fill_value,
margins=False,
dropna=True,
margins_name="All",
observed=False,
sort=sort,
)

# Set the original unnamed index values back to None
output_index_names = qc.get_index_names()
output_index_names_replace_level_with_none = [
None
if output_index_names[i] in column_names_to_reset_to_none
else output_index_names[i]
for i in range(len(output_index_names))
]
qc = qc.set_index_names(output_index_names_replace_level_with_none)
# Set the unnamed column values back to None
output_column_names = qc.columns.names
output_column_names_replace_level_with_none = [
None
if output_column_names[i] in column_names_to_reset_to_none
else output_column_names[i]
for i in range(len(output_column_names))
]
qc = qc.set_columns(
qc.columns.set_names(output_column_names_replace_level_with_none)
)
else:
qc = self._stack_helper(operation=StackOperation.UNSTACK)

if is_series_input and qc.columns.nlevels > 1:
# If input is Series and output is MultiIndex, drop the top level of the MultiIndex
qc = qc.set_columns(qc.columns.droplevel())
return qc

def _stack_helper(
self,
operation: StackOperation,
) -> "SnowflakeQueryCompiler":
"""
Helper function that performs stacking or unstacking operation on single index dataframe/series.

Parameters
----------
operation : StackOperation.STACK or StackOperation.UNSTACK
The operation being performed.
"""
index_names = self.get_index_names()
# Resetting the index keeps the index columns as the first n data columns
qc = self.reset_index()
index_cols = qc._modin_frame.data_column_pandas_labels[0 : len(index_names)]
column_names_to_reset_to_none = []
for i in range(len(index_names)):
if index_names[i] is None:
# We need to track the names where the index and columns originally had no name
# in order to reset those names back to None after the operation
column_names_to_reset_to_none.append(
qc._modin_frame.data_column_pandas_labels[i]
)

# Track the new column name for the original unnamed column
if self.columns.name is None:
quoted_col_label = (
qc._modin_frame.ordered_dataframe.generate_snowflake_quoted_identifiers(
pandas_labels=["index_second_level"]
)[0]
)
col_label = unquote_name_if_quoted(quoted_col_label)
column_names_to_reset_to_none.append(col_label)
else:
col_label = self.columns.name

qc = qc.melt(
id_vars=index_cols, # type: ignore
value_vars=self.columns,
var_name=col_label,
value_name=MODIN_UNNAMED_SERIES_LABEL,
ignore_index=False,
)

if operation == StackOperation.STACK:
# Only sort rows by column values in case of 'stack'
# For 'unstack' maintain the row position order
qc = qc.sort_rows_by_column_values(
columns=index_cols, # type: ignore
ascending=[True],
kind="stable",
na_position="last",
ignore_index=False,
)
.replace(to_replace=UNPIVOT_NULL_REPLACE_VALUE, value=np.nan)
.set_index_from_columns(index_names + ["index_second_level"]) # type: ignore
.set_index_names([None, None])
)

if dropna:
return qc.dropna(axis=0, how="any", thresh=None)
# TODO: SNOW-1524695: Remove the following replace once "NULL_REPLACE" values are fixed for 'melt'
qc = qc.replace(to_replace=UNPIVOT_NULL_REPLACE_VALUE, value=np.nan)

if operation == StackOperation.STACK:
qc = qc.set_index_from_columns(index_cols + [col_label]) # type: ignore
else:
return qc
qc = qc.set_index_from_columns([col_label] + index_cols) # type: ignore

# Set the original unnamed index and column values back to None
output_index_names = qc.get_index_names()
output_index_names = [
None
if output_index_names[i] in column_names_to_reset_to_none
else output_index_names[i]
for i in range(len(output_index_names))
]
qc = qc.set_index_names(output_index_names)
return qc

def corr(
self,
Expand Down
Loading

0 comments on commit 8a1b8d8

Please sign in to comment.