diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 08d6c61efc4..5a469211743 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -756,12 +756,6 @@ def to_pandas( 2. If you use :func:`Session.sql` with this method, the input query of :func:`Session.sql` can only be a SELECT statement. """ - from snowflake.snowpark.mock.connection import MockServerConnection - - if isinstance(self._session._conn, MockServerConnection): - raise NotImplementedError( - "[Local Testing] DataFrame.to_pandas is not implemented." - ) result = self._session._conn.execute( self._plan, to_pandas=True, diff --git a/src/snowflake/snowpark/mock/connection.py b/src/snowflake/snowpark/mock/connection.py index 472b1912856..6e5f9b9fe91 100644 --- a/src/snowflake/snowpark/mock/connection.py +++ b/src/snowflake/snowpark/mock/connection.py @@ -16,7 +16,6 @@ import pandas as pd import snowflake.connector -from snowflake.connector.constants import FIELD_ID_TO_NAME from snowflake.connector.cursor import ResultMetadata, SnowflakeCursor from snowflake.connector.errors import NotSupportedError, ProgrammingError from snowflake.connector.network import ReauthenticationRequest @@ -46,7 +45,13 @@ from snowflake.snowpark.mock.snowflake_data_type import TableEmulator from snowflake.snowpark.mock.util import parse_table_name from snowflake.snowpark.row import Row -from snowflake.snowpark.types import ArrayType, MapType, VariantType +from snowflake.snowpark.types import ( + ArrayType, + DecimalType, + MapType, + VariantType, + _IntegralType, +) logger = getLogger(__name__) @@ -405,8 +410,22 @@ def execute( elif isinstance(res, list): rows = res + if to_pandas: + pandas_df = pd.DataFrame() + for col_name in res.columns: + pandas_df[unquote_if_quoted(col_name)] = res[col_name].tolist() + rows = _fix_pandas_df_integer(res) + + # the following implementation is just to make DataFrame.to_pandas_batches API workable + # in snowflake, large data result are split into multiple data chunks + # and sent back to the client, thus it makes sense to have the generator + # however, local testing is designed for local testing + # we do not mock the splitting into data chunks behavior + rows = [rows] if to_iter else rows + if to_iter: return iter(rows) + return rows @SnowflakePlan.Decorator.wrap_exception @@ -539,19 +558,44 @@ def get_result_query_id(self, plan: SnowflakePlan, **kwargs) -> str: return result_set["sfqid"] -def _fix_pandas_df_integer( - pd_df: "pandas.DataFrame", results_cursor: SnowflakeCursor -) -> "pandas.DataFrame": - for column_metadata, pandas_dtype, pandas_col_name in zip( - results_cursor.description, pd_df.dtypes, pd_df.columns - ): +def _fix_pandas_df_integer(table_res: TableEmulator) -> "pandas.DataFrame": + pd_df = pd.DataFrame() + for col_name in table_res.columns: + col_sf_type = table_res.sf_types[col_name] + pd_df_col_name = unquote_if_quoted(col_name) if ( - FIELD_ID_TO_NAME.get(column_metadata.type_code) == "FIXED" - and column_metadata.precision is not None - and column_metadata.scale == 0 - and not str(pandas_dtype).startswith("int") + isinstance(col_sf_type.datatype, DecimalType) + and col_sf_type.datatype.precision is not None + and col_sf_type.datatype.scale == 0 + and not str(table_res[col_name].dtype).startswith("int") ): - pd_df[pandas_col_name] = pandas.to_numeric( - pd_df[pandas_col_name], downcast="integer" + # if decimal is set to default 38, we auto-detect the dtype, see the following code + # df = session.create_dataframe( + # data=[[decimal.Decimal(1)]], + # schema=StructType([StructField("d", DecimalType())]) + # ) + # df.to_pandas() # the returned df is of dtype int8, instead of dtype int64 + if col_sf_type.datatype.precision == 38: + pd_df[pd_df_col_name] = pandas.to_numeric( + table_res[col_name], downcast="integer" + ) + continue + + # this is to mock the behavior that precision is explicitly set to non-default value 38 + # optimize pd.DataFrame dtype of integer to align the behavior with live connection + if col_sf_type.datatype.precision <= 2: + pd_df[pd_df_col_name] = table_res[col_name].astype("int8") + elif col_sf_type.datatype.precision <= 4: + pd_df[pd_df_col_name] = table_res[col_name].astype("int16") + elif col_sf_type.datatype.precision <= 8: + pd_df[pd_df_col_name] = table_res[col_name].astype("int32") + else: + pd_df[pd_df_col_name] = table_res[col_name].astype("int64") + elif isinstance(col_sf_type.datatype, _IntegralType): + pd_df[pd_df_col_name] = pandas.to_numeric( + table_res[col_name].tolist(), downcast="integer" ) + else: + pd_df[pd_df_col_name] = table_res[col_name].tolist() + return pd_df diff --git a/src/snowflake/snowpark/mock/functions.py b/src/snowflake/snowpark/mock/functions.py index 8d3a6a22bb2..18624d7f41c 100644 --- a/src/snowflake/snowpark/mock/functions.py +++ b/src/snowflake/snowpark/mock/functions.py @@ -524,7 +524,9 @@ def mock_to_timestamp( if data is None: res.append(None) continue - if auto_detect and data.isnumeric(): + if auto_detect and ( + isinstance(data, int) or (isinstance(data, str) and data.isnumeric()) + ): res.append( datetime.datetime.utcfromtimestamp(process_numeric_time(data)) ) diff --git a/tests/integ/test_df_to_pandas.py b/tests/integ/test_df_to_pandas.py index 6bfd48d5dc6..d9ebb69f721 100644 --- a/tests/integ/test_df_to_pandas.py +++ b/tests/integ/test_df_to_pandas.py @@ -21,6 +21,7 @@ ) +@pytest.mark.localtest def test_to_pandas_new_df_from_range(session): # Single column snowpark_df = session.range(3, 8) @@ -46,6 +47,7 @@ def test_to_pandas_new_df_from_range(session): assert all(pandas_df["OTHER"][i] == i + 3 for i in range(5)) +@pytest.mark.localtest @pytest.mark.parametrize("to_pandas_api", ["to_pandas", "to_pandas_batches"]) def test_to_pandas_cast_integer(session, to_pandas_api): snowpark_df = session.create_dataframe( @@ -122,14 +124,18 @@ def check_fetch_data_exception(query: str) -> None: @pytest.mark.skipif( IS_IN_STORED_PROC, reason="SNOW-507565: Need localaws for large result" ) -def test_to_pandas_batches(session): +@pytest.mark.localtest +def test_to_pandas_batches(session, local_testing_mode): df = session.range(100000).cache_result() iterator = df.to_pandas_batches() assert isinstance(iterator, Iterator) entire_pandas_df = df.to_pandas() pandas_df_list = list(df.to_pandas_batches()) - assert len(pandas_df_list) > 1 + if not local_testing_mode: + # in live session, large data result will be split into multiple chunks by snowflake + # local test does not split the data result chunk/is not intended for large data result chunk + assert len(pandas_df_list) > 1 assert_frame_equal(pd.concat(pandas_df_list, ignore_index=True), entire_pandas_df) for df_batch in df.to_pandas_batches(): diff --git a/tests/integ/test_pandas_to_df.py b/tests/integ/test_pandas_to_df.py index 12f257305bd..3ed272ecc3d 100644 --- a/tests/integ/test_pandas_to_df.py +++ b/tests/integ/test_pandas_to_df.py @@ -11,7 +11,6 @@ from pandas.testing import assert_frame_equal from snowflake.connector.errors import ProgrammingError -from snowflake.snowpark import Row from snowflake.snowpark._internal.utils import ( TempObjectType, random_name_for_temp_object, @@ -284,17 +283,8 @@ def test_create_dataframe_from_pandas(session, local_testing_mode): ) df = session.create_dataframe(pd) - # TODO: after to pandas support, we do not need if-else check - # https://snowflakecomputing.atlassian.net/browse/SNOW-786887 - if local_testing_mode: - assert df.collect() == [ - Row(1, 4.5, "t1", True), - Row(2, 7.5, "t2", False), - Row(3, 10.5, "t3", True), - ] - else: - results = df.to_pandas() - assert_frame_equal(results, pd, check_dtype=False) + results = df.to_pandas() + assert_frame_equal(results, pd, check_dtype=False) # pd = PandasDF( # [ diff --git a/tests/mock/test_to_pandas.py b/tests/mock/test_to_pandas.py new file mode 100644 index 00000000000..b1a53141206 --- /dev/null +++ b/tests/mock/test_to_pandas.py @@ -0,0 +1,161 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import datetime +import decimal + +import numpy as np +import pandas as pd +import pytest +from pandas.testing import assert_frame_equal + +from snowflake.snowpark import Session +from snowflake.snowpark.mock.connection import MockServerConnection +from snowflake.snowpark.types import ( + ArrayType, + BinaryType, + BooleanType, + ByteType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + MapType, + NullType, + ShortType, + StringType, + StructField, + StructType, + TimestampType, + TimeType, + VariantType, +) + +session = Session(MockServerConnection()) + + +@pytest.mark.localtest +def test_df_to_pandas_df(): + df = session.create_dataframe( + [ + [ + 1, + 1234567890, + True, + 1.23, + "abc", + b"abc", + datetime.datetime( + year=2023, month=10, day=30, hour=12, minute=12, second=12 + ), + ] + ], + schema=[ + "aaa", + "BBB", + "cCc", + "DdD", + "e e", + "ff ", + " gg", + ], + ) + + to_compare_df = pd.DataFrame( + { + "AAA": pd.Series( + [1], dtype=np.int8 + ), # int8 is the snowpark behavior, by default pandas use int64 + "BBB": pd.Series( + [1234567890], dtype=np.int32 + ), # int32 is the snowpark behavior, by default pandas use int64 + "CCC": pd.Series([True]), + "DDD": pd.Series([1.23]), + "e e": pd.Series(["abc"]), + "ff ": pd.Series([b"abc"]), + " gg": pd.Series( + [ + datetime.datetime( + year=2023, month=10, day=30, hour=12, minute=12, second=12 + ) + ] + ), + } + ) + + # assert_frame_equal also checks dtype + assert_frame_equal(df.to_pandas(), to_compare_df) + assert_frame_equal(list(df.to_pandas_batches())[0], to_compare_df) + + # check snowflake types explicitly + df = session.create_dataframe( + data=[ + [ + [1, 2, 3, 4], + b"123", + True, + 1, + datetime.date(year=2023, month=10, day=30), + decimal.Decimal(1), + 1.23, + 1.23, + 100, + 100, + None, + 100, + "abc", + datetime.datetime(2023, 10, 30, 12, 12, 12), + datetime.time(12, 12, 12), + {"a": "b"}, + {"a": "b"}, + ], + ], + schema=StructType( + [ + StructField("a", ArrayType()), + StructField("b", BinaryType()), + StructField("c", BooleanType()), + StructField("d", ByteType()), + StructField("e", DateType()), + StructField("f", DecimalType()), + StructField("g", DoubleType()), + StructField("h", FloatType()), + StructField("i", IntegerType()), + StructField("j", LongType()), + StructField("k", NullType()), + StructField("l", ShortType()), + StructField("m", StringType()), + StructField("n", TimestampType()), + StructField("o", TimeType()), + StructField("p", VariantType()), + StructField("q", MapType(StringType(), StringType())), + ] + ), + ) + + pandas_df = pd.DataFrame( + { + "A": pd.Series(["[\n 1,\n 2,\n 3,\n 4\n]"], dtype=object), + "B": pd.Series([b"123"], dtype=object), + "C": pd.Series([True], dtype=bool), + "D": pd.Series([1], dtype=np.int8), + "E": pd.Series([datetime.date(year=2023, month=10, day=30)], dtype=object), + "F": pd.Series([decimal.Decimal(1)], dtype=np.int8), + "G": pd.Series([1.23], dtype=np.float64), + "H": pd.Series([1.23], dtype=np.float64), + "I": pd.Series([100], dtype=np.int8), + "J": pd.Series([100], dtype=np.int8), + "K": pd.Series([None], dtype=object), + "L": pd.Series([100], dtype=np.int8), + "M": pd.Series(["abc"], dtype=object), + "N": pd.Series( + [datetime.datetime(2023, 10, 30, 12, 12, 12)], dtype="datetime64[ns]" + ), + "O": pd.Series([datetime.time(12, 12, 12)], dtype=object), + "P": pd.Series(['{\n "a": "b"\n}'], dtype=object), + "Q": pd.Series(['{\n "a": "b"\n}'], dtype=object), + } + ) + assert_frame_equal(df.to_pandas(), pandas_df)