Skip to content

Commit

Permalink
[Local Testing] SNOW-904260 support dataframe na functions (#1069)
Browse files Browse the repository at this point in the history
* Add function iff

* Add support for na functions, blocked by type coercion

* Fix tests

* Workaround type coercion

* Workaround coercion

* Address comments

* Delete dead code

* Fix tests
  • Loading branch information
sfc-gh-stan authored Nov 6, 2023
1 parent 43967ef commit b24d529
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 166 deletions.
6 changes: 0 additions & 6 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3213,12 +3213,6 @@ def na(self) -> DataFrameNaFunctions:
Returns a :class:`DataFrameNaFunctions` object that provides functions for
handling missing values in the DataFrame.
"""
from snowflake.snowpark.mock.connection import MockServerConnection

if isinstance(self._session._conn, MockServerConnection):
raise NotImplementedError(
"[Local Testing] DataFrameNaFunctions is not implemented."
)
return self._na

def describe(self, *cols: Union[str, List[str]]) -> "DataFrame":
Expand Down
24 changes: 3 additions & 21 deletions src/snowflake/snowpark/dataframe_na_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#

import copy
import math
from logging import getLogger
from typing import Dict, Optional, Union

Expand Down Expand Up @@ -153,13 +154,6 @@ def drop(
# iff(float_col = 'NaN' or float_col is null, 0, 1)
# iff(non_float_col is null, 0, 1) >= thresh

from snowflake.snowpark.mock.connection import MockServerConnection

if isinstance(self._df._session._conn, MockServerConnection):
raise NotImplementedError(
"[Local Testing] DataFrame NA functions are currently not supported."
)

if how is not None and how not in ["any", "all"]:
raise ValueError(f"how ('{how}') should be 'any' or 'all'")

Expand Down Expand Up @@ -206,7 +200,7 @@ def drop(
df_col_type_dict[normalized_col_name], (FloatType, DoubleType)
):
# iff(col = 'NaN' or col is null, 0, 1)
is_na = iff((col == "NaN") | col.is_null(), 0, 1)
is_na = iff((col == math.nan) | col.is_null(), 0, 1)
else:
# iff(col is null, 0, 1)
is_na = iff(col.is_null(), 0, 1)
Expand Down Expand Up @@ -309,13 +303,6 @@ def fill(
# select col, iff(float_col = 'NaN' or float_col is null, replacement, float_col)
# iff(non_float_col is null, replacement, non_float_col) from table where

from snowflake.snowpark.mock.connection import MockServerConnection

if isinstance(self._df._session._conn, MockServerConnection):
raise NotImplementedError(
"[Local Testing] DataFrame NA functions are currently not supported."
)

if subset is None:
subset = self._df.columns
elif isinstance(subset, str):
Expand Down Expand Up @@ -368,7 +355,7 @@ def fill(
if isinstance(datatype, (FloatType, DoubleType)):
# iff(col = 'NaN' or col is null, value, col)
res_columns.append(
iff((col == "NaN") | col.is_null(), value, col).as_(
iff((col == math.nan) | col.is_null(), value, col).as_(
col_name
)
)
Expand Down Expand Up @@ -486,12 +473,7 @@ def replace(
See Also:
:func:`DataFrame.replace`
"""
from snowflake.snowpark.mock.connection import MockServerConnection

if isinstance(self._df._session._conn, MockServerConnection):
raise NotImplementedError(
"[Local Testing] DataFrame NA functions are currently not supported."
)
if subset is None:
subset = self._df.columns
elif isinstance(subset, str):
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2742,7 +2742,7 @@ def char(col: ColumnOrName) -> Column:
return builtin("char")(c)


def to_char(c: ColumnOrName, format: Optional[ColumnOrLiteralStr] = None) -> Column:
def to_char(c: ColumnOrName, format: Optional[str] = None) -> Column:
"""Converts a Unicode code point (including 7-bit ASCII) into the character that
matches the input Unicode.
Expand Down
75 changes: 49 additions & 26 deletions src/snowflake/snowpark/mock/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,12 @@ def mock_avg(column: ColumnEmulator) -> ColumnEmulator:


@patch("count")
def mock_count(column: ColumnEmulator) -> ColumnEmulator:
count_column = column.count()
if isinstance(count_column, ColumnEmulator):
count_column.sf_type = ColumnType(LongType(), False)
return ColumnEmulator(
data=round(count_column, 5), sf_type=ColumnType(LongType(), False)
)
def mock_count(column: Union[TableEmulator, ColumnEmulator]) -> ColumnEmulator:
if isinstance(column, ColumnEmulator):
count_column = column.count()
return ColumnEmulator(data=count_column, sf_type=ColumnType(LongType(), False))
else: # TableEmulator
return ColumnEmulator(data=len(column), sf_type=ColumnType(LongType(), False))


@patch("count_distinct")
Expand Down Expand Up @@ -232,7 +231,7 @@ def mock_covar_pop(column1: ColumnEmulator, column2: ColumnEmulator) -> ColumnEm


@patch("listagg")
def mock_listagg(column: ColumnEmulator, delimiter, is_distinct):
def mock_listagg(column: ColumnEmulator, delimiter: str, is_distinct: bool):
columns_data = ColumnEmulator(column.unique()) if is_distinct else column
# nit todo: returns a string that includes all the non-NULL input values, separated by the delimiter.
return ColumnEmulator(
Expand All @@ -244,7 +243,7 @@ def mock_listagg(column: ColumnEmulator, delimiter, is_distinct):
@patch("to_date")
def mock_to_date(
column: ColumnEmulator,
fmt: Union[ColumnEmulator, str] = None,
fmt: str = None,
try_cast: bool = False,
):
"""
Expand Down Expand Up @@ -297,7 +296,7 @@ def mock_to_date(


@patch("contains")
def mock_contains(expr1: ColumnEmulator, expr2: Union[str, ColumnEmulator]):
def mock_contains(expr1: ColumnEmulator, expr2: ColumnEmulator):
if isinstance(expr1, str) and isinstance(expr2, str):
return ColumnEmulator(data=[bool(str(expr2) in str(expr1))])
if isinstance(expr1, ColumnEmulator) and isinstance(expr2, ColumnEmulator):
Expand Down Expand Up @@ -399,7 +398,7 @@ def mock_to_decimal(
@patch("to_time")
def mock_to_time(
column: ColumnEmulator,
fmt: Union[ColumnEmulator, str] = None,
fmt: Optional[str] = None,
try_cast: bool = False,
):
"""
Expand Down Expand Up @@ -471,7 +470,7 @@ def mock_to_time(
@patch("to_timestamp")
def mock_to_timestamp(
column: ColumnEmulator,
fmt: Union[ColumnEmulator, str] = None,
fmt: Optional[str] = None,
try_cast: bool = False,
):
"""
Expand Down Expand Up @@ -580,7 +579,7 @@ def try_convert(convert: Callable, try_cast: bool, val: Any):
@patch("to_char")
def mock_to_char(
column: ColumnEmulator,
fmt: Union[ColumnEmulator, str] = None,
fmt: Optional[str] = None,
try_cast: bool = False,
) -> ColumnEmulator: # TODO: support more input types
source_datatype = column.sf_type.datatype
Expand Down Expand Up @@ -741,16 +740,28 @@ def mock_to_binary(
@patch("iff")
def mock_iff(condition: ColumnEmulator, expr1: ColumnEmulator, expr2: ColumnEmulator):
assert isinstance(condition.sf_type.datatype, BooleanType)
condition = condition.array
res = ColumnEmulator(data=[None] * len(condition), dtype=object)
if not all(condition) and expr1.sf_type != expr2.sf_type:
if (
all(condition)
or all(~condition)
or expr1.sf_type.datatype == expr2.sf_type.datatype
or isinstance(expr1.sf_type.datatype, NullType)
or isinstance(expr2.sf_type.datatype, NullType)
):
res = ColumnEmulator(data=[None] * len(condition), dtype=object)
sf_data_type = (
expr1.sf_type.datatype
if any(condition) and not isinstance(expr1.sf_type.datatype, NullType)
else expr2.sf_type.datatype
)
nullability = expr1.sf_type.nullable and expr2.sf_type.nullable
res.sf_type = ColumnType(sf_data_type, nullability)
res.where(condition, other=expr2, inplace=True)
res.where([not x for x in condition], other=expr1, inplace=True)
return res
else:
raise SnowparkSQLException(
f"iff expr1 and expr2 have conflicting data types: {expr1.sf_type} != {expr2.sf_type}"
f"[Local Testing] does not support coercion currently, iff expr1 and expr2 have conflicting data types: {expr1.sf_type} != {expr2.sf_type}"
)
res.sf_type = expr1.sf_type if any(condition) else expr2.sf_type
res.where(condition, other=expr2, inplace=True)
res.where([not x for x in condition], other=expr1, inplace=True)
return res


@patch("coalesce")
Expand All @@ -773,20 +784,31 @@ def mock_coalesce(*exprs):
def mock_substring(
base_expr: ColumnEmulator, start_expr: ColumnEmulator, length_expr: ColumnEmulator
):
return base_expr.str.slice(start=start_expr - 1, stop=start_expr - 1 + length_expr)
res = [
x[y - 1 : y + z - 1] if x is not None else None
for x, y, z in zip(base_expr, start_expr, length_expr)
]
res = ColumnEmulator(
res, sf_type=ColumnType(StringType(), base_expr.sf_type.nullable), dtype=object
)
return res


@patch("startswith")
def mock_startswith(expr1: ColumnEmulator, expr2: ColumnEmulator):
res = expr1.str.startswith(expr2)
res.sf_type = ColumnType(StringType(), expr1.sf_type.nullable)
res = [x.startswith(y) if x is not None else None for x, y in zip(expr1, expr2)]
res = ColumnEmulator(
res, sf_type=ColumnType(BooleanType(), expr1.sf_type.nullable), dtype=bool
)
return res


@patch("endswith")
def mock_endswith(expr1: ColumnEmulator, expr2: ColumnEmulator):
res = expr1.str.endswith(expr2)
res.sf_type = ColumnType(StringType(), expr1.sf_type.nullable)
res = [x.endswith(y) if x is not None else None for x, y in zip(expr1, expr2)]
res = ColumnEmulator(
res, sf_type=ColumnType(BooleanType(), expr1.sf_type.nullable), dtype=bool
)
return res


Expand Down Expand Up @@ -879,3 +901,4 @@ def mock_to_variant(expr: ColumnEmulator):
res = expr.copy()
res.sf_type = ColumnType(VariantType(), expr.sf_type.nullable)
return res

92 changes: 57 additions & 35 deletions src/snowflake/snowpark/mock/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
import math
import re
import typing
import uuid
from enum import Enum
from functools import cached_property, partial
Expand Down Expand Up @@ -888,60 +889,66 @@ def calculate_expression(
if isinstance(exp, (UnresolvedAlias, Alias)):
return calculate_expression(exp.child, input_data, analyzer, expr_to_alias)
if isinstance(exp, FunctionExpression):
# evaluated_children maps to parameters passed to the function call
evaluated_children = [
calculate_expression(
c, input_data, analyzer, expr_to_alias, keep_literal=True
)
for c in exp.children
]

# Special case for count_distinct
if exp.name.lower() == "count" and exp.is_distinct:
func_name = "count_distinct"
else:
func_name = exp.name.lower()

try:
original_func = getattr(
importlib.import_module("snowflake.snowpark.functions"),
exp.name.lower(),
importlib.import_module("snowflake.snowpark.functions"), func_name
)
except AttributeError:
raise NotImplementedError(
f"[Local Testing] Mocking function {exp.name.lower()} is not supported."
f"[Local Testing] Mocking function {func_name} is not supported."
)

signatures = inspect.signature(original_func)
spec = inspect.getfullargspec(original_func)
if exp.name not in _MOCK_FUNCTION_IMPLEMENTATION_MAP:
if func_name not in _MOCK_FUNCTION_IMPLEMENTATION_MAP:
raise NotImplementedError(
f"[Local Testing] Mocking function {exp.name} is not implemented."
f"[Local Testing] Mocking function {func_name} is not implemented."
)
to_pass_args = []
type_hints = typing.get_type_hints(original_func)
for idx, key in enumerate(signatures.parameters):
type_hint = str(type_hints[key])
keep_literal = "Column" not in type_hint
if key == spec.varargs:
to_pass_args.extend(evaluated_children[idx:])
to_pass_args.extend(
[
calculate_expression(
c,
input_data,
analyzer,
expr_to_alias,
keep_literal=keep_literal,
)
for c in exp.children[idx:]
]
)
else:
try:
to_pass_args.append(evaluated_children[idx])
to_pass_args.append(
calculate_expression(
exp.children[idx],
input_data,
analyzer,
expr_to_alias,
keep_literal=keep_literal,
)
)
except IndexError:
to_pass_args.append(None)

if exp.name == "count" and exp.is_distinct:
if "count_distinct" not in _MOCK_FUNCTION_IMPLEMENTATION_MAP:
raise NotImplementedError(
f"[Local Testing] Mocking function {exp.name} is not implemented."
)
return _MOCK_FUNCTION_IMPLEMENTATION_MAP["count_distinct"](
*evaluated_children
)
if (
exp.name == "count"
and isinstance(exp.children[0], Literal)
and exp.children[0].sql == "LITERAL()"
):
to_pass_args[0] = input_data
if exp.name == "array_agg":
if func_name == "array_agg":
to_pass_args[-1] = exp.is_distinct
if exp.name == "sum" and exp.is_distinct:
if func_name == "sum" and exp.is_distinct:
to_pass_args[0] = ColumnEmulator(
data=to_pass_args[0].unique(), sf_type=to_pass_args[0].sf_type
)
return _MOCK_FUNCTION_IMPLEMENTATION_MAP[exp.name](*to_pass_args)
return _MOCK_FUNCTION_IMPLEMENTATION_MAP[func_name](*to_pass_args)
if isinstance(exp, ListAgg):
column = calculate_expression(exp.col, input_data, analyzer, expr_to_alias)
column.sf_type = ColumnType(StringType(), exp.col.nullable)
Expand Down Expand Up @@ -1013,6 +1020,15 @@ def calculate_expression(
new_column = left**right
elif isinstance(exp, EqualTo):
new_column = left == right
if left.hasnans and right.hasnans:
new_column[
left.apply(lambda x: x is None) & right.apply(lambda x: x is None)
] = True
new_column[
left.apply(lambda x: x is not None and np.isnan(x))
& right.apply(lambda x: x is not None and np.isnan(x))
] = True
# NaN == NaN evaluates to False in pandas, but True in Snowflake
elif isinstance(exp, NotEqualTo):
new_column = left != right
elif isinstance(exp, GreaterThanOrEqual):
Expand Down Expand Up @@ -1148,7 +1164,10 @@ def calculate_expression(
remaining = remaining[~remaining.index.isin(true_index)]

if output_data.sf_type:
if output_data.sf_type != value.sf_type:
if (
not isinstance(output_data.sf_type.datatype, NullType)
and output_data.sf_type != value.sf_type
):
raise SnowparkSQLException(
f"CaseWhen expressions have conflicting data types: {output_data.sf_type} != {value.sf_type}"
)
Expand All @@ -1161,9 +1180,12 @@ def calculate_expression(
)
output_data[remaining.index] = value[remaining.index]
if output_data.sf_type:
if output_data.sf_type != value.sf_type:
if (
not isinstance(output_data.sf_type.datatype, NullType)
and output_data.sf_type.datatype != value.sf_type.datatype
):
raise SnowparkSQLException(
f"CaseWhen expressions have conflicting data types: {output_data.sf_type} != {value.sf_type}"
f"CaseWhen expressions have conflicting data types: {output_data.sf_type.datatype} != {value.sf_type.datatype}"
)
else:
output_data.sf_type = value.sf_type
Expand Down
Loading

0 comments on commit b24d529

Please sign in to comment.