Skip to content

Commit

Permalink
SNOW-1846962: remove type conversion when calling a system function (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-yuwang authored Dec 19, 2024
1 parent 0488021 commit d2cc2b8
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
- Added support for mixed case field names in struct type columns.
- Added support for `SeriesGroupBy.unique`

#### Bug Fixes

- Fixed a bug that system function called through `session.call` have incorrect type conversion.

#### Improvements
- Improve performance of `DataFrame.map`, `Series.apply` and `Series.map` methods by mapping numpy functions to snowpark functions if possible.
- Updated integration testing for `session.lineage.trace` to exclude deleted objects
Expand Down
50 changes: 48 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,55 @@ def float_nan_inf_to_sql(value: float) -> str:
return f"{cast_value} :: FLOAT"


def to_sql(value: Any, datatype: DataType, from_values_statement: bool = False) -> str:
"""Convert a value with DataType to a snowflake compatible sql"""
def to_sql_no_cast(
value: Any,
datatype: DataType,
) -> str:
if value is None:
return "NULL"
if isinstance(datatype, VariantType):
# PARSE_JSON returns VARIANT, so no need to append :: VARIANT here explicitly.
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))})"
if isinstance(value, str):
if isinstance(datatype, GeographyType):
return f"TO_GEOGRAPHY({str_to_sql(value)})"
if isinstance(datatype, GeometryType):
return f"TO_GEOMETRY({str_to_sql(value)})"
return str_to_sql(value)
if isinstance(value, float) and (math.isnan(value) or math.isinf(value)):
cast_value = float_nan_inf_to_sql(value)
return cast_value[:-9]
if isinstance(value, (list, bytes, bytearray)) and isinstance(datatype, BinaryType):
return str(bytes(value))
if isinstance(value, (list, tuple, array)) and isinstance(datatype, ArrayType):
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))})"
if isinstance(value, dict) and isinstance(datatype, MapType):
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))})"
if isinstance(datatype, DateType):
if isinstance(value, int):
# add value as number of days to 1970-01-01
target_date = date(1970, 1, 1) + timedelta(days=value)
return f"'{target_date.isoformat()}'"
elif isinstance(value, date):
return f"'{value.isoformat()}'"

if isinstance(datatype, TimestampType):
if isinstance(value, (int, datetime)):
if isinstance(value, int):
# add value as microseconds to 1970-01-01 00:00:00.00.
value = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(
microseconds=value
)
return f"'{value}'"
return f"{value}"


