Skip to content

Commit

Permalink
SNOW-1569916: fix local testing default timestamp timezone issue (#2114)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aling authored Aug 22, 2024
1 parent cd27c52 commit 70310ea
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#### Bug Fixes

- Fixed a bug where the truncate mode in `DataFrameWriter.save_as_table` incorrectly handled DataFrames containing only a subset of columns from the existing table.
- Fixed a bug where function `to_timestamp` does not set the default timezone of the column datatype.

### Snowpark pandas API Updates

Expand Down
8 changes: 7 additions & 1 deletion src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from snowflake.snowpark._internal.analyzer.expression import FunctionExpression
from snowflake.snowpark.mock._options import numpy, pandas
from snowflake.snowpark.mock._snowflake_data_type import (
_TIMESTAMP_TYPE_MAPPING,
_TIMESTAMP_TYPE_TIMEZONE_MAPPING,
ColumnEmulator,
ColumnType,
TableEmulator,
Expand Down Expand Up @@ -943,7 +945,11 @@ def mock_to_timestamp(
try_cast: bool = False,
):
result = mock_to_timestamp_ntz(column, fmt, try_cast)
result.sf_type = ColumnType(TimestampType(), column.sf_type.nullable)

result.sf_type = ColumnType(
TimestampType(_TIMESTAMP_TYPE_TIMEZONE_MAPPING[_TIMESTAMP_TYPE_MAPPING]),
column.sf_type.nullable,
)
return result


Expand Down
19 changes: 19 additions & 0 deletions src/snowflake/snowpark/mock/_snowflake_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MapType,
NullType,
StringType,
TimestampTimeZone,
TimestampType,
TimeType,
VariantType,
Expand All @@ -32,6 +33,17 @@
PandasDataframeType = object if not installed_pandas else pd.DataFrame
PandasSeriesType = object if not installed_pandas else pd.Series

# https://docs.snowflake.com/en/sql-reference/parameters#label-timestamp-type-mapping
# SNOW-1630258 for local testing session parameters support
_TIMESTAMP_TYPE_MAPPING = "TIMESTAMP_NTZ"


_TIMESTAMP_TYPE_TIMEZONE_MAPPING = {
"TIMESTAMP_NTZ": TimestampTimeZone.NTZ,
"TIMESTAMP_LTZ": TimestampTimeZone.LTZ,
"TIMESTAMP_TZ": TimestampTimeZone.TZ,
}


class Operator:
def op(self, *operands):
Expand Down Expand Up @@ -302,6 +314,13 @@ def coerce_t1_into_t2(t1: DataType, t2: DataType) -> Optional[DataType]:
elif isinstance(t1, (TimeType, TimestampType, MapType, ArrayType)):
if isinstance(t2, VariantType):
return t2
if isinstance(t1, TimestampType) and isinstance(t2, TimestampType):
if (
t1.tz is TimestampTimeZone.DEFAULT
and t2.tz is TimestampTimeZone.NTZ
and _TIMESTAMP_TYPE_MAPPING == "TIMESTAMP_NTZ"
):
return t2
return None


Expand Down
13 changes: 4 additions & 9 deletions tests/integ/scala/test_dataframe_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,7 @@ def test_flatten_in_session(session):
)


def test_createDataFrame_with_given_schema(session, local_testing_mode):
def test_createDataFrame_with_given_schema(session):
schema = StructType(
[
StructField("string", StringType(84)),
Expand Down Expand Up @@ -1734,12 +1734,7 @@ def test_createDataFrame_with_given_schema(session, local_testing_mode):
StructField("number", DecimalType(10, 3)),
StructField("boolean", BooleanType()),
StructField("binary", BinaryType()),
StructField(
"timestamp",
TimestampType(TimestampTimeZone.NTZ)
if not local_testing_mode
else TimestampType(),
), # depends on TIMESTAMP_TYPE_MAPPING
StructField("timestamp", TimestampType(TimestampTimeZone.NTZ)),
StructField("timestamp_ntz", TimestampType(TimestampTimeZone.NTZ)),
StructField("timestamp_ltz", TimestampType(TimestampTimeZone.LTZ)),
StructField("timestamp_tz", TimestampType(TimestampTimeZone.TZ)),
Expand All @@ -1765,7 +1760,7 @@ def test_createDataFrame_with_given_schema_time(session):
assert df.collect() == data


def test_createDataFrame_with_given_schema_timestamp(session, local_testing_mode):
def test_createDataFrame_with_given_schema_timestamp(session):
schema = StructType(
[
StructField("timestamp", TimestampType()),
Expand All @@ -1786,7 +1781,7 @@ def test_createDataFrame_with_given_schema_timestamp(session, local_testing_mode

assert (
schema_str
== f"StructType([StructField('TIMESTAMP', TimestampType({'' if local_testing_mode else 'tz=ntz'}), nullable=True), "
== "StructType([StructField('TIMESTAMP', TimestampType(tz=ntz), nullable=True), "
"StructField('TIMESTAMP_NTZ', TimestampType(tz=ntz), nullable=True), "
"StructField('TIMESTAMP_LTZ', TimestampType(tz=ltz), nullable=True), "
"StructField('TIMESTAMP_TZ', TimestampType(tz=tz), nullable=True)])"
Expand Down
29 changes: 29 additions & 0 deletions tests/integ/scala/test_function_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3857,6 +3857,35 @@ def test_convert_timezone(session, local_testing_mode):
],
)

df = TestData.datetime_primitives1(session).select("timestamp", "timestamp_ntz")

Utils.check_answer(
df.select(
*[
convert_timezone(lit("UTC"), col, lit("Asia/Shanghai"))
for col in df.columns
]
),
[
Row(
datetime(2024, 2, 1, 4, 0),
datetime(2017, 2, 24, 4, 0, 0, 456000),
)
],
)

df = TestData.datetime_primitives1(session).select(
"timestamp_ltz", "timestamp_tz"
)
with pytest.raises(SnowparkSQLException):
# convert_timezone function does not accept non-TimestampTimeZone.NTZ datetime
df.select(
*[
convert_timezone(lit("UTC"), col, lit("Asia/Shanghai"))
for col in df.columns
]
).collect()

LocalTimezone.set_local_timezone()


Expand Down

0 comments on commit 70310ea

Please sign in to comment.