From aa6b0278645a9d2d91c908cb377c9c82fbf2af4c Mon Sep 17 00:00:00 2001
From: Rehan Durrani <rehan.durrani@snowflake.com>
Date: Tue, 17 Sep 2024 13:14:38 -0700
Subject: [PATCH 1/6] [SNOW-1429199]: Provide clearer error message for groupby
 aggregations

---
 src/snowflake/snowpark/modin/pandas/groupby.py  | 17 ++++++++++++++---
 .../modin/groupby/test_groupby_basic_agg.py     | 10 ++++++++++
 2 files changed, 24 insertions(+), 3 deletions(-)

diff --git a/src/snowflake/snowpark/modin/pandas/groupby.py b/src/snowflake/snowpark/modin/pandas/groupby.py
index 72ae6a2b003..27f1ce4d042 100644
--- a/src/snowflake/snowpark/modin/pandas/groupby.py
+++ b/src/snowflake/snowpark/modin/pandas/groupby.py
@@ -1157,9 +1157,20 @@ def _wrap_aggregation(
             Returns the same type as `self._df`.
         """
         # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
-        numeric_only = validate_bool_kwarg(
-            numeric_only, "numeric_only", none_allowed=True
-        )
+        try:
+            numeric_only = validate_bool_kwarg(
+                numeric_only, "numeric_only", none_allowed=True
+            )
+        except ValueError:
+            # SNOW-1429199: Snowpark users expect to be able to pass in the column to aggregate
+            # on in the aggregation method, e.g. df.groupby("COL0").sum("COL1"), but the pandas
+            # API's only accept the numeric_only argument, so users get an error complaining that
+            # the numeric_only kwarg expects a bool argument, but a string was passed in. Instead
+            # of that error, we can throw this error instead that will make it more clear to users
+            # what went wrong.
+            raise ValueError(
+                f"GroupBy aggregations like sum take a numeric_only argument that needs to be a bool, but a {type(numeric_only).__name__} value was passed in."
+            )
 
         agg_args = tuple() if agg_args is None else agg_args
         agg_kwargs = dict() if agg_kwargs is None else agg_kwargs
diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py
index cbf5b75d48c..2d30de77996 100644
--- a/tests/integ/modin/groupby/test_groupby_basic_agg.py
+++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py
@@ -388,6 +388,16 @@ def test_string_sum(data):
     )
 
 
+@sql_count_checker(query_count=0)
+def test_groupby_sum_string_argument_exception():
+    snow_df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["key_col", "col1", "col2"])
+    with pytest.raises(
+        ValueError,
+        match="GroupBy aggregations like sum take a numeric_only argument that needs to be a bool, but a str value was passed in.",
+    ):
+        snow_df.groupby("key_col").sum("col1")
+
+
 @sql_count_checker(query_count=1)
 def test_string_sum_with_all_nulls_in_group_produces_empty_string():
     """

From ba02385200135679d958fcf214e5248cd4cba0a9 Mon Sep 17 00:00:00 2001
From: Rehan Durrani <rehan.durrani@snowflake.com>
Date: Tue, 17 Sep 2024 13:15:44 -0700
Subject: [PATCH 2/6] Add changelog

---
 CHANGELOG.md | 1 +
 1 file changed, 1 insertion(+)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 829c027c527..2dadb920993 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,7 @@
 #### Improvements
 
 - Improved `to_pandas` to persist the original timezone offset for TIMESTAMP_TZ type.
+- Improved error message when passing non-bool value to `numeric_only` for groupby aggregations.
 
 #### New Features
 

From f2730be2e1516a39b8faa1af521538048c482b46 Mon Sep 17 00:00:00 2001
From: Rehan Durrani <rehan.durrani@snowflake.com>
Date: Tue, 17 Sep 2024 13:18:46 -0700
Subject: [PATCH 3/6] Update src/snowflake/snowpark/modin/pandas/groupby.py

Co-authored-by: Naren Krishna <naren.krishna@snowflake.com>
---
 src/snowflake/snowpark/modin/pandas/groupby.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/snowflake/snowpark/modin/pandas/groupby.py b/src/snowflake/snowpark/modin/pandas/groupby.py
index 27f1ce4d042..e238dabb603 100644
--- a/src/snowflake/snowpark/modin/pandas/groupby.py
+++ b/src/snowflake/snowpark/modin/pandas/groupby.py
@@ -1169,7 +1169,7 @@ def _wrap_aggregation(
             # of that error, we can throw this error instead that will make it more clear to users
             # what went wrong.
             raise ValueError(
-                f"GroupBy aggregations like sum take a numeric_only argument that needs to be a bool, but a {type(numeric_only).__name__} value was passed in."
+                f"GroupBy aggregations like 'sum' take a 'numeric_only' argument that needs to be a bool, but a {type(numeric_only).__name__} value was passed in."
             )
 
         agg_args = tuple() if agg_args is None else agg_args

From 731641256988d265d4b80e259e91825361c41376 Mon Sep 17 00:00:00 2001
From: Rehan Durrani <rehan.durrani@snowflake.com>
Date: Tue, 17 Sep 2024 13:19:55 -0700
Subject: [PATCH 4/6] Address review comments

---
 tests/integ/modin/groupby/test_groupby_basic_agg.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py
index 2d30de77996..87d3f7ad402 100644
--- a/tests/integ/modin/groupby/test_groupby_basic_agg.py
+++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py
@@ -2,6 +2,7 @@
 # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
 #
 import logging
+import re
 from typing import Any
 
 import modin.pandas as pd
@@ -393,7 +394,9 @@ def test_groupby_sum_string_argument_exception():
     snow_df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["key_col", "col1", "col2"])
     with pytest.raises(
         ValueError,
-        match="GroupBy aggregations like sum take a numeric_only argument that needs to be a bool, but a str value was passed in.",
+        match=re.escape(
+            "GroupBy aggregations like 'sum' take a 'numeric_only' argument that needs to be a bool, but a str value was passed in."
+        ),
     ):
         snow_df.groupby("key_col").sum("col1")
 

From c1239d574f277ca7e36373aa4c75a70b51a0b246 Mon Sep 17 00:00:00 2001
From: Rehan Durrani <rehan.durrani@snowflake.com>
Date: Wed, 18 Sep 2024 14:20:50 -0700
Subject: [PATCH 5/6] Fix negative tests to check for new regex

---
 tests/integ/modin/groupby/test_groupby_negative.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/tests/integ/modin/groupby/test_groupby_negative.py b/tests/integ/modin/groupby/test_groupby_negative.py
index 0c9c056c2a7..faca799acad 100644
--- a/tests/integ/modin/groupby/test_groupby_negative.py
+++ b/tests/integ/modin/groupby/test_groupby_negative.py
@@ -536,7 +536,10 @@ def test_groupby_agg_invalid_numeric_only(
     # treated as True. This behavior is confusing to customer, in Snowpark pandas, we do an
     # explicit type check, an errors it out if an invalid value is given.
     with pytest.raises(
-        ValueError, match=re.escape('For argument "numeric_only" expected type bool')
+        ValueError,
+        match=re.escape(
+            "GroupBy aggregations like 'sum' take a 'numeric_only' argument that needs to be a bool, but a str value was passed in."
+        ),
     ):
         getattr(basic_snowpark_pandas_df.groupby("col1"), agg_method_name)(
             numeric_only=numeric_only

From c98970a44b1ce2e3234eaa58911d90b798a4fc6e Mon Sep 17 00:00:00 2001
From: Rehan Durrani <rehan.durrani@snowflake.com>
Date: Thu, 19 Sep 2024 09:00:10 -0700
Subject: [PATCH 6/6] Fix tests

---
 tests/integ/modin/groupby/test_groupby_basic_agg.py | 13 -------------
 tests/integ/modin/groupby/test_groupby_negative.py  |  2 +-
 2 files changed, 1 insertion(+), 14 deletions(-)

diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py
index 87d3f7ad402..cbf5b75d48c 100644
--- a/tests/integ/modin/groupby/test_groupby_basic_agg.py
+++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py
@@ -2,7 +2,6 @@
 # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
 #
 import logging
-import re
 from typing import Any
 
 import modin.pandas as pd
@@ -389,18 +388,6 @@ def test_string_sum(data):
     )
 
 
-@sql_count_checker(query_count=0)
-def test_groupby_sum_string_argument_exception():
-    snow_df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["key_col", "col1", "col2"])
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "GroupBy aggregations like 'sum' take a 'numeric_only' argument that needs to be a bool, but a str value was passed in."
-        ),
-    ):
-        snow_df.groupby("key_col").sum("col1")
-
-
 @sql_count_checker(query_count=1)
 def test_string_sum_with_all_nulls_in_group_produces_empty_string():
     """
diff --git a/tests/integ/modin/groupby/test_groupby_negative.py b/tests/integ/modin/groupby/test_groupby_negative.py
index faca799acad..45560862a30 100644
--- a/tests/integ/modin/groupby/test_groupby_negative.py
+++ b/tests/integ/modin/groupby/test_groupby_negative.py
@@ -538,7 +538,7 @@ def test_groupby_agg_invalid_numeric_only(
     with pytest.raises(
         ValueError,
         match=re.escape(
-            "GroupBy aggregations like 'sum' take a 'numeric_only' argument that needs to be a bool, but a str value was passed in."
+            f"GroupBy aggregations like 'sum' take a 'numeric_only' argument that needs to be a bool, but a {type(numeric_only).__name__} value was passed in."
         ),
     ):
         getattr(basic_snowpark_pandas_df.groupby("col1"), agg_method_name)(