Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNOW-1502893]: Add support for pd.crosstab #1837

Merged
merged 32 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3ce073a
[SNOW-1502893]: Add support for `pd.crosstab`
sfc-gh-rdurrani Jun 25, 2024
cca9bbe
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 13, 2024
d9e5b79
Add initial tests + implementation
sfc-gh-rdurrani Aug 13, 2024
b883db7
Add values tests
sfc-gh-rdurrani Aug 13, 2024
794b592
Add more support
sfc-gh-rdurrani Aug 20, 2024
2d9f9d2
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 20, 2024
f76a034
Add support for everything but normalize + margins together
sfc-gh-rdurrani Aug 21, 2024
d9c759f
Add support for normalize + margins on rows
sfc-gh-rdurrani Aug 21, 2024
c84fe32
Wrap up implementation of crosstab
sfc-gh-rdurrani Aug 21, 2024
8df981f
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 21, 2024
f668414
Fix small bug after merge
sfc-gh-rdurrani Aug 21, 2024
f2ba2f4
Add docstrings
sfc-gh-rdurrani Aug 21, 2024
92ac274
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 28, 2024
da12522
Fix tests, address review comments
sfc-gh-rdurrani Aug 28, 2024
e8bcc63
Use eval_snowpark...
sfc-gh-rdurrani Aug 28, 2024
6399f4c
Remove crosstab from unsupported tests
sfc-gh-rdurrani Aug 28, 2024
53a4125
Fix docs, add all aggfuncs to tests
sfc-gh-rdurrani Aug 28, 2024
76c01e2
Fix docs
sfc-gh-rdurrani Aug 28, 2024
906e00b
Add tests
sfc-gh-rdurrani Aug 28, 2024
5cdc712
Address review comments
sfc-gh-rdurrani Aug 28, 2024
a38bcb1
Address review comments
sfc-gh-rdurrani Aug 28, 2024
d007558
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 28, 2024
e835540
Fix doc
sfc-gh-rdurrani Aug 29, 2024
e4a302b
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 29, 2024
8fb1a3d
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
d41f30a
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
4e7464c
Update coverage
sfc-gh-rdurrani Aug 30, 2024
5723756
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
1430d1b
Fix tests
sfc-gh-rdurrani Aug 30, 2024
cf1b2fb
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
796445a
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
5c19d5b
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
- Added support for `DatetimeIndex.month_name` and `DatetimeIndex.day_name`.
- Added support for `Series.dt.weekday`, `Series.dt.time`, and `DatetimeIndex.time`.
- Added support for subtracting two timestamps to get a Timedelta.
- Added support for `pd.crosstab`.

#### Bug Fixes

Expand Down
1 change: 1 addition & 0 deletions docs/source/modin/general_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ General functions
:toctree: pandas_api/

