Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1460351: support case when string type coercion #1731

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- Fixed a bug in convert_timezone that made the setting the source_timezone parameter return an error.
- Fixed a bug where creating DataFrame with empty data of type `DateType` raises `AttributeError`.
- Fixed a bug that table merge fails when update clause exists but no update takes place.
- Fixed a bug that case when expression did not handle string type coercion.

### Snowpark pandas API Updates

Expand Down
29 changes: 23 additions & 6 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
UnboundedPreceding,
WindowExpression,
)
from snowflake.snowpark.mock._snowflake_data_type import coerce_type_if_needed
from snowflake.snowpark.mock._udf_utils import (
coerce_variant_input,
remove_null_wrapper,
Expand Down Expand Up @@ -1992,13 +1993,21 @@ def _match_pattern(row) -> bool:
remaining = remaining[~remaining.index.isin(true_index)]

if output_data.sf_type:
if (
not isinstance(output_data.sf_type.datatype, NullType)
and output_data.sf_type != value.sf_type
if not isinstance(output_data.sf_type.datatype, NullType) and not (
types_are_compatible(
output_data.sf_type.datatype, value.sf_type.datatype
)
):
raise SnowparkLocalTestingException(
f"CaseWhen expressions have conflicting data types: {output_data.sf_type} != {value.sf_type}"
)
else:
output_data.sf_type = ColumnType(
coerce_type_if_needed(
output_data.sf_type.datatype, value.sf_type.datatype
),
output_data.sf_type.nullable,
)
else:
output_data.sf_type = value.sf_type

Expand All @@ -2008,13 +2017,21 @@ def _match_pattern(row) -> bool:
)
output_data[remaining.index] = value[remaining.index]
if output_data.sf_type:
if (
not isinstance(output_data.sf_type.datatype, NullType)
and output_data.sf_type.datatype != value.sf_type.datatype
if not isinstance(output_data.sf_type.datatype, NullType) and not (
types_are_compatible(
output_data.sf_type.datatype, value.sf_type.datatype
)
):
raise SnowparkLocalTestingException(
f"CaseWhen expressions have conflicting data types: {output_data.sf_type.datatype} != {value.sf_type.datatype}"
)
else:
output_data.sf_type = ColumnType(
coerce_type_if_needed(
output_data.sf_type.datatype, value.sf_type.datatype
),
output_data.sf_type.nullable,
)
else:
output_data.sf_type = value.sf_type
return output_data
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/mock/_snowflake_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FloatType,
IntegerType,
LongType,
StringType,
_IntegralType,
_NumericType,
)
Expand Down Expand Up @@ -530,3 +531,12 @@ def isnull(self):
result = super().isnull()
result.sf_type = ColumnType(BooleanType(), True)
return result


def coerce_type_if_needed(type1: DataType, type2: DataType) -> DataType:
from snowflake.snowpark.mock._udf_utils import types_are_compatible

if types_are_compatible(type1, type2):
if isinstance(type1, StringType) and isinstance(type2, StringType):
return StringType(max(type1.length, type2.length))
return type1
11 changes: 10 additions & 1 deletion tests/integ/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from snowflake.snowpark import DataFrame, Row
from snowflake.snowpark.functions import lit
from snowflake.snowpark.functions import col, lit, when
from snowflake.snowpark.types import (
BooleanType,
DecimalType,
Expand Down Expand Up @@ -423,3 +423,12 @@ def test_join_basic(session):
]
)
)


def test_case_when_type_coerce(session):
df_input = session.create_dataframe(
[1], StructType([StructField("col", LongType())])
)
assert df_input.withColumn(
"new_col", when((col("col").is_null()), lit("abc")).otherwise(lit("abcdef"))
).collect() == [Row(1, "abcdef")]
Loading