Skip to content

Commit

Permalink
Merge branch 'main' into vbudati/SNOW-1643304-record-set-index-length…
Browse files Browse the repository at this point in the history
…-checking-behavior-change
  • Loading branch information
sfc-gh-vbudati authored Aug 28, 2024
2 parents 9f543ac + 572d01d commit 3301c7d
Show file tree
Hide file tree
Showing 44 changed files with 1,834 additions and 424 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#### Improvements

- Added support for `ln` in `snowflake.snowpark.functions`
- Added support for specifying the following to `DataFrameWriter.save_as_table`:
- `enable_schema_evolution`
- `data_retention_time`
Expand Down Expand Up @@ -51,7 +52,7 @@
#### New Features

- Added limited support for the `Timedelta` type, including the following features. Snowpark pandas will raise `NotImplementedError` for unsupported `Timedelta` use cases.
- supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`.
- supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`, `assign`, `bfill`, `ffill`, `fillna`, `compare`, `diff`, `drop`, `dropna`, `duplicated`, `empty`, `equals`, `insert`, `isin`, `isna`, `items`, `iterrows`, `join`, `len`, `mask`, `melt`, `merge`, `nlargest`, `nsmallest`.
- converting non-timedelta to timedelta via `astype`.
- `NotImplementedError` will be raised for the rest of methods that do not support `Timedelta`.
- support for subtracting two timestamps to get a Timedelta.
Expand Down
1 change: 1 addition & 0 deletions docs/source/snowpark/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ Functions
length
listagg
lit
ln
locate
log
lower
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,15 +1062,15 @@ def resolve_imports_and_packages(
packages,
include_pandas=is_pandas_udf,
statement_params=statement_params,
)[0]
)
if packages is not None
else session._resolve_packages(
[],
session._packages,
validate_package=False,
include_pandas=is_pandas_udf,
statement_params=statement_params,
)[0]
)
)

if session is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5982,7 +5982,8 @@ def vector_inner_product(v1: ColumnOrName, v2: ColumnOrName) -> Column:


def ln(c: ColumnOrLiteral) -> Column:
"""Returns the natrual log product of given column expression
"""Returns the natrual logarithm of given column expression.
Example::
>>> from snowflake.snowpark.functions import ln
>>> from math import e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,8 @@ def are_equal_types(type1: DataType, type2: DataType) -> bool:
Returns:
True if given types are equal, False otherwise.
"""
if isinstance(type1, TimedeltaType) and not isinstance(type2, TimedeltaType):
return False
if isinstance(type2, TimedeltaType) and not isinstance(type1, TimedeltaType):
return False
if isinstance(type1, TimedeltaType) or isinstance(type2, TimedeltaType):
return type1 == type2
if isinstance(type1, _IntegralType) and isinstance(type2, _IntegralType):
return True
if isinstance(type1, _FractionalType) and isinstance(type2, _FractionalType):
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
)
from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame
from snowflake.snowpark.modin.plugin._internal.indexing_utils import set_frame_2d_labels
from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import (
SnowparkPandasType,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import infer_series_type
from snowflake.snowpark.modin.plugin._internal.utils import (
append_columns,
Expand Down Expand Up @@ -100,6 +103,13 @@ def scalar_isin_expression(
for literal_expr in values
]

# Case 4: If column's and values' data type differs and any of the type is SnowparkPandasType
elif values_dtype != column_dtype and (
isinstance(values_dtype, SnowparkPandasType)
or isinstance(column_dtype, SnowparkPandasType)
):
return pandas_lit(False)

values = array_construct(*values)

# to_variant is a requirement for array_contains, else an error is produced.
Expand Down
24 changes: 24 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/join_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,30 @@ def join(
JoinTypeLit
), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}"

def assert_snowpark_pandas_types_match() -> None:
"""If Snowpark pandas types do not match, then a ValueError will be raised."""
left_types = [
left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in left_on
]
right_types = [
right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in right_on
]
for i, (lt, rt) in enumerate(zip(left_types, right_types)):
if lt != rt:
left_on_id = left_on[i]
idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id)
key = left.data_column_pandas_labels[idx]
lt = lt if lt is not None else left.get_snowflake_type(left_on_id)
rt = rt if rt is not None else right.get_snowflake_type(right_on[i])
raise ValueError(
f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. "
f"If you wish to proceed you should use pd.concat"
)

assert_snowpark_pandas_types_match()

# Re-project the active columns to make sure all active columns of the internal frame participate
# in the join operation, and unnecessary columns are dropped from the projected columns.
left = left.select_active_columns()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,8 @@
import pandas as native_pd

from snowflake.snowpark.column import Column
from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage
from snowflake.snowpark.types import DataType, LongType

TIMEDELTA_WARNING_MESSAGE = (
"Snowpark pandas support for Timedelta is not currently available."
)

"""Map Python type to its from_pandas method"""
_python_type_to_from_pandas: dict[type, Callable[[Any], Any]] = {}

Expand Down Expand Up @@ -101,6 +96,13 @@ def get_snowpark_pandas_type_for_pandas_type(
return _type_to_snowpark_pandas_type[pandas_type]()
return None

def type_match(self, value: Any) -> bool:
"""Return True if the value's type matches self."""
val_type = SnowparkPandasType.get_snowpark_pandas_type_for_pandas_type(
type(value)
)
return self == val_type


