Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNOW-1502893]: Add support for pd.crosstab #1837

Merged
merged 32 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3ce073a
[SNOW-1502893]: Add support for `pd.crosstab`
sfc-gh-rdurrani Jun 25, 2024
cca9bbe
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 13, 2024
d9e5b79
Add initial tests + implementation
sfc-gh-rdurrani Aug 13, 2024
b883db7
Add values tests
sfc-gh-rdurrani Aug 13, 2024
794b592
Add more support
sfc-gh-rdurrani Aug 20, 2024
2d9f9d2
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 20, 2024
f76a034
Add support for everything but normalize + margins together
sfc-gh-rdurrani Aug 21, 2024
d9c759f
Add support for normalize + margins on rows
sfc-gh-rdurrani Aug 21, 2024
c84fe32
Wrap up implementation of crosstab
sfc-gh-rdurrani Aug 21, 2024
8df981f
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 21, 2024
f668414
Fix small bug after merge
sfc-gh-rdurrani Aug 21, 2024
f2ba2f4
Add docstrings
sfc-gh-rdurrani Aug 21, 2024
92ac274
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 28, 2024
da12522
Fix tests, address review comments
sfc-gh-rdurrani Aug 28, 2024
e8bcc63
Use eval_snowpark...
sfc-gh-rdurrani Aug 28, 2024
6399f4c
Remove crosstab from unsupported tests
sfc-gh-rdurrani Aug 28, 2024
53a4125
Fix docs, add all aggfuncs to tests
sfc-gh-rdurrani Aug 28, 2024
76c01e2
Fix docs
sfc-gh-rdurrani Aug 28, 2024
906e00b
Add tests
sfc-gh-rdurrani Aug 28, 2024
5cdc712
Address review comments
sfc-gh-rdurrani Aug 28, 2024
a38bcb1
Address review comments
sfc-gh-rdurrani Aug 28, 2024
d007558
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 28, 2024
e835540
Fix doc
sfc-gh-rdurrani Aug 29, 2024
e4a302b
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 29, 2024
8fb1a3d
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
d41f30a
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
4e7464c
Update coverage
sfc-gh-rdurrani Aug 30, 2024
5723756
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
1430d1b
Fix tests
sfc-gh-rdurrani Aug 30, 2024
cf1b2fb
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
796445a
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
5c19d5b
Merge branch 'main' into rdurrani-SNOW-1502893
sfc-gh-rdurrani Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
- Added support for `Index.is_boolean`, `Index.is_integer`, `Index.is_floating`, `Index.is_numeric`, and `Index.is_object`.
- Added support for `DatetimeIndex.round`, `DatetimeIndex.floor` and `DatetimeIndex.ceil`.
- Added support for `Series.dt.days_in_month` and `Series.dt.daysinmonth`.
- Added support for `pd.crosstab`.

#### Improvements

Expand Down
1 change: 1 addition & 0 deletions docs/source/modin/general_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ General functions
:toctree: pandas_api/

melt
crosstab
pivot
pivot_table
cut
Expand Down
5 changes: 4 additions & 1 deletion docs/source/modin/supported/general_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ Data manipulations
| ``concat`` | P | ``levels`` is not supported, | |
| | | ``copy`` is ignored | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``crosstab`` | N | | |
| ``crosstab`` | P | | ``N`` if ``aggfunc`` is not one of |
| | | | "count", "mean", "min", "max", or "sum", or |
| | | | margins is True, normalize is "all" or True, |
| | | | and values is passed. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``cut`` | P | ``retbins``, ``labels`` | ``N`` if ``retbins=True``or ``labels!=False`` |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
330 changes: 313 additions & 17 deletions src/snowflake/snowpark/modin/pandas/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""Implement pandas general API."""
from __future__ import annotations

from collections.abc import Hashable, Iterable, Mapping, Sequence
from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
from datetime import date, datetime, timedelta, tzinfo
from logging import getLogger
from typing import TYPE_CHECKING, Any, Literal, Union
Expand All @@ -49,7 +49,7 @@
_infer_tz_from_endpoints,
_maybe_normalize_endpoints,
)
from pandas.core.dtypes.common import is_list_like
from pandas.core.dtypes.common import is_list_like, is_nested_list_like
from pandas.core.dtypes.inference import is_array_like
from pandas.core.tools.datetimes import (
ArrayConvertible,
Expand Down Expand Up @@ -1982,8 +1982,6 @@ def melt(


@snowpark_pandas_telemetry_standalone_function_decorator
@pandas_module_level_function_not_implemented()
@_inherit_docstrings(pandas.crosstab, apilink="pandas.crosstab")
def crosstab(
index,
columns,
Expand All @@ -1998,21 +1996,319 @@ def crosstab(
) -> DataFrame: # noqa: PR01, RT01, D200
"""
Compute a simple cross tabulation of two (or more) factors.
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved

By default, computes a frequency table of the factors unless an array
of values and an aggregation function are passed.

