Skip to content

Commit

Permalink
SNOW-786887: implement df.to_pandas (#1117)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aling authored Nov 17, 2023
1 parent 017cfbf commit cb4d1e4
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 35 deletions.
6 changes: 0 additions & 6 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,12 +756,6 @@ def to_pandas(
2. If you use :func:`Session.sql` with this method, the input query of
:func:`Session.sql` can only be a SELECT statement.
"""
from snowflake.snowpark.mock.connection import MockServerConnection

if isinstance(self._session._conn, MockServerConnection):
raise NotImplementedError(
"[Local Testing] DataFrame.to_pandas is not implemented."
)
result = self._session._conn.execute(
self._plan,
to_pandas=True,
Expand Down
72 changes: 58 additions & 14 deletions src/snowflake/snowpark/mock/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pandas as pd

import snowflake.connector
from snowflake.connector.constants import FIELD_ID_TO_NAME
from snowflake.connector.cursor import ResultMetadata, SnowflakeCursor
from snowflake.connector.errors import NotSupportedError, ProgrammingError
from snowflake.connector.network import ReauthenticationRequest
Expand Down Expand Up @@ -46,7 +45,13 @@
from snowflake.snowpark.mock.snowflake_data_type import TableEmulator
from snowflake.snowpark.mock.util import parse_table_name
from snowflake.snowpark.row import Row
from snowflake.snowpark.types import ArrayType, MapType, VariantType
from snowflake.snowpark.types import (
ArrayType,
DecimalType,
MapType,
VariantType,
_IntegralType,
)

logger = getLogger(__name__)

Expand Down Expand Up @@ -405,8 +410,22 @@ def execute(
elif isinstance(res, list):
rows = res

if to_pandas:
pandas_df = pd.DataFrame()
for col_name in res.columns:
pandas_df[unquote_if_quoted(col_name)] = res[col_name].tolist()
rows = _fix_pandas_df_integer(res)

# the following implementation is just to make DataFrame.to_pandas_batches API workable
# in snowflake, large data result are split into multiple data chunks
# and sent back to the client, thus it makes sense to have the generator
# however, local testing is designed for local testing
# we do not mock the splitting into data chunks behavior
rows = [rows] if to_iter else rows

if to_iter:
return iter(rows)

return rows

@SnowflakePlan.Decorator.wrap_exception
Expand Down Expand Up @@ -539,19 +558,44 @@ def get_result_query_id(self, plan: SnowflakePlan, **kwargs) -> str:
return result_set["sfqid"]


def _fix_pandas_df_integer(
pd_df: "pandas.DataFrame", results_cursor: SnowflakeCursor
) -> "pandas.DataFrame":
for column_metadata, pandas_dtype, pandas_col_name in zip(
results_cursor.description, pd_df.dtypes, pd_df.columns
):
def _fix_pandas_df_integer(table_res: TableEmulator) -> "pandas.DataFrame":
pd_df = pd.DataFrame()
for col_name in table_res.columns:
col_sf_type = table_res.sf_types[col_name]
pd_df_col_name = unquote_if_quoted(col_name)
if (
FIELD_ID_TO_NAME.get(column_metadata.type_code) == "FIXED"
and column_metadata.precision is not None
and column_metadata.scale == 0
and not str(pandas_dtype).startswith("int")
isinstance(col_sf_type.datatype, DecimalType)
and col_sf_type.datatype.precision is not None
and col_sf_type.datatype.scale == 0
and not str(table_res[col_name].dtype).startswith("int")
):
pd_df[pandas_col_name] = pandas.to_numeric(
pd_df[pandas_col_name], downcast="integer"
# if decimal is set to default 38, we auto-detect the dtype, see the following code
# df = session.create_dataframe(
# data=[[decimal.Decimal(1)]],
# schema=StructType([StructField("d", DecimalType())])
# )
# df.to_pandas() # the returned df is of dtype int8, instead of dtype int64
if col_sf_type.datatype.precision == 38:
pd_df[pd_df_col_name] = pandas.to_numeric(
table_res[col_name], downcast="integer"
)
continue

# this is to mock the behavior that precision is explicitly set to non-default value 38
# optimize pd.DataFrame dtype of integer to align the behavior with live connection
if col_sf_type.datatype.precision <= 2:
pd_df[pd_df_col_name] = table_res[col_name].astype("int8")
elif col_sf_type.datatype.precision <= 4:
pd_df[pd_df_col_name] = table_res[col_name].astype("int16")
elif col_sf_type.datatype.precision <= 8:
pd_df[pd_df_col_name] = table_res[col_name].astype("int32")
else:
pd_df[pd_df_col_name] = table_res[col_name].astype("int64")
elif isinstance(col_sf_type.datatype, _IntegralType):
pd_df[pd_df_col_name] = pandas.to_numeric(
table_res[col_name].tolist(), downcast="integer"
)
else:
pd_df[pd_df_col_name] = table_res[col_name].tolist()

return pd_df
4 changes: 3 additions & 1 deletion src/snowflake/snowpark/mock/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,9 @@ def mock_to_timestamp(
if data is None:
res.append(None)
continue
if auto_detect and data.isnumeric():
if auto_detect and (
isinstance(data, int) or (isinstance(data, str) and data.isnumeric())
):
res.append(
datetime.datetime.utcfromtimestamp(process_numeric_time(data))
)
Expand Down
10 changes: 8 additions & 2 deletions tests/integ/test_df_to_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)


@pytest.mark.localtest
def test_to_pandas_new_df_from_range(session):
# Single column
snowpark_df = session.range(3, 8)
Expand All @@ -46,6 +47,7 @@ def test_to_pandas_new_df_from_range(session):
assert all(pandas_df["OTHER"][i] == i + 3 for i in range(5))


@pytest.mark.localtest
@pytest.mark.parametrize("to_pandas_api", ["to_pandas", "to_pandas_batches"])
def test_to_pandas_cast_integer(session, to_pandas_api):
snowpark_df = session.create_dataframe(
Expand Down Expand Up @@ -122,14 +124,18 @@ def check_fetch_data_exception(query: str) -> None:
@pytest.mark.skipif(
IS_IN_STORED_PROC, reason="SNOW-507565: Need localaws for large result"
)
def test_to_pandas_batches(session):
@pytest.mark.localtest
def test_to_pandas_batches(session, local_testing_mode):
df = session.range(100000).cache_result()
iterator = df.to_pandas_batches()
assert isinstance(iterator, Iterator)

entire_pandas_df = df.to_pandas()
pandas_df_list = list(df.to_pandas_batches())
assert len(pandas_df_list) > 1
if not local_testing_mode:
# in live session, large data result will be split into multiple chunks by snowflake
# local test does not split the data result chunk/is not intended for large data result chunk
assert len(pandas_df_list) > 1
assert_frame_equal(pd.concat(pandas_df_list, ignore_index=True), entire_pandas_df)

for df_batch in df.to_pandas_batches():
Expand Down
14 changes: 2 additions & 12 deletions tests/integ/test_pandas_to_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pandas.testing import assert_frame_equal

from snowflake.connector.errors import ProgrammingError
from snowflake.snowpark import Row
from snowflake.snowpark._internal.utils import (
TempObjectType,
random_name_for_temp_object,
Expand Down Expand Up @@ -284,17 +283,8 @@ def test_create_dataframe_from_pandas(session, local_testing_mode):
)

df = session.create_dataframe(pd)
# TODO: after to pandas support, we do not need if-else check
# https://snowflakecomputing.atlassian.net/browse/SNOW-786887
if local_testing_mode:
assert df.collect() == [
Row(1, 4.5, "t1", True),
Row(2, 7.5, "t2", False),
Row(3, 10.5, "t3", True),
]
else:
results = df.to_pandas()
assert_frame_equal(results, pd, check_dtype=False)
results = df.to_pandas()
assert_frame_equal(results, pd, check_dtype=False)

# pd = PandasDF(
# [
Expand Down
161 changes: 161 additions & 0 deletions tests/mock/test_to_pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
import datetime
import decimal

import numpy as np
import pandas as pd
import pytest
from pandas.testing import assert_frame_equal

from snowflake.snowpark import Session
from snowflake.snowpark.mock.connection import MockServerConnection
from snowflake.snowpark.types import (
ArrayType,
BinaryType,
BooleanType,
ByteType,
DateType,
DecimalType,
DoubleType,
FloatType,
IntegerType,
LongType,
MapType,
NullType,
ShortType,
StringType,
StructField,
StructType,
TimestampType,
TimeType,
VariantType,
)

session = Session(MockServerConnection())


@pytest.mark.localtest
def test_df_to_pandas_df():
df = session.create_dataframe(
[
[
1,
1234567890,
True,
1.23,
"abc",
b"abc",
datetime.datetime(
year=2023, month=10, day=30, hour=12, minute=12, second=12
),
]
],
schema=[
"aaa",
"BBB",
"cCc",
"DdD",
"e e",
"ff ",
" gg",
],
)

to_compare_df = pd.DataFrame(
{
"AAA": pd.Series(
[1], dtype=np.int8
), # int8 is the snowpark behavior, by default pandas use int64
"BBB": pd.Series(
[1234567890], dtype=np.int32
), # int32 is the snowpark behavior, by default pandas use int64
"CCC": pd.Series([True]),
"DDD": pd.Series([1.23]),
"e e": pd.Series(["abc"]),
"ff ": pd.Series([b"abc"]),
" gg": pd.Series(
[
datetime.datetime(
year=2023, month=10, day=30, hour=12, minute=12, second=12
)
]
),
}
)

# assert_frame_equal also checks dtype
assert_frame_equal(df.to_pandas(), to_compare_df)
assert_frame_equal(list(df.to_pandas_batches())[0], to_compare_df)

# check snowflake types explicitly
df = session.create_dataframe(
data=[
[
[1, 2, 3, 4],
b"123",
True,
1,
datetime.date(year=2023, month=10, day=30),
decimal.Decimal(1),
1.23,
1.23,
100,
100,
None,
100,
"abc",
datetime.datetime(2023, 10, 30, 12, 12, 12),
datetime.time(12, 12, 12),
{"a": "b"},
{"a": "b"},
],
],
schema=StructType(
[
StructField("a", ArrayType()),
StructField("b", BinaryType()),
StructField("c", BooleanType()),
StructField("d", ByteType()),
StructField("e", DateType()),
StructField("f", DecimalType()),
StructField("g", DoubleType()),
StructField("h", FloatType()),
StructField("i", IntegerType()),
StructField("j", LongType()),
StructField("k", NullType()),
StructField("l", ShortType()),
StructField("m", StringType()),
StructField("n", TimestampType()),
StructField("o", TimeType()),
StructField("p", VariantType()),
StructField("q", MapType(StringType(), StringType())),
]
),
)

pandas_df = pd.DataFrame(
{
"A": pd.Series(["[\n 1,\n 2,\n 3,\n 4\n]"], dtype=object),
"B": pd.Series([b"123"], dtype=object),
"C": pd.Series([True], dtype=bool),
"D": pd.Series([1], dtype=np.int8),
"E": pd.Series([datetime.date(year=2023, month=10, day=30)], dtype=object),
"F": pd.Series([decimal.Decimal(1)], dtype=np.int8),
"G": pd.Series([1.23], dtype=np.float64),
"H": pd.Series([1.23], dtype=np.float64),
"I": pd.Series([100], dtype=np.int8),
"J": pd.Series([100], dtype=np.int8),
"K": pd.Series([None], dtype=object),
"L": pd.Series([100], dtype=np.int8),
"M": pd.Series(["abc"], dtype=object),
"N": pd.Series(
[datetime.datetime(2023, 10, 30, 12, 12, 12)], dtype="datetime64[ns]"
),
"O": pd.Series([datetime.time(12, 12, 12)], dtype=object),
"P": pd.Series(['{\n "a": "b"\n}'], dtype=object),
"Q": pd.Series(['{\n "a": "b"\n}'], dtype=object),
}
)
assert_frame_equal(df.to_pandas(), pandas_df)

0 comments on commit cb4d1e4

Please sign in to comment.