class SnowparkPandasColumn(NamedTuple):
"""A Snowpark Column that has an optional SnowparkPandasType."""
Expand Down Expand Up @@ -128,11 +130,14 @@ class TimedeltaType(SnowparkPandasType, LongType):
)

def __init__(self) -> None:
# TODO(SNOW-1620452): Remove this warning message before releasing
# Timedelta support.
WarningMessage.single_warning(TIMEDELTA_WARNING_MESSAGE)
super().__init__()

def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

@staticmethod
def to_pandas(value: int) -> native_pd.Timedelta:
"""
Expand Down
21 changes: 17 additions & 4 deletions src/snowflake/snowpark/modin/plugin/_internal/unpivot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,16 @@ def _simple_unpivot(
# create the initial set of columns to be retained as identifiers and those
# which will be unpivoted. Collect data type information.
unpivot_quoted_columns = []
unpivot_quoted_column_types = []

ordering_decode_conditions = []
id_col_names = []
id_col_quoted_identifiers = []
for (pandas_label, snowflake_quoted_identifier) in zip(
id_col_types = []
for (pandas_label, snowflake_quoted_identifier, sp_pandas_type) in zip(
frame.data_column_pandas_labels,
frame.data_column_snowflake_quoted_identifiers,
frame.cached_data_column_snowpark_pandas_types,
):
is_id_col = pandas_label in pandas_id_columns
is_var_col = pandas_label in pandas_value_columns
Expand All @@ -752,9 +756,11 @@ def _simple_unpivot(
col(var_quoted) == pandas_lit(pandas_label)
)
unpivot_quoted_columns.append(snowflake_quoted_identifier)
unpivot_quoted_column_types.append(sp_pandas_type)
if is_id_col:
id_col_names.append(pandas_label)
id_col_quoted_identifiers.append(snowflake_quoted_identifier)
id_col_types.append(sp_pandas_type)

# create the case expressions used for the final result set ordering based
# on the column position. This clause will be appled after the unpivot
Expand Down Expand Up @@ -787,7 +793,7 @@ def _simple_unpivot(
pandas_labels=[unquoted_col_name],
)[0]
)
# coalese the values to unpivot and preserve null values This code
# coalesce the values to unpivot and preserve null values This code
# can be removed when UNPIVOT_INCLUDE_NULLS is enabled
unpivot_columns_normalized_types.append(
coalesce(to_variant(c), to_variant(pandas_lit(null_replace_value))).alias(
Expand Down Expand Up @@ -870,6 +876,13 @@ def _simple_unpivot(
var_quoted,
corrected_value_column_name,
]
corrected_value_column_type = None
if len(set(unpivot_quoted_column_types)) == 1:
corrected_value_column_type = unpivot_quoted_column_types[0]
final_snowflake_quoted_col_types = id_col_types + [
None,
corrected_value_column_type,
]

# Create the new frame and compiler
return InternalFrame.create(
Expand All @@ -881,8 +894,8 @@ def _simple_unpivot(
index_column_snowflake_quoted_identifiers=[
ordered_dataframe.row_position_snowflake_quoted_identifier
],
data_column_types=None,
index_column_types=None,
data_column_types=final_snowflake_quoted_col_types,
index_column_types=[None],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,7 @@ def _shift_values_axis_0(
row_position_quoted_identifier = frame.row_position_snowflake_quoted_identifier

fill_value_dtype = infer_object_type(fill_value)
fill_value = pandas_lit(fill_value) if fill_value is not None else None
fill_value = None if pd.isna(fill_value) else pandas_lit(fill_value)

def shift_expression_and_type(
quoted_identifier: str, dtype: DataType
Expand Down Expand Up @@ -5757,8 +5757,6 @@ def insert(
Returns:
A new SnowflakeQueryCompiler instance with new column.
"""
self._raise_not_implemented_error_for_timedelta()

