Skip to content

Commit

Permalink
[SNOW-935457] Support variant type for Literal (#1082)
Browse files Browse the repository at this point in the history
Support creating literals of `VariantType`, e.g. `Column(Literal(10, VariantType()))` will create a literal corresponding to `10::VARIANT`. Change `datatype_mapper` to use `PythonObjJSONEncoder` to allow more Python objects for `ARRAY`/`OBJECT` literals.
  • Loading branch information
sfc-gh-lspiegelberg authored Oct 10, 2023
1 parent ff2a386 commit 3bd6378
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
- Fixed a bug where imports from permanent stage locations were ignored for temporary stored procedures, UDTFs, UDFs, and UDAFs.
- Revert back to using CTAS (create table as select) statement for `Dataframe.writer.save_as_table` which does not need insert permission for writing tables.

### New Features
- Support `PythonObjJSONEncoder` json-serializable objects for `ARRAY` and `OBJECT` literals.

## 1.8.0 (2023-09-14)

### New Features
Expand Down
12 changes: 10 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import snowflake.snowpark._internal.analyzer.analyzer_utils as analyzer_utils
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
from snowflake.snowpark._internal.utils import PythonObjJSONEncoder
from snowflake.snowpark.types import (
ArrayType,
BinaryType,
Expand Down Expand Up @@ -69,6 +70,9 @@ def to_sql(value: Any, datatype: DataType, from_values_statement: bool = False)
if isinstance(datatype, BooleanType):
if value is None:
return "NULL :: BOOLEAN"
if isinstance(datatype, VariantType):
if value is None:
return "NULL :: VARIANT"
if value is None:
return "NULL"

Expand Down Expand Up @@ -136,10 +140,14 @@ def to_sql(value: Any, datatype: DataType, from_values_statement: bool = False)
return f"'{binascii.hexlify(value).decode()}' :: BINARY"

if isinstance(value, (list, tuple, array)) and isinstance(datatype, ArrayType):
return f"PARSE_JSON({str_to_sql(json.dumps(value))}) :: ARRAY"
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: ARRAY"

if isinstance(value, dict) and isinstance(datatype, MapType):
return f"PARSE_JSON({str_to_sql(json.dumps(value))}) :: OBJECT"
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: OBJECT"

if isinstance(datatype, VariantType):
# PARSE_JSON returns VARIANT, so no need to append :: VARIANT here explicitly.
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))})"

raise TypeError(f"Unsupported datatype {datatype}, value {value} by to_sql()")

Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def convert_sp_to_sf_type(datatype: DataType) -> str:
_NumericType,
ArrayType,
MapType,
VariantType,
)

# Mapping Python array types to DataType
Expand Down
54 changes: 54 additions & 0 deletions tests/integ/scala/test_literal_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
#

import datetime
import json
from decimal import Decimal

from snowflake.snowpark import Column, Row
from snowflake.snowpark._internal.analyzer.expression import Literal
from snowflake.snowpark._internal.utils import PythonObjJSONEncoder
from snowflake.snowpark.functions import lit
from snowflake.snowpark.types import (
DecimalType,
Expand All @@ -16,6 +18,7 @@
TimestampTimeZone,
TimestampType,
TimeType,
VariantType,
)
from tests.utils import Utils

Expand Down Expand Up @@ -204,6 +207,7 @@ def test_array_object(session):
.with_column("list1", lit([1, 2, 3]))
.with_column("list2", lit([]))
.with_column("list3", lit([1, "1", 2.5, None]))
.with_column("list4", lit([datetime.date(2023, 4, 5)]))
.with_column("tuple1", lit((1, 2, 3)))
.with_column("tuple2", lit(()))
.with_column("tuple3", lit((1, "1", 2.5, None)))
Expand All @@ -218,6 +222,7 @@ def test_array_object(session):
"StructField('LIST1', ArrayType(StringType()), nullable=True), "
"StructField('LIST2', ArrayType(StringType()), nullable=True), "
"StructField('LIST3', ArrayType(StringType()), nullable=True), "
"StructField('LIST4', ArrayType(StringType()), nullable=True), "
"StructField('TUPLE1', ArrayType(StringType()), nullable=True), "
"StructField('TUPLE2', ArrayType(StringType()), nullable=True), "
"StructField('TUPLE3', ArrayType(StringType()), nullable=True), "
Expand All @@ -232,6 +237,7 @@ def test_array_object(session):
LIST1="[\n 1,\n 2,\n 3\n]",
LIST2="[]",
LIST3='[\n 1,\n "1",\n 2.5,\n null\n]',
LIST4='[\n "2023-04-05"\n]',
TUPLE1="[\n 1,\n 2,\n 3\n]",
TUPLE2="[]",
TUPLE3='[\n 1,\n "1",\n 2.5,\n null\n]',
Expand All @@ -240,3 +246,51 @@ def test_array_object(session):
DICT3='{\n "a": [\n 1,\n "\'"\n ],\n "b": {\n "1": null\n }\n}',
),
)


def test_literal_variant(session):
LITERAL_VALUES = [
None,
1,
3.141,
"hello world",
True,
[1, 2, 3],
(2, 3, 4),
{4: 5, 6: 1},
{"a": 10},
datetime.datetime.now(),
datetime.date(2023, 4, 5),
]
df = session.range(1)

for i, value in enumerate(LITERAL_VALUES):
df = df.with_column(f"x{i}", Column(Literal(value, VariantType())))

field_str = str(df.schema.fields)
ref_field_str = (
"[StructField('ID', LongType(), nullable=False), "
+ ", ".join(
[
f"StructField('X{i}', VariantType(), nullable=True)"
for i in range(len(LITERAL_VALUES))
]
)
+ "]"
)
assert field_str == ref_field_str
kwargs = {
f"X{i}": json.dumps(value, cls=PythonObjJSONEncoder)
if value is not None
else None
for i, value in enumerate(LITERAL_VALUES)
}
ans = (
str(df.collect()[0])
.replace("\\n ", "")
.replace("\\n", "")
.replace(", ", ",")
.replace(",", ", ")
) # normalize Snowflake formatting for easier comparison
ref = str(Row(ID=0, **kwargs))
assert ans == ref
17 changes: 11 additions & 6 deletions tests/unit/test_datatype_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,20 @@ def test_to_sql():
assert to_sql([{1: 2}], ArrayType()) == "PARSE_JSON('[{\"1\": 2}]') :: ARRAY"
assert to_sql({1: [2]}, MapType()) == "PARSE_JSON('{\"1\": [2]}') :: OBJECT"

# value must be json serializable
with pytest.raises(TypeError, match="is not JSON serializable"):
to_sql([1, bytearray(1)], ArrayType())
assert (
to_sql([1, bytearray(1)], ArrayType()) == "PARSE_JSON('[1, \"00\"]') :: ARRAY"
)

with pytest.raises(TypeError, match="is not JSON serializable"):
assert (
to_sql(["2", Decimal(0.5)], ArrayType())
== "PARSE_JSON('[\"2\", 0.5]') :: ARRAY"
)

with pytest.raises(TypeError, match="is not JSON serializable"):
to_sql({1: datetime.datetime.today()}, MapType())
dt = datetime.datetime.today()
assert (
to_sql({1: dt}, MapType())
== 'PARSE_JSON(\'{"1": "' + dt.isoformat() + "\"}') :: OBJECT"
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 3bd6378

Please sign in to comment.