def to_sql(
value: Any,
datatype: DataType,
from_values_statement: bool = False,
) -> str:
"""Convert a value with DataType to a snowflake compatible sql"""
# Handle null values
if isinstance(
datatype,
Expand Down
4 changes: 3 additions & 1 deletion src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import snowflake.snowpark
from snowflake.connector.options import installed_pandas, pandas
from snowflake.snowpark._internal import code_generation, type_utils
from snowflake.snowpark._internal.analyzer.datatype_mapper import to_sql
from snowflake.snowpark._internal.analyzer.datatype_mapper import to_sql, to_sql_no_cast
from snowflake.snowpark._internal.telemetry import TelemetryField
from snowflake.snowpark._internal.type_utils import (
NoneType,
Expand Down Expand Up @@ -1481,6 +1481,8 @@ def generate_call_python_sp_sql(
for arg in args:
if isinstance(arg, snowflake.snowpark.Column):
sql_args.append(session._analyzer.analyze(arg._expression, {}))
elif "system$" in sproc_name.lower():
sql_args.append(to_sql_no_cast(arg, infer_type(arg)))
else:
sql_args.append(to_sql(arg, infer_type(arg)))
return f"CALL {sproc_name}({', '.join(sql_args)})"
116 changes: 116 additions & 0 deletions tests/unit/test_datatype_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@

import datetime
from decimal import Decimal
from unittest.mock import MagicMock

import pytest

from snowflake.snowpark import Session
from snowflake.snowpark._internal.analyzer.datatype_mapper import (
numeric_to_sql_without_cast,
schema_expression,
to_sql,
to_sql_no_cast,
)
from snowflake.snowpark._internal.udf_utils import generate_call_python_sp_sql
from snowflake.snowpark.types import (
ArrayType,
BinaryType,
Expand Down Expand Up @@ -156,6 +160,118 @@ def test_to_sql():
)


def test_to_sql_system_function():
# Test nulls
assert to_sql_no_cast(None, NullType()) == "NULL"
assert to_sql_no_cast(None, ArrayType(DoubleType())) == "NULL"
assert to_sql_no_cast(None, MapType(IntegerType(), ByteType())) == "NULL"
assert to_sql_no_cast(None, StructType([])) == "NULL"
assert to_sql_no_cast(None, GeographyType()) == "NULL"
assert to_sql_no_cast(None, GeometryType()) == "NULL"

assert to_sql_no_cast(None, IntegerType()) == "NULL"
assert to_sql_no_cast(None, ShortType()) == "NULL"
assert to_sql_no_cast(None, ByteType()) == "NULL"
assert to_sql_no_cast(None, LongType()) == "NULL"
assert to_sql_no_cast(None, FloatType()) == "NULL"
assert to_sql_no_cast(None, StringType()) == "NULL"
assert to_sql_no_cast(None, DoubleType()) == "NULL"
assert to_sql_no_cast(None, BooleanType()) == "NULL"

assert to_sql_no_cast(None, "Not any of the previous types") == "NULL"

# Test non-nulls
assert (
to_sql_no_cast("\\ ' ' abc \n \\", StringType())
== "'\\\\ '' '' abc \\n \\\\'"
)
assert (
to_sql_no_cast("\\ ' ' abc \n \\", StringType())
== "'\\\\ '' '' abc \\n \\\\'"
)
assert to_sql_no_cast(1, ByteType()) == "1"
assert to_sql_no_cast(1, ShortType()) == "1"
assert to_sql_no_cast(1, IntegerType()) == "1"
assert to_sql_no_cast(1, LongType()) == "1"
assert to_sql_no_cast(1, BooleanType()) == "1"
assert to_sql_no_cast(0, ByteType()) == "0"
assert to_sql_no_cast(0, ShortType()) == "0"
assert to_sql_no_cast(0, IntegerType()) == "0"
assert to_sql_no_cast(0, LongType()) == "0"
assert to_sql_no_cast(0, BooleanType()) == "0"

assert to_sql_no_cast(float("nan"), FloatType()) == "'NAN'"
assert to_sql_no_cast(float("inf"), FloatType()) == "'INF'"
assert to_sql_no_cast(float("-inf"), FloatType()) == "'-INF'"
assert to_sql_no_cast(1.2, FloatType()) == "1.2"

assert to_sql_no_cast(float("nan"), DoubleType()) == "'NAN'"
assert to_sql_no_cast(float("inf"), DoubleType()) == "'INF'"
assert to_sql_no_cast(float("-inf"), DoubleType()) == "'-INF'"
assert to_sql_no_cast(1.2, DoubleType()) == "1.2"

assert to_sql_no_cast(Decimal(0.5), DecimalType(2, 1)) == "0.5"

assert to_sql_no_cast(397, DateType()) == "'1971-02-02'"

assert to_sql_no_cast(datetime.date(1971, 2, 2), DateType()) == "'1971-02-02'"

assert (
to_sql_no_cast(1622002533000000, TimestampType())
== "'2021-05-26 04:15:33+00:00'"
)

assert (
to_sql_no_cast(bytearray.fromhex("2Ef0 F1f2 "), BinaryType())
== "b'.\\xf0\\xf1\\xf2'"
)

assert to_sql_no_cast([1, "2", 3.5], ArrayType()) == "PARSE_JSON('[1, \"2\", 3.5]')"
assert to_sql_no_cast({"'": '"'}, MapType()) == 'PARSE_JSON(\'{"\'\'": "\\\\""}\')'
assert to_sql_no_cast([{1: 2}], ArrayType()) == "PARSE_JSON('[{\"1\": 2}]')"
assert to_sql_no_cast({1: [2]}, MapType()) == "PARSE_JSON('{\"1\": [2]}')"

assert to_sql_no_cast([1, bytearray(1)], ArrayType()) == "PARSE_JSON('[1, \"00\"]')"

assert (
to_sql_no_cast(["2", Decimal(0.5)], ArrayType()) == "PARSE_JSON('[\"2\", 0.5]')"
)

dt = datetime.datetime.today()
assert (
to_sql_no_cast({1: dt}, MapType())
== 'PARSE_JSON(\'{"1": "' + dt.isoformat() + "\"}')"
)

assert to_sql_no_cast([1, 2, 3.5], VectorType(float, 3)) == "[1, 2, 3.5]"
assert (
to_sql_no_cast("POINT(-122.35 37.55)", GeographyType())
== "TO_GEOGRAPHY('POINT(-122.35 37.55)')"
)
assert (
to_sql_no_cast("POINT(-122.35 37.55)", GeometryType())
== "TO_GEOMETRY('POINT(-122.35 37.55)')"
)
assert to_sql_no_cast("1", VariantType()) == "PARSE_JSON('\"1\"')"
assert (
to_sql_no_cast([1, 2, 3.5, 4.1234567, -3.8], VectorType("float", 5))
== "[1, 2, 3.5, 4.1234567, -3.8]"
)
assert to_sql_no_cast([1, 2, 3], VectorType(int, 3)) == "[1, 2, 3]"
assert (
to_sql_no_cast([1, 2, 31234567, -1928, 0, -3], VectorType(int, 5))
== "[1, 2, 31234567, -1928, 0, -3]"
)


def test_generate_call_python_sp_sql():
fake_session = MagicMock(Session)
assert (
generate_call_python_sp_sql(fake_session, "system$wait", 1)
== "CALL system$wait(1)"
)


@pytest.mark.parametrize(
"timezone, expected",
[
Expand Down

0 comments on commit d2cc2b8

Please sign in to comment.