melt
crosstab
pivot
pivot_table
cut
Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/general_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Data manipulations
| ``concat`` | P | ``levels`` is not supported, | |
| | | ``copy`` is ignored | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``crosstab`` | N | | |
| ``crosstab`` | Y | | |
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``cut`` | P | ``retbins``, ``labels`` | ``N`` if ``retbins=True``or ``labels!=False`` |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
1 change: 0 additions & 1 deletion src/snowflake/snowpark/modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,6 @@ def aggregate(
# TypeError: got an unexpected keyword argument 'skipna'
if is_dict_like(func) and not uses_named_kwargs:
kwargs.clear()

sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
result = self.__constructor__(
query_compiler=self._query_compiler.agg(
func=func,
Expand Down
240 changes: 224 additions & 16 deletions src/snowflake/snowpark/modin/pandas/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""Implement pandas general API."""
from __future__ import annotations

from collections.abc import Hashable, Iterable, Mapping, Sequence
from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
from datetime import date, datetime, tzinfo
from logging import getLogger
from typing import TYPE_CHECKING, Any, Literal, Union
Expand All @@ -47,7 +47,7 @@
_infer_tz_from_endpoints,
_maybe_normalize_endpoints,
)
from pandas.core.dtypes.common import is_list_like
from pandas.core.dtypes.common import is_list_like, is_nested_list_like
from pandas.core.dtypes.inference import is_array_like
from pandas.core.tools.datetimes import (
ArrayConvertible,
Expand Down Expand Up @@ -1831,7 +1831,6 @@ def melt(


@snowpark_pandas_telemetry_standalone_function_decorator
@pandas_module_level_function_not_implemented()
@_inherit_docstrings(pandas.crosstab, apilink="pandas.crosstab")
def crosstab(
index,
Expand All @@ -1848,20 +1847,229 @@ def crosstab(
"""
Compute a simple cross tabulation of two (or more) factors.
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
"""
# TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py
pandas_crosstab = pandas.crosstab(
index,
columns,
values,
rownames,
colnames,
aggfunc,
margins,
margins_name,
dropna,
normalize,
if values is None and aggfunc is not None:
raise ValueError("aggfunc cannot be used without values.")

if values is not None and aggfunc is None:
raise ValueError("values cannot be used without an aggfunc.")

if not is_nested_list_like(index):
index = [index]
if not is_nested_list_like(columns):
columns = [columns]

user_passed_rownames = rownames is not None
user_passed_colnames = colnames is not None

from pandas.core.reshape.pivot import _build_names_mapper, _get_names

def _get_names_wrapper(list_of_objs, names, prefix):
"""
Helper method to expand DataFrame objects containing
multiple columns into Series, since `_get_names` expects
one column per entry.
"""
expanded_list_of_objs = []
for obj in list_of_objs:
if isinstance(obj, DataFrame):
for col in obj.columns:
expanded_list_of_objs.append(obj[col])
else:
expanded_list_of_objs.append(obj)
return _get_names(expanded_list_of_objs, names, prefix)

rownames = _get_names_wrapper(index, rownames, prefix="row")
colnames = _get_names_wrapper(columns, colnames, prefix="col")

(
rownames_mapper,
unique_rownames,
colnames_mapper,
unique_colnames,
) = _build_names_mapper(rownames, colnames)

pass_objs = [x for x in index + columns if isinstance(x, (Series, DataFrame))]
row_idx_names = None
col_idx_names = None
if pass_objs:
# If we have any Snowpark pandas objects in the index or columns, then we
# need to find the intersection of their indices, and only pick rows from
# the objects that have indices in the intersection of their indices.
# After we do that, we then need to append the non Snowpark pandas objects
# using the intersection of indices as the final index for the DataFrame object.
# First, we separate the objects into Snowpark pandas objects, and non-Snowpark
# pandas objects (while renaming them so that they have unique names).
rownames_idx = 0
row_idx_names = []
dfs = []
arrays = []
for obj in index:
if isinstance(obj, Series):
row_idx_names.append(obj.name)
df = pd.DataFrame(obj)
df.columns = [unique_rownames[rownames_idx]]
rownames_idx += 1
dfs.append(df)
elif isinstance(obj, DataFrame):
row_idx_names.extend(obj.columns)
obj.columns = unique_rownames[
rownames_idx : rownames_idx + len(obj.columns)
]
rownames_idx += len(obj.columns)
dfs.append(obj)
else:
row_idx_names.append(None)
df = pd.DataFrame(obj)
df.columns = unique_rownames[
rownames_idx : rownames_idx + len(df.columns)
]
rownames_idx += len(df.columns)
arrays.append(df)

colnames_idx = 0
col_idx_names = []
for obj in columns:
if isinstance(obj, Series):
col_idx_names.append(obj.name)
df = pd.DataFrame(obj)
df.columns = [unique_colnames[colnames_idx]]
colnames_idx += 1
dfs.append(df)
elif isinstance(obj, DataFrame):
col_idx_names.extend(obj.columns)
obj.columns = unique_colnames[
colnames_idx : colnames_idx + len(obj.columns)
]
colnames_idx += len(obj.columns)
dfs.append(obj)
else:
col_idx_names.append(None)
df = pd.DataFrame(obj)
df.columns = unique_colnames[
colnames_idx : colnames_idx + len(df.columns)
]
colnames_idx += len(df.columns)
arrays.append(df)

# Now, we have two lists - a list of Snowpark pandas objects, and a list of objects
# that were not passed in as Snowpark pandas objects, but that we have converted
# to Snowpark pandas objects to give them column names. We can perform inner joins
# on the dfs list to get a DataFrame with the final index (that is only an intersection
# of indices.)
df = dfs[0]
for right in dfs[1:]:
df = df.merge(right, left_index=True, right_index=True)

if len(arrays) > 0:
index = df.index
right_df = pd.concat(arrays, axis=1)
right_df.index = index
df = df.merge(right_df, left_index=True, right_index=True)
else:
data = {
**dict(zip(unique_rownames, index)),
**dict(zip(unique_colnames, columns)),
}
df = DataFrame(data)

if values is None:
df["__dummy__"] = 0
kwargs = {"aggfunc": "count"}
else:
df["__dummy__"] = values
kwargs = {"aggfunc": aggfunc}

table = df.pivot_table(
sfc-gh-helmeleegy marked this conversation as resolved.
Show resolved Hide resolved
"__dummy__",
index=unique_rownames,
columns=unique_colnames,
margins=margins,
margins_name=margins_name,
dropna=dropna,
**kwargs, # type: ignore[arg-type]
)
return DataFrame(pandas_crosstab)

if row_idx_names is not None and not user_passed_rownames:
table.index = table.index.set_names(row_idx_names)

if col_idx_names is not None and not user_passed_colnames:
table.columns = table.columns.set_names(col_idx_names)

if aggfunc is None:
# If no aggfunc is provided, we are computing frequencies. Since we use
# pivot_table above, pairs that are not observed will get a NaN value,
# so we need to fill all NaN values with 0.
table = table.fillna(0)

# We must explicitly check that the value of normalize is not False here,
# as a valid value of normalize is `0` (for normalizing index).
if normalize is not False:
if normalize not in [0, 1, "index", "columns", "all", True]:
raise ValueError(f"Not a valid normalize argument: {normalize}")
if normalize is True:
normalize = "all"
normalize = {0: "index", 1: "columns"}.get(normalize, normalize)

# Actual Normalizations
normalizers: dict[bool | str, Callable] = {
"all": lambda x: x / x.sum(axis=0).sum(),
"columns": lambda x: x / x.sum(),
"index": lambda x: x.div(x.sum(axis=1), axis="index"),
}

if margins is False:

f = normalizers[normalize]
names = table.columns.names
table = f(table)
table.columns.names = names
table = table.fillna(0)
else:
# keep index and column of pivoted table
table_index = table.index
table_columns = table.columns

column_margin = table.iloc[:-1, -1]

if normalize == "columns":
# keep the core table
table = table.iloc[:-1, :-1]

# Normalize core
f = normalizers[normalize]
table = f(table)
table = table.fillna(0)
# Fix Margins
column_margin = column_margin / column_margin.sum()
table = pd.concat([table, column_margin], axis=1)
table = table.fillna(0)
table.columns = table_columns

elif normalize == "index":
table = table.iloc[:, :-1]

# Normalize core
f = normalizers[normalize]
table = f(table)
table = table.fillna(0).reindex(index=table_index)

elif normalize == "all":
# Normalize core
f = normalizers[normalize]
table = f(table.iloc[:, :-1]) * 2.0
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved

column_margin = column_margin / column_margin.sum()
table = pd.concat([table, column_margin], axis=1)
table.iloc[-1, -1] = 1

table = table.fillna(0)
table.index = table_index
table.columns = table_columns

table = table.rename_axis(index=rownames_mapper, axis=0)
table = table.rename_axis(columns=colnames_mapper, axis=1)

return table


# Adding docstring since pandas docs don't have web section for this function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,10 @@
compute_bin_indices,
preprocess_bins_for_cut,
)
from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame
from snowflake.snowpark.modin.plugin._internal.frame import (
InternalFrame,
LabelIdentifierPair,
)
from snowflake.snowpark.modin.plugin._internal.groupby_utils import (
check_is_groupby_supported_by_snowflake,
extract_groupby_column_pandas_labels,
Expand Down Expand Up @@ -5459,11 +5462,14 @@ def agg(
)
for agg_arg in agg_args
}
pandas_labels = list(agg_col_map.keys())
if self.is_multiindex(axis=1):
pandas_labels = [
(label,) * len(self.columns.names) for label in pandas_labels
]
single_agg_func_query_compilers.append(
SnowflakeQueryCompiler(
frame.project_columns(
list(agg_col_map.keys()), list(agg_col_map.values())
)
frame.project_columns(pandas_labels, list(agg_col_map.values()))
)
)
else: # axis == 0
Expand Down Expand Up @@ -13703,7 +13709,6 @@ def create_lazy_type_functions(
assert len(right_result_data_identifiers) == 1, "other must be a Series"
right = right_result_data_identifiers[0]
right_datatype = right_datatypes[0]

# now replace in result frame identifiers with binary op result
replace_mapping = {}
snowpark_pandas_types = []
Expand All @@ -13725,10 +13730,19 @@ def create_lazy_type_functions(
identifiers_to_keep = set(
new_frame.index_column_snowflake_quoted_identifiers
) | set(update_result.old_id_to_new_id_mappings.values())
self_is_column_mi = len(self._modin_frame.data_column_pandas_index_names)
label_to_snowflake_quoted_identifier = []
snowflake_quoted_identifier_to_snowpark_pandas_type = {}
for pair in new_frame.label_to_snowflake_quoted_identifier:
if pair.snowflake_quoted_identifier in identifiers_to_keep:
if (
self_is_column_mi
and isinstance(pair.label, tuple)
and isinstance(pair.label[0], tuple)
):
pair = LabelIdentifierPair(
pair.label[0], pair.snowflake_quoted_identifier
)
label_to_snowflake_quoted_identifier.append(pair)
snowflake_quoted_identifier_to_snowpark_pandas_type[
pair.snowflake_quoted_identifier
Expand All @@ -13742,7 +13756,7 @@ def create_lazy_type_functions(
label_to_snowflake_quoted_identifier
),
num_index_columns=new_frame.num_index_columns,
data_column_index_names=new_frame.data_column_index_names,
data_column_index_names=self._modin_frame.data_column_index_names,
snowflake_quoted_identifier_to_snowpark_pandas_type=snowflake_quoted_identifier_to_snowpark_pandas_type,
)

Expand Down Expand Up @@ -14153,9 +14167,7 @@ def infer_sorted_column_labels(
new_frame = InternalFrame.create(
ordered_dataframe=expanded_ordered_frame,
data_column_pandas_labels=sorted_column_labels,
data_column_pandas_index_names=[
None
], # operation removes column index name always.
data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names,
data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers
+ new_identifiers,
index_column_pandas_labels=index_column_pandas_labels,
Expand Down Expand Up @@ -14202,7 +14214,7 @@ def infer_sorted_column_labels(
new_frame = InternalFrame.create(
ordered_dataframe=expanded_ordered_frame,
data_column_pandas_labels=expanded_data_column_pandas_labels,
data_column_pandas_index_names=[None], # operation removes names
data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names,
data_column_snowflake_quoted_identifiers=expanded_data_column_snowflake_quoted_identifiers,
index_column_pandas_labels=index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers,
Expand Down
3 changes: 3 additions & 0 deletions tests/integ/modin/crosstab/__init__.py
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
Loading
Loading