Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-966485 Fix reader schema with metadata columns #1143

Merged
merged 11 commits into from
Dec 1, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
### Bug Fixes

- DataFrame column names qouting check now supports newline characters.
- Fix a bug where a DataFrame generated by `session.read.with_metadata` creates inconsistent table when doing `df.write.save_as_table`.

## 1.10.0 (2023-11-03)

Expand Down
38 changes: 35 additions & 3 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
GeneratorTableFunction,
TableFunctionRelation,
)
from snowflake.snowpark._internal.analyzer.unary_expression import Alias
from snowflake.snowpark._internal.type_utils import ColumnOrName
from snowflake.snowpark.column import METADATA_COLUMN_TYPES, Column

if TYPE_CHECKING:
from snowflake.snowpark._internal.analyzer.select_statement import (
Expand Down Expand Up @@ -819,7 +822,7 @@ def read_file(
schema: List[Attribute],
schema_to_cast: Optional[List[Tuple[str, str]]] = None,
transformations: Optional[List[str]] = None,
metadata_project: Optional[List[str]] = None,
metadata_columns: Optional[Iterable[ColumnOrName]] = None,
):
format_type_options, copy_options = get_copy_into_table_options(options)
pattern = options.get("PATTERN")
Expand Down Expand Up @@ -883,7 +886,13 @@ def read_file(
else:
schema_project = schema_cast_seq(schema)

metadata_project = [] if metadata_project is None else metadata_project
if metadata_columns:
metadata_project = [
self.session._analyzer.analyze(col._expression, {})
for col in metadata_columns
]
else:
metadata_project = []
queries.append(
Query(
select_from_path_with_format_statement(
Expand All @@ -894,9 +903,32 @@ def read_file(
)
)
)

def _get_unaliased_name(unaliased: ColumnOrName):
if isinstance(unaliased, Column):
if isinstance(unaliased._expression, Alias):
return unaliased._expression.child.sql
return unaliased.get_name()
return unaliased

try:
combined_schema = [
Attribute(
metadata_col.get_name(),
METADATA_COLUMN_TYPES[
_get_unaliased_name(metadata_col).upper()
],
)
for metadata_col in metadata_columns
] + schema
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for metadata_col in metadata_columns
] + schema
for metadata_col in metadata_columns or []
] + schema

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is could be why some tests are failing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I also realized the same after observing the test failure lol

except KeyError:
raise ValueError(
f"Metadata column name is not supported. Supported {METADATA_COLUMN_TYPES.keys()}, Got {metadata_project}"
)

return SnowflakePlan(
queries,
schema_value_statement(schema),
schema_value_statement(combined_schema),
post_queries,
{},
None,
Expand Down
16 changes: 15 additions & 1 deletion src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,13 @@
type_string_to_type_object,
)
from snowflake.snowpark._internal.utils import parse_positional_args_to_list, quote_name
from snowflake.snowpark.types import DataType
from snowflake.snowpark.types import (
DataType,
IntegerType,
StringType,
TimestampTimeZone,
TimestampType,
)
from snowflake.snowpark.window import Window, WindowSpec

# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
Expand Down Expand Up @@ -828,3 +834,11 @@ def otherwise(self, value: ColumnOrLiteral) -> "CaseExpr":
METADATA_FILE_LAST_MODIFIED = Column("METADATA$FILE_LAST_MODIFIED")
METADATA_START_SCAN_TIME = Column("METADATA$START_SCAN_TIME")
METADATA_FILENAME = Column("METADATA$FILENAME")

METADATA_COLUMN_TYPES = {
METADATA_FILE_ROW_NUMBER.get_name(): IntegerType(),
METADATA_FILE_CONTENT_KEY.getName(): StringType(),
METADATA_FILE_LAST_MODIFIED.getName(): TimestampType(TimestampTimeZone.NTZ),
METADATA_START_SCAN_TIME.getName(): TimestampType(TimestampTimeZone.LTZ),
METADATA_FILENAME.getName(): StringType(),
}
24 changes: 4 additions & 20 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,6 @@ def csv(self, path: str) -> DataFrame:
else:
schema = self._user_schema._to_attributes()

if self._metadata_cols:
metadata_project = [
self._session._analyzer.analyze(col._expression, {})
for col in self._metadata_cols
]
else:
metadata_project = []

if self._session.sql_simplifier_enabled:
df = DataFrame(
self._session,
Expand All @@ -410,7 +402,7 @@ def csv(self, path: str) -> DataFrame:
schema,
schema_to_cast=schema_to_cast,
transformations=transformations,
metadata_project=metadata_project,
metadata_columns=self._metadata_cols,
),
analyzer=self._session._analyzer,
),
Expand All @@ -428,7 +420,7 @@ def csv(self, path: str) -> DataFrame:
schema,
schema_to_cast=schema_to_cast,
transformations=transformations,
metadata_project=metadata_project,
metadata_columns=self._metadata_cols,
),
)
df._reader = self
Expand Down Expand Up @@ -637,14 +629,6 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
if new_schema:
schema = new_schema