if not isinstance(value, SnowflakeQueryCompiler):
# Scalar value
new_internal_frame = self._modin_frame.append_column(
Expand Down Expand Up @@ -5848,7 +5846,9 @@ def move_last_element(arr: list, index: int) -> None:
data_column_snowflake_quoted_identifiers = (
new_internal_frame.data_column_snowflake_quoted_identifiers
)
data_column_types = new_internal_frame.cached_data_column_snowpark_pandas_types
move_last_element(data_column_snowflake_quoted_identifiers, loc)
move_last_element(data_column_types, loc)

new_internal_frame = InternalFrame.create(
ordered_dataframe=new_internal_frame.ordered_dataframe,
Expand All @@ -5857,8 +5857,8 @@ def move_last_element(arr: list, index: int) -> None:
data_column_pandas_index_names=new_internal_frame.data_column_pandas_index_names,
index_column_pandas_labels=new_internal_frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=new_internal_frame.index_column_snowflake_quoted_identifiers,
data_column_types=None,
index_column_types=None,
data_column_types=data_column_types,
index_column_types=new_internal_frame.cached_index_column_snowpark_pandas_types,
)
return SnowflakeQueryCompiler(new_internal_frame)

Expand Down Expand Up @@ -6645,8 +6645,6 @@ def melt(
Notes:
melt does not yet handle multiindex or ignore index
"""
self._raise_not_implemented_error_for_timedelta()

if col_level is not None:
raise NotImplementedError(
"Snowpark Pandas doesn't support 'col_level' argument in melt API"
Expand Down Expand Up @@ -6749,8 +6747,6 @@ def merge(
Returns:
SnowflakeQueryCompiler instance with merged result.
"""
self._raise_not_implemented_error_for_timedelta()

if validate:
ErrorMessage.not_implemented(
"Snowpark pandas merge API doesn't yet support 'validate' parameter"
Expand Down Expand Up @@ -9815,6 +9811,10 @@ def _fillna_with_masking(

# case 2: fillna with a method
if method is not None:
# no Snowpark pandas type change in this case
data_column_snowpark_pandas_types = (
self._modin_frame.cached_data_column_snowpark_pandas_types
)
method = FillNAMethod.get_enum_for_string_method(method)
method_is_ffill = method is FillNAMethod.FFILL_METHOD
if axis == 0:
Expand Down Expand Up @@ -9921,6 +9921,7 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
include_index=False,
)
fillna_column_map = {}
data_column_snowpark_pandas_types = []
if columns_mask is not None:
columns_to_ignore = itertools.compress(
self._modin_frame.data_column_pandas_labels,
Expand All @@ -9940,10 +9941,18 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
col(id),
coalesce(id, pandas_lit(val)),
)
col_type = self._modin_frame.get_snowflake_type(id)
col_pandas_type = (
col_type
if isinstance(col_type, SnowparkPandasType)
and col_type.type_match(val)
else None
)
data_column_snowpark_pandas_types.append(col_pandas_type)

return SnowflakeQueryCompiler(
self._modin_frame.update_snowflake_quoted_identifiers_with_expressions(
fillna_column_map
fillna_column_map, data_column_snowpark_pandas_types
).frame
)

Expand Down Expand Up @@ -10217,7 +10226,8 @@ def diff(self, periods: int, axis: int) -> "SnowflakeQueryCompiler":
}
return SnowflakeQueryCompiler(
self._modin_frame.update_snowflake_quoted_identifiers_with_expressions(
diff_label_to_value_map
diff_label_to_value_map,
self._modin_frame.cached_data_column_snowpark_pandas_types,
).frame
)

Expand Down
Loading

0 comments on commit 3301c7d

Please sign in to comment.