Skip to content

Commit

Permalink
SNOW-1651234: Fix create_dataframe throwing an exception for structur…
Browse files Browse the repository at this point in the history
…ed dtypes (#2240)
  • Loading branch information
sfc-gh-jrose authored Sep 20, 2024
1 parent 208ad7a commit 5514105
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 46 deletions.
23 changes: 12 additions & 11 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,10 @@
column,
lit,
parse_json,
to_array,
to_date,
to_decimal,
to_geography,
to_geometry,
to_object,
to_time,
to_timestamp,
to_timestamp_ltz,
Expand Down Expand Up @@ -2832,14 +2830,15 @@ def convert_row_to_list(
if isinstance(
field.datatype,
(
VariantType,
ArrayType,
MapType,
TimeType,
DateType,
TimestampType,
GeographyType,
GeometryType,
MapType,
StructType,
TimeType,
TimestampType,
VariantType,
VectorType,
),
)
Expand Down Expand Up @@ -2877,7 +2876,9 @@ def convert_row_to_list(
data_type, ArrayType
):
converted_row.append(json.dumps(value, cls=PythonObjJSONEncoder))
elif isinstance(value, dict) and isinstance(data_type, MapType):
elif isinstance(value, dict) and isinstance(
data_type, (MapType, StructType)
):
converted_row.append(json.dumps(value, cls=PythonObjJSONEncoder))
elif isinstance(data_type, VariantType):
converted_row.append(json.dumps(value, cls=PythonObjJSONEncoder))
Expand Down Expand Up @@ -2925,10 +2926,10 @@ def convert_row_to_list(
project_columns.append(to_geography(column(name)).as_(name))
elif isinstance(field.datatype, GeometryType):
project_columns.append(to_geometry(column(name)).as_(name))
elif isinstance(field.datatype, ArrayType):
project_columns.append(to_array(parse_json(column(name))).as_(name))
elif isinstance(field.datatype, MapType):
project_columns.append(to_object(parse_json(column(name))).as_(name))
elif isinstance(field.datatype, (ArrayType, MapType, StructType)):
project_columns.append(
parse_json(column(name)).cast(field.datatype).as_(name)
)
elif isinstance(field.datatype, VectorType):
project_columns.append(
parse_json(column(name)).cast(field.datatype).as_(name)
Expand Down
74 changes: 39 additions & 35 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def test_structured_dtypes_pandas(structured_type_session, structured_type_suppo

@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="strucutred types do not fully support structured types yet.",
reason="local testing does not fully support structured types yet.",
)
def test_structured_dtypes_iceberg(
structured_type_session, local_testing_mode, structured_type_support
Expand All @@ -445,20 +445,9 @@ def test_structured_dtypes_iceberg(
query, expected_dtypes, expected_schema = STRUCTURED_TYPES_EXAMPLES[True]

table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}"
save_table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}"
try:
structured_type_session.sql(
f"""
create iceberg table if not exists {table_name} (
map map(varchar, int),
obj object(A varchar, B float),
arr array(float)
)
CATALOG = 'SNOWFLAKE'
EXTERNAL_VOLUME = 'python_connector_iceberg_exvol'
BASE_LOCATION = 'python_connector_merge_gate';
"""
).collect()
create_df = structured_type_session.create_dataframe([], schema=expected_schema)
create_df.write.save_as_table(table_name, iceberg_config=ICEBERG_CONFIG)
structured_type_session.sql(
f"""
insert into {table_name}
Expand All @@ -469,29 +458,54 @@ def test_structured_dtypes_iceberg(
assert df.schema == expected_schema
assert df.dtypes == expected_dtypes

# Try to save_as_table
structured_type_session.table(table_name).write.save_as_table(
save_table_name, iceberg_config=ICEBERG_CONFIG
)

save_ddl = structured_type_session._run_query(
f"select get_ddl('table', '{save_table_name}')"
f"select get_ddl('table', '{table_name}')"
)
assert save_ddl[0][0] == (
f"create or replace ICEBERG TABLE {save_table_name.upper()} (\n\t"
f"create or replace ICEBERG TABLE {table_name.upper()} (\n\t"
"MAP MAP(STRING, LONG),\n\tOBJ OBJECT(A STRING, B DOUBLE),\n\tARR ARRAY(DOUBLE)\n)\n "
"EXTERNAL_VOLUME = 'PYTHON_CONNECTOR_ICEBERG_EXVOL'\n CATALOG = 'SNOWFLAKE'\n "
"BASE_LOCATION = 'python_connector_merge_gate/';"
)

finally:
structured_type_session.sql(f"drop table if exists {table_name}")
structured_type_session.sql(f"drop table if exists {save_table_name}")


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="strucutred types do not fully support structured types yet.",
reason="local testing does not fully support structured types yet.",
)
def test_structured_dtypes_iceberg_create_from_values(
structured_type_session, local_testing_mode, structured_type_support
):
if not (
structured_type_support
and iceberg_supported(structured_type_session, local_testing_mode)
):
pytest.skip("Test requires iceberg support and structured type support.")

_, __, expected_schema = STRUCTURED_TYPES_EXAMPLES[True]
table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}"
data = [
({"x": 1}, {"A": "a", "B": 1}, [1, 1, 1]),
({"x": 2}, {"A": "b", "B": 2}, [2, 2, 2]),
]
try:
create_df = structured_type_session.create_dataframe(
data, schema=expected_schema
)
create_df.write.save_as_table(table_name, iceberg_config=ICEBERG_CONFIG)
assert structured_type_session.table(table_name).order_by(
col("ARR"), ascending=True
).collect() == [Row(*d) for d in data]
finally:
structured_type_session.sql(f"drop table if exists {table_name}")


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="local testing does not fully support structured types yet.",
)
def test_structured_dtypes_iceberg_udf(
structured_type_session, local_testing_mode, structured_type_support
Expand Down Expand Up @@ -520,18 +534,8 @@ def nop(x):
)

try:
structured_type_session.sql(
f"""
create iceberg table if not exists {table_name} (
map map(varchar, int),
obj object(A varchar, B float),
arr array(float)
)
CATALOG = 'SNOWFLAKE'
EXTERNAL_VOLUME = 'python_connector_iceberg_exvol'
BASE_LOCATION = 'python_connector_merge_gate';
"""
).collect()
create_df = structured_type_session.create_dataframe([], schema=expected_schema)
create_df.write.save_as_table(table_name, iceberg_config=ICEBERG_CONFIG)
structured_type_session.sql(
f"""
insert into {table_name}
Expand Down

0 comments on commit 5514105

Please sign in to comment.