if self._metadata_cols:
metadata_project = [
self._session._analyzer.analyze(col._expression, {})
for col in self._metadata_cols
]
else:
metadata_project = []

if self._session.sql_simplifier_enabled:
df = DataFrame(
self._session,
Expand All @@ -658,7 +642,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
schema,
schema_to_cast=schema_to_cast,
transformations=read_file_transformations,
metadata_project=metadata_project,
metadata_columns=self._metadata_cols,
),
analyzer=self._session._analyzer,
),
Expand All @@ -676,7 +660,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
schema,
schema_to_cast=schema_to_cast,
transformations=read_file_transformations,
metadata_project=metadata_project,
metadata_columns=self._metadata_cols,
),
)
df._reader = self
Expand Down
47 changes: 47 additions & 0 deletions tests/integ/scala/test_dataframe_reader_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,25 +565,72 @@ def test_read_metadata_column_from_stage(session, file_format):
assert isinstance(res[0]["METADATA$FILE_LAST_MODIFIED"], datetime.datetime)
assert isinstance(res[0]["METADATA$START_SCAN_TIME"], datetime.datetime)

table_name = Utils.random_table_name()
df.write.save_as_table(table_name, mode="append")
with session.table(table_name) as table_df:
table_res = table_df.collect()
assert table_res[0]["METADATA$FILENAME"] == res[0]["METADATA$FILENAME"]
assert (
table_res[0]["METADATA$FILE_ROW_NUMBER"]
== res[0]["METADATA$FILE_ROW_NUMBER"]
)
assert (
table_res[0]["METADATA$FILE_CONTENT_KEY"]
== res[0]["METADATA$FILE_CONTENT_KEY"]
)
assert (
table_res[0]["METADATA$FILE_LAST_MODIFIED"]
== res[0]["METADATA$FILE_LAST_MODIFIED"]
)
assert isinstance(res[0]["METADATA$START_SCAN_TIME"], datetime.datetime)

# test single column works
reader = session.read.with_metadata(METADATA_FILENAME)
df = get_df_from_reader_and_file_format(reader, file_format)
res = df.collect()
assert res[0]["METADATA$FILENAME"] == filename

table_name = Utils.random_table_name()
df.write.save_as_table(table_name, mode="append")
with session.table(table_name) as table_df:
table_res = table_df.collect()
assert table_res[0]["METADATA$FILENAME"] == res[0]["METADATA$FILENAME"]

# test that alias works
reader = session.read.with_metadata(METADATA_FILENAME.alias("filename"))
df = get_df_from_reader_and_file_format(reader, file_format)
res = df.collect()
assert res[0]["FILENAME"] == filename

table_name = Utils.random_table_name()
df.write.save_as_table(table_name, mode="append")
with session.table(table_name) as table_df:
table_res = table_df.collect()
assert table_res[0]["FILENAME"] == res[0]["FILENAME"]

# test that column name with str works
reader = session.read.with_metadata("metadata$filename", "metadata$file_row_number")
df = get_df_from_reader_and_file_format(reader, file_format)
res = df.collect()
assert res[0]["METADATA$FILENAME"] == filename
assert res[0]["METADATA$FILE_ROW_NUMBER"] >= 0

table_name = Utils.random_table_name()
df.write.save_as_table(table_name, mode="append")
with session.table(table_name) as table_df:
table_res = table_df.collect()
assert table_res[0]["METADATA$FILENAME"] == res[0]["METADATA$FILENAME"]
assert (
table_res[0]["METADATA$FILE_ROW_NUMBER"]
== res[0]["METADATA$FILE_ROW_NUMBER"]
)

# test non-existing metadata column
with pytest.raises(ValueError, match="Metadata column name is not supported"):
get_df_from_reader_and_file_format(
session.read.with_metadata("metadata$non-existing"), file_format
)


@pytest.mark.parametrize("mode", ["select", "copy"])
def test_read_json_with_no_schema(session, mode):
Expand Down
Loading