Skip to content

Commit

Permalink
[SNOW-1502893]: Add support for pd.crosstab (#1837)
Browse files Browse the repository at this point in the history
<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1502893

2. Fill out the following pre-review checklist:

- [ ] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.

Add support for pd.crosstab.
  • Loading branch information
sfc-gh-rdurrani authored Aug 31, 2024
1 parent e1149ca commit 3c1db07
Show file tree
Hide file tree
Showing 8 changed files with 1,071 additions and 29 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
- Added support for `DataFrameGroupBy.value_counts` and `SeriesGroupBy.value_counts`.
- Added support for `Series.is_monotonic_increasing` and `Series.is_monotonic_decreasing`.
- Added support for `Index.is_monotonic_increasing` and `Index.is_monotonic_decreasing`.
- Added support for `pd.crosstab`.

#### Improvements

Expand Down
1 change: 1 addition & 0 deletions docs/source/modin/general_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ General functions
:toctree: pandas_api/

melt
crosstab
pivot
pivot_table
cut
Expand Down
5 changes: 4 additions & 1 deletion docs/source/modin/supported/general_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ Data manipulations
| ``concat`` | P | ``levels`` is not supported, | |
| | | ``copy`` is ignored | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``crosstab`` | N | | |
| ``crosstab`` | P | | ``N`` if ``aggfunc`` is not one of |
| | | | "count", "mean", "min", "max", or "sum", or |
| | | | margins is True, normalize is "all" or True, |
| | | | and values is passed. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``cut`` | P | ``retbins``, ``labels`` | ``N`` if ``retbins=True``or ``labels!=False`` |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
330 changes: 313 additions & 17 deletions src/snowflake/snowpark/modin/pandas/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""Implement pandas general API."""
from __future__ import annotations

from collections.abc import Hashable, Iterable, Mapping, Sequence
from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
from datetime import date, datetime, timedelta, tzinfo
from logging import getLogger
from typing import TYPE_CHECKING, Any, Literal, Union
Expand All @@ -49,7 +49,7 @@
_infer_tz_from_endpoints,
_maybe_normalize_endpoints,
)
from pandas.core.dtypes.common import is_list_like
from pandas.core.dtypes.common import is_list_like, is_nested_list_like
from pandas.core.dtypes.inference import is_array_like
from pandas.core.tools.datetimes import (
ArrayConvertible,
Expand Down Expand Up @@ -1982,8 +1982,6 @@ def melt(


@snowpark_pandas_telemetry_standalone_function_decorator
@pandas_module_level_function_not_implemented()
@_inherit_docstrings(pandas.crosstab, apilink="pandas.crosstab")
def crosstab(
index,
columns,
Expand All @@ -1998,21 +1996,319 @@ def crosstab(
) -> DataFrame: # noqa: PR01, RT01, D200
"""
Compute a simple cross tabulation of two (or more) factors.
By default, computes a frequency table of the factors unless an array
of values and an aggregation function are passed.
Parameters
----------
index : array-like, Series, or list of arrays/Series
Values to group by in the rows.
columns : array-like, Series, or list of arrays/Series
Values to group by in the columns.
values : array-like, optional
Array of values to aggregate according to the factors.
Requires aggfunc be specified.
rownames : sequence, default None
If passed, must match number of row arrays passed.
colnames : sequence, default None
If passed, must match number of column arrays passed.
aggfunc : function, optional
If specified, requires values be specified as well.
margins : bool, default False
Add row/column margins (subtotals).
margins_name : str, default 'All'
Name of the row/column that will contain the totals when margins is True.
dropna : bool, default True
Do not include columns whose entries are all NaN.
normalize : bool, {'all', 'index', 'columns'}, or {0,1}, default False
Normalize by dividing all values by the sum of values.
* If passed 'all' or True, will normalize over all values.
* If passed 'index' will normalize over each row.
* If passed 'columns' will normalize over each column.
* If margins is True, will also normalize margin values.
Returns
-------
Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.DataFrame`
Cross tabulation of the data.
Notes
-----
Raises NotImplementedError if aggfunc is not one of "count", "mean", "min", "max", or "sum", or
margins is True, normalize is True or all, and values is passed.
Examples
--------
>>> a = np.array(["foo", "foo", "foo", "foo", "bar", "bar",
... "bar", "bar", "foo", "foo", "foo"], dtype=object)
>>> b = np.array(["one", "one", "one", "two", "one", "one",
... "one", "two", "two", "two", "one"], dtype=object)
>>> c = np.array(["dull", "dull", "shiny", "dull", "dull", "shiny",
... "shiny", "dull", "shiny", "shiny", "shiny"],
... dtype=object)
>>> pd.crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c']) # doctest: +NORMALIZE_WHITESPACE
b one two
c dull shiny dull shiny
a
bar 1 2 1 0
foo 2 2 1 2
"""
# TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py
pandas_crosstab = pandas.crosstab(
index,
columns,
values,
rownames,
colnames,
aggfunc,
margins,
margins_name,
dropna,
normalize,
if values is None and aggfunc is not None:
raise ValueError("aggfunc cannot be used without values.")

if values is not None and aggfunc is None:
raise ValueError("values cannot be used without an aggfunc.")

if not is_nested_list_like(index):
index = [index]
if not is_nested_list_like(columns):
columns = [columns]

if (
values is not None
and margins is True
and (normalize is True or normalize == "all")
):
raise NotImplementedError(
'Snowpark pandas does not yet support passing in margins=True, normalize="all", and values.'
)

user_passed_rownames = rownames is not None
user_passed_colnames = colnames is not None

from pandas.core.reshape.pivot import _build_names_mapper, _get_names

def _get_names_wrapper(list_of_objs, names, prefix):
"""
Helper method to expand DataFrame objects containing
multiple columns into Series, since `_get_names` expects
one column per entry.
"""
expanded_list_of_objs = []
for obj in list_of_objs:
if isinstance(obj, DataFrame):
for col in obj.columns:
expanded_list_of_objs.append(obj[col])
else:
expanded_list_of_objs.append(obj)
return _get_names(expanded_list_of_objs, names, prefix)

rownames = _get_names_wrapper(index, rownames, prefix="row")
colnames = _get_names_wrapper(columns, colnames, prefix="col")

(
rownames_mapper,
unique_rownames,
colnames_mapper,
unique_colnames,
) = _build_names_mapper(rownames, colnames)

pass_objs = [x for x in index + columns if isinstance(x, (Series, DataFrame))]
row_idx_names = None
col_idx_names = None
if pass_objs:
# If we have any Snowpark pandas objects in the index or columns, then we
# need to find the intersection of their indices, and only pick rows from
# the objects that have indices in the intersection of their indices.
# After we do that, we then need to append the non Snowpark pandas objects
# using the intersection of indices as the final index for the DataFrame object.
# First, we separate the objects into Snowpark pandas objects, and non-Snowpark
# pandas objects (while renaming them so that they have unique names).
rownames_idx = 0
row_idx_names = []
dfs = []
arrays = []
array_lengths = []
for obj in index:
if isinstance(obj, Series):
row_idx_names.append(obj.name)
df = pd.DataFrame(obj)
df.columns = [unique_rownames[rownames_idx]]
rownames_idx += 1
dfs.append(df)
elif isinstance(obj, DataFrame):
row_idx_names.extend(obj.columns)
obj.columns = unique_rownames[
rownames_idx : rownames_idx + len(obj.columns)
]
rownames_idx += len(obj.columns)
dfs.append(obj)
else:
row_idx_names.append(None)
array_lengths.append(len(obj))
df = pd.DataFrame(obj)
df.columns = unique_rownames[
rownames_idx : rownames_idx + len(df.columns)
]
rownames_idx += len(df.columns)
arrays.append(df)

colnames_idx = 0
col_idx_names = []
for obj in columns:
if isinstance(obj, Series):
col_idx_names.append(obj.name)
df = pd.DataFrame(obj)
df.columns = [unique_colnames[colnames_idx]]
colnames_idx += 1
dfs.append(df)
elif isinstance(obj, DataFrame):
col_idx_names.extend(obj.columns)
obj.columns = unique_colnames[
colnames_idx : colnames_idx + len(obj.columns)
]
colnames_idx += len(obj.columns)
dfs.append(obj)
else:
col_idx_names.append(None)
array_lengths.append(len(obj))
df = pd.DataFrame(obj)
df.columns = unique_colnames[
colnames_idx : colnames_idx + len(df.columns)
]
colnames_idx += len(df.columns)
arrays.append(df)

if len(set(array_lengths)) > 1:
raise ValueError("All arrays must be of the same length")

# Now, we have two lists - a list of Snowpark pandas objects, and a list of objects
# that were not passed in as Snowpark pandas objects, but that we have converted
# to Snowpark pandas objects to give them column names. We can perform inner joins
# on the dfs list to get a DataFrame with the final index (that is only an intersection
# of indices.)
df = dfs[0]
for right in dfs[1:]:
df = df.merge(right, left_index=True, right_index=True)
if len(arrays) > 0:
index = df.index
right_df = pd.concat(arrays, axis=1)
# Increases query count by 1, but necessary for error checking.
index_length = len(df)
if index_length != array_lengths[0]:
raise ValueError(
f"Length mismatch: Expected {array_lengths[0]} rows, received array of length {index_length}"
)
right_df.index = index
df = df.merge(right_df, left_index=True, right_index=True)
else:
data = {
**dict(zip(unique_rownames, index)),
**dict(zip(unique_colnames, columns)),
}
df = DataFrame(data)

if values is None:
df["__dummy__"] = 0
kwargs = {"aggfunc": "count"}
else:
df["__dummy__"] = values
kwargs = {"aggfunc": aggfunc}

table = df.pivot_table(
"__dummy__",
index=unique_rownames,
columns=unique_colnames,
margins=margins,
margins_name=margins_name,
dropna=dropna,
**kwargs, # type: ignore[arg-type]
)
return DataFrame(pandas_crosstab)

if row_idx_names is not None and not user_passed_rownames:
table.index = table.index.set_names(row_idx_names)

if col_idx_names is not None and not user_passed_colnames:
table.columns = table.columns.set_names(col_idx_names)

if aggfunc is None:
# If no aggfunc is provided, we are computing frequencies. Since we use
# pivot_table above, pairs that are not observed will get a NaN value,
# so we need to fill all NaN values with 0.
table = table.fillna(0)

# We must explicitly check that the value of normalize is not False here,
# as a valid value of normalize is `0` (for normalizing index).
if normalize is not False:
if normalize not in [0, 1, "index", "columns", "all", True]:
raise ValueError("Not a valid normalize argument")
if normalize is True:
normalize = "all"
normalize = {0: "index", 1: "columns"}.get(normalize, normalize)

# Actual Normalizations
normalizers: dict[bool | str, Callable] = {
"all": lambda x: x / x.sum(axis=0).sum(),
"columns": lambda x: x / x.sum(),
"index": lambda x: x.div(x.sum(axis=1), axis="index"),
}

if margins is False:

f = normalizers[normalize]
names = table.columns.names
table = f(table)
table.columns.names = names
table = table.fillna(0)
else:
# keep index and column of pivoted table
table_index = table.index
table_columns = table.columns

column_margin = table.iloc[:-1, -1]

if normalize == "columns":
# keep the core table
table = table.iloc[:-1, :-1]

# Normalize core
f = normalizers[normalize]
table = f(table)
table = table.fillna(0)
# Fix Margins
column_margin = column_margin / column_margin.sum()
table = pd.concat([table, column_margin], axis=1)
table = table.fillna(0)
table.columns = table_columns

elif normalize == "index":
table = table.iloc[:, :-1]

# Normalize core
f = normalizers[normalize]
table = f(table)
table = table.fillna(0).reindex(index=table_index)

elif normalize == "all":
# Normalize core
f = normalizers[normalize]

# When we perform the normalization function, we take the sum over
# the rows, and divide every value by the sum. Since margins is included
# though, the result of the sum is actually 2 * the sum of the original
# values (since the margin itself is the sum of the original values),
# so we need to multiply by 2 here to account for that.
# The alternative would be to apply normalization to the main table
# and the index margins separately, but that would require additional joins
# to get the final table, which we want to avoid.
table = f(table.iloc[:, :-1]) * 2.0

column_margin = column_margin / column_margin.sum()
table = pd.concat([table, column_margin], axis=1)
table.iloc[-1, -1] = 1

table = table.fillna(0)
table.index = table_index
table.columns = table_columns

table = table.rename_axis(index=rownames_mapper, axis=0)
table = table.rename_axis(columns=colnames_mapper, axis=1)

return table


# Adding docstring since pandas docs don't have web section for this function.
Expand Down
Loading

0 comments on commit 3c1db07

Please sign in to comment.