Parameters
----------
index : array-like, Series, or list of arrays/Series
Values to group by in the rows.
columns : array-like, Series, or list of arrays/Series
Values to group by in the columns.
values : array-like, optional
Array of values to aggregate according to the factors.
Requires aggfunc be specified.
rownames : sequence, default None
If passed, must match number of row arrays passed.
colnames : sequence, default None
If passed, must match number of column arrays passed.
aggfunc : function, optional
If specified, requires values be specified as well.
margins : bool, default False
Add row/column margins (subtotals).
margins_name : str, default 'All'
Name of the row/column that will contain the totals when margins is True.
dropna : bool, default True
Do not include columns whose entries are all NaN.

normalize : bool, {'all', 'index', 'columns'}, or {0,1}, default False
Normalize by dividing all values by the sum of values.

* If passed 'all' or True, will normalize over all values.
* If passed 'index' will normalize over each row.
* If passed 'columns' will normalize over each column.
* If margins is True, will also normalize margin values.

Returns
-------
Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.DataFrame`
Cross tabulation of the data.

Notes
-----

Raises NotImplementedError if aggfunc is not one of "count", "mean", "min", "max", or "sum", or
margins is True, normalize is True or all, and values is passed.
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved

Examples
--------
>>> a = np.array(["foo", "foo", "foo", "foo", "bar", "bar",
... "bar", "bar", "foo", "foo", "foo"], dtype=object)
>>> b = np.array(["one", "one", "one", "two", "one", "one",
... "one", "two", "two", "two", "one"], dtype=object)
>>> c = np.array(["dull", "dull", "shiny", "dull", "dull", "shiny",
... "shiny", "dull", "shiny", "shiny", "shiny"],
... dtype=object)
>>> pd.crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c']) # doctest: +NORMALIZE_WHITESPACE
b one two
c dull shiny dull shiny
a
bar 1 2 1 0
foo 2 2 1 2
"""
# TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py
pandas_crosstab = pandas.crosstab(
index,
columns,
values,
rownames,
colnames,
aggfunc,
margins,
margins_name,
dropna,
normalize,
if values is None and aggfunc is not None:
raise ValueError("aggfunc cannot be used without values.")

if values is not None and aggfunc is None:
raise ValueError("values cannot be used without an aggfunc.")

if not is_nested_list_like(index):
index = [index]
if not is_nested_list_like(columns):
columns = [columns]

if (
values is not None
and margins is True
and (normalize is True or normalize == "all")
):
raise NotImplementedError(
'Snowpark pandas does not yet support passing in margins=True, normalize="all", and values.'
)

user_passed_rownames = rownames is not None
user_passed_colnames = colnames is not None

from pandas.core.reshape.pivot import _build_names_mapper, _get_names

def _get_names_wrapper(list_of_objs, names, prefix):
"""
Helper method to expand DataFrame objects containing
multiple columns into Series, since `_get_names` expects
one column per entry.
"""
expanded_list_of_objs = []
for obj in list_of_objs:
if isinstance(obj, DataFrame):
for col in obj.columns:
expanded_list_of_objs.append(obj[col])
else:
expanded_list_of_objs.append(obj)
return _get_names(expanded_list_of_objs, names, prefix)

rownames = _get_names_wrapper(index, rownames, prefix="row")
colnames = _get_names_wrapper(columns, colnames, prefix="col")

(
rownames_mapper,
unique_rownames,
colnames_mapper,
unique_colnames,
) = _build_names_mapper(rownames, colnames)

pass_objs = [x for x in index + columns if isinstance(x, (Series, DataFrame))]
row_idx_names = None
col_idx_names = None
if pass_objs:
# If we have any Snowpark pandas objects in the index or columns, then we
# need to find the intersection of their indices, and only pick rows from
# the objects that have indices in the intersection of their indices.
# After we do that, we then need to append the non Snowpark pandas objects
# using the intersection of indices as the final index for the DataFrame object.
# First, we separate the objects into Snowpark pandas objects, and non-Snowpark
# pandas objects (while renaming them so that they have unique names).
rownames_idx = 0
row_idx_names = []
dfs = []
arrays = []
array_lengths = []
for obj in index:
if isinstance(obj, Series):
row_idx_names.append(obj.name)
df = pd.DataFrame(obj)
df.columns = [unique_rownames[rownames_idx]]
rownames_idx += 1
dfs.append(df)
elif isinstance(obj, DataFrame):
row_idx_names.extend(obj.columns)
obj.columns = unique_rownames[
rownames_idx : rownames_idx + len(obj.columns)
]
rownames_idx += len(obj.columns)
dfs.append(obj)
else:
row_idx_names.append(None)
array_lengths.append(len(obj))
df = pd.DataFrame(obj)
df.columns = unique_rownames[
rownames_idx : rownames_idx + len(df.columns)
]
rownames_idx += len(df.columns)
arrays.append(df)

colnames_idx = 0
col_idx_names = []
for obj in columns:
if isinstance(obj, Series):
col_idx_names.append(obj.name)
df = pd.DataFrame(obj)
df.columns = [unique_colnames[colnames_idx]]
colnames_idx += 1
dfs.append(df)
elif isinstance(obj, DataFrame):
col_idx_names.extend(obj.columns)
obj.columns = unique_colnames[
colnames_idx : colnames_idx + len(obj.columns)
]
colnames_idx += len(obj.columns)
dfs.append(obj)
else:
col_idx_names.append(None)
array_lengths.append(len(obj))
df = pd.DataFrame(obj)
df.columns = unique_colnames[
colnames_idx : colnames_idx + len(df.columns)
]
colnames_idx += len(df.columns)
arrays.append(df)

if len(set(array_lengths)) > 1:
raise ValueError("All arrays must be of the same length")

# Now, we have two lists - a list of Snowpark pandas objects, and a list of objects
# that were not passed in as Snowpark pandas objects, but that we have converted
# to Snowpark pandas objects to give them column names. We can perform inner joins
# on the dfs list to get a DataFrame with the final index (that is only an intersection
# of indices.)
df = dfs[0]
for right in dfs[1:]:
df = df.merge(right, left_index=True, right_index=True)
if len(arrays) > 0:
index = df.index
right_df = pd.concat(arrays, axis=1)
# Increases query count by 1, but necessary for error checking.
index_length = len(df)
if index_length != array_lengths[0]:
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"Length mismatch: Expected {array_lengths[0]} rows, received array of length {index_length}"
)
right_df.index = index
df = df.merge(right_df, left_index=True, right_index=True)
else:
data = {
**dict(zip(unique_rownames, index)),
**dict(zip(unique_colnames, columns)),
}
df = DataFrame(data)

if values is None:
df["__dummy__"] = 0
kwargs = {"aggfunc": "count"}
else:
df["__dummy__"] = values
kwargs = {"aggfunc": aggfunc}

table = df.pivot_table(
sfc-gh-helmeleegy marked this conversation as resolved.
Show resolved Hide resolved
"__dummy__",
index=unique_rownames,
columns=unique_colnames,
margins=margins,
margins_name=margins_name,
dropna=dropna,
**kwargs, # type: ignore[arg-type]
)
return DataFrame(pandas_crosstab)

if row_idx_names is not None and not user_passed_rownames:
table.index = table.index.set_names(row_idx_names)

if col_idx_names is not None and not user_passed_colnames:
table.columns = table.columns.set_names(col_idx_names)

if aggfunc is None:
# If no aggfunc is provided, we are computing frequencies. Since we use
# pivot_table above, pairs that are not observed will get a NaN value,
# so we need to fill all NaN values with 0.
table = table.fillna(0)

# We must explicitly check that the value of normalize is not False here,
# as a valid value of normalize is `0` (for normalizing index).
if normalize is not False:
if normalize not in [0, 1, "index", "columns", "all", True]:
raise ValueError(f"Not a valid normalize argument: {normalize}")
if normalize is True:
normalize = "all"
normalize = {0: "index", 1: "columns"}.get(normalize, normalize)

# Actual Normalizations
normalizers: dict[bool | str, Callable] = {
"all": lambda x: x / x.sum(axis=0).sum(),
"columns": lambda x: x / x.sum(),
"index": lambda x: x.div(x.sum(axis=1), axis="index"),
}

if margins is False:

f = normalizers[normalize]
names = table.columns.names
table = f(table)
table.columns.names = names
table = table.fillna(0)
else:
# keep index and column of pivoted table
table_index = table.index
table_columns = table.columns

column_margin = table.iloc[:-1, -1]

if normalize == "columns":
# keep the core table
table = table.iloc[:-1, :-1]

# Normalize core
f = normalizers[normalize]
table = f(table)
table = table.fillna(0)
# Fix Margins
column_margin = column_margin / column_margin.sum()
table = pd.concat([table, column_margin], axis=1)
table = table.fillna(0)
table.columns = table_columns

elif normalize == "index":
table = table.iloc[:, :-1]

# Normalize core
f = normalizers[normalize]
table = f(table)
table = table.fillna(0).reindex(index=table_index)

elif normalize == "all":
# Normalize core
f = normalizers[normalize]

# When we perform the normalization function, we take the sum over
# the rows, and divide every value by the sum. Since margins is included
# though, the result of the sum is actually 2 * the sum of the original
# values (since the margin itself is the sum of the original values),
# so we need to multiply by 2 here to account for that.
# The alternative would be to apply normalization to the main table
# and the index margins separately, but that would require additional joins
# to get the final table, which we want to avoid.
table = f(table.iloc[:, :-1]) * 2.0
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved

column_margin = column_margin / column_margin.sum()
table = pd.concat([table, column_margin], axis=1)
table.iloc[-1, -1] = 1

table = table.fillna(0)
table.index = table_index
table.columns = table_columns

table = table.rename_axis(index=rownames_mapper, axis=0)
table = table.rename_axis(columns=colnames_mapper, axis=1)

return table


# Adding docstring since pandas docs don't have web section for this function.
Expand Down
Loading
Loading