Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-958008: fix local test division datatypes #1127

Merged
merged 9 commits into from
Nov 22, 2023
4 changes: 3 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
VALID_SNOWPARK_TYPES_FOR_LITERAL_VALUE,
infer_type,
)
from snowflake.snowpark.types import DataType
from snowflake.snowpark.types import DataType, StringType

COLUMN_DEPENDENCY_DOLLAR = frozenset(
"$"
Expand Down Expand Up @@ -214,6 +214,8 @@ def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None:
self.datatype = datatype
else:
self.datatype = infer_type(value)
if isinstance(self.datatype, StringType):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-aalam Are you good with this change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. approved the other PR which adds this change

self.datatype = StringType(len(value))


class Like(Expression):
Expand Down
18 changes: 17 additions & 1 deletion src/snowflake/snowpark/mock/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
import time
from copy import copy
from decimal import Decimal
from logging import getLogger
from typing import IO, Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from unittest.mock import Mock
Expand Down Expand Up @@ -403,8 +404,23 @@ def execute(
# to align with snowflake behavior, we unquote name here
columns = [unquote_if_quoted(col_name) for col_name in res.columns]
rows = []
# TODO: SNOW-976145, move to index based approach to store col type mapping
# for now we only use the index based approach in aggregation functions
if res.sf_types_by_col_index:
keys = sorted(res.sf_types_by_col_index.keys())
sf_types = [res.sf_types_by_col_index[key] for key in keys]
else:
sf_types = list(res.sf_types.values())
for pdr in res.itertuples(index=False, name=None):
row = Row(*pdr)
row = Row(
*[
Decimal(str(v))
if isinstance(sf_types[i].datatype, DecimalType)
and v is not None
else v
for i, v in enumerate(pdr)
]
)
row._fields = columns
rows.append(row)
elif isinstance(res, list):
Expand Down
22 changes: 16 additions & 6 deletions src/snowflake/snowpark/mock/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,16 +745,27 @@ def mock_iff(condition: ColumnEmulator, expr1: ColumnEmulator, expr2: ColumnEmul
if (
all(condition)
or all(~condition)
or (
isinstance(expr1.sf_type.datatype, StringType)
and isinstance(expr2.sf_type.datatype, StringType)
)
or expr1.sf_type.datatype == expr2.sf_type.datatype
or isinstance(expr1.sf_type.datatype, NullType)
or isinstance(expr2.sf_type.datatype, NullType)
):
res = ColumnEmulator(data=[None] * len(condition), dtype=object)
sf_data_type = (
expr1.sf_type.datatype
if any(condition) and not isinstance(expr1.sf_type.datatype, NullType)
else expr2.sf_type.datatype
)
if isinstance(expr1.sf_type.datatype, StringType) and isinstance(
expr2.sf_type.datatype, StringType
):
l1 = expr1.sf_type.datatype.length or StringType._MAX_LENGTH
l2 = expr2.sf_type.datatype.length or StringType._MAX_LENGTH
sf_data_type = StringType(max(l1, l2))
else:
sf_data_type = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not using calculate_datatype to compute the return type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. I finally wanted to limit the scope and fixed the problem locally.

expr1.sf_type.datatype
if any(condition) and not isinstance(expr1.sf_type.datatype, NullType)
else expr2.sf_type.datatype
)
nullability = expr1.sf_type.nullable and expr2.sf_type.nullable
res.sf_type = ColumnType(sf_data_type, nullability)
res.where(condition, other=expr2, inplace=True)
Expand Down Expand Up @@ -903,4 +914,3 @@ def mock_to_variant(expr: ColumnEmulator):
res = expr.copy()
res.sf_type = ColumnType(VariantType(), expr.sf_type.nullable)
return res

40 changes: 32 additions & 8 deletions src/snowflake/snowpark/mock/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def execute_mock_plan(
result_df.insert(len(result_df.columns), str(i), from_df.iloc[:, i])
result_df.columns = from_df.columns
result_df.sf_types = from_df.sf_types
result_df.sf_types_by_col_index = from_df.sf_types_by_col_index
elif (
isinstance(exp, UnresolvedAlias)
and exp.child
Expand Down Expand Up @@ -537,6 +538,8 @@ def execute_mock_plan(
)

result_df_sf_Types = {}
result_df_sf_Types_by_col_idx = {}

column_exps = [
(
plan.session._analyzer.analyze(exp),
Expand All @@ -547,10 +550,11 @@ def execute_mock_plan(
)
for exp in source_plan.grouping_expressions
]
for column_name, _, column_type in column_exps:
for idx, (column_name, _, column_type) in enumerate(column_exps):
result_df_sf_Types[
column_name
] = column_type # TODO: fix this, this does not work
result_df_sf_Types_by_col_idx[idx] = column_type
# Aggregate may not have column_exps, which is allowed in the case of `Dataframe.agg`, in this case we pass
# lambda x: True as the `by` parameter
# also pandas group by takes None and nan as the same, so we use .astype to differentiate the two
Expand Down Expand Up @@ -607,10 +611,16 @@ def aggregate_by_groups(cur_group: TableEmulator):
values.append(cal_exp_res.iat[0])
result_df_sf_Types[
columns[idx + len(column_exps)]
] = result_df_sf_Types_by_col_idx[
idx + len(column_exps)
] = cal_exp_res.sf_type
else:
values.append(cal_exp_res)
result_df_sf_Types[columns[idx] + len(column_exps)] = ColumnType(
result_df_sf_Types[
columns[idx + len(column_exps)]
] = result_df_sf_Types_by_col_idx[
idx + len(column_exps)
] = ColumnType(
infer_type(cal_exp_res), nullable=True
)
data.append(values)
Expand All @@ -634,6 +644,7 @@ def aggregate_by_groups(cur_group: TableEmulator):
result_df[intermediate_mapped_column[col]] = series_data

result_df.sf_types = result_df_sf_Types
result_df.sf_types_by_col_index = result_df_sf_Types_by_col_idx
result_df.columns = columns
return result_df
if isinstance(source_plan, Range):
Expand Down Expand Up @@ -1126,6 +1137,12 @@ def describe(plan: MockExecutionPlan) -> List[Attribute]:
data_type = LongType()
elif isinstance(data_type, FloatType):
data_type = DoubleType()
elif (
isinstance(data_type, DecimalType)
and data_type.precision == 38
and data_type.scale == 0
):
data_type = LongType()
ret.append(
Attribute(
quote_name(result[c].name.strip()),
Expand Down Expand Up @@ -1243,15 +1260,15 @@ def calculate_expression(
)
return ColumnEmulator(
data=[bool(data is None) for data in child_column],
sf_type=ColumnType(BooleanType(), False),
sf_type=ColumnType(BooleanType(), True),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We observe that boolean type columns are always nullable from live connection.

)
if isinstance(exp, IsNotNull):
child_column = calculate_expression(
exp.child, input_data, analyzer, expr_to_alias
)
return ColumnEmulator(
data=[bool(data is not None) for data in child_column],
sf_type=ColumnType(BooleanType(), False),
sf_type=ColumnType(BooleanType(), True),
)
if isinstance(exp, IsNaN):
child_column = calculate_expression(
Expand All @@ -1264,7 +1281,7 @@ def calculate_expression(
except TypeError:
res.append(False)
return ColumnEmulator(
data=res, dtype=object, sf_type=ColumnType(BooleanType(), False)
data=res, dtype=object, sf_type=ColumnType(BooleanType(), True)
)
if isinstance(exp, Not):
child_column = calculate_expression(
Expand All @@ -1286,6 +1303,8 @@ def calculate_expression(
if isinstance(exp, BinaryExpression):
left = calculate_expression(exp.left, input_data, analyzer, expr_to_alias)
right = calculate_expression(exp.right, input_data, analyzer, expr_to_alias)
# TODO: Address mixed type calculation here. For instance Snowflake allows to add a date to a number, but
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be enabled by type-coercion.

# pandas doesn't allow. Type coercion will address it.
if isinstance(exp, Multiply):
new_column = left * right
elif isinstance(exp, Divide):
Expand Down Expand Up @@ -1354,7 +1373,7 @@ def calculate_expression(
except re.error:
raise SnowparkSQLException(f"Invalid regular expression {raw_pattern}")
result = lhs.str.match(pattern)
result.sf_type = ColumnType(BooleanType(), exp.nullable)
result.sf_type = ColumnType(BooleanType(), True)
return result
if isinstance(exp, Like):
lhs = calculate_expression(exp.expr, input_data, analyzer, expr_to_alias)
Expand All @@ -1366,7 +1385,7 @@ def calculate_expression(
)
)
result = lhs.str.match(pattern)
result.sf_type = ColumnType(BooleanType(), exp.nullable)
result.sf_type = ColumnType(BooleanType(), True)
return result
if isinstance(exp, InExpression):
lhs = calculate_expression(exp.columns, input_data, analyzer, expr_to_alias)
Expand Down Expand Up @@ -1661,7 +1680,12 @@ def calculate_expression(
).sf_type
if not calculated_sf_type:
calculated_sf_type = cur_windows_sf_type
elif calculated_sf_type.datatype != cur_windows_sf_type.datatype:
elif calculated_sf_type != cur_windows_sf_type and (
not (
isinstance(calculated_sf_type.datatype, StringType)
and isinstance(cur_windows_sf_type.datatype, StringType)
)
):
if isinstance(calculated_sf_type.datatype, NullType):
calculated_sf_type = sub_window_res.sf_type
# the result calculated upon a windows can be None, this is still valid and we can keep
Expand Down
Loading
Loading