diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ef69275ac8..c1f768ffddc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ ### Bug Fixes - DataFrame column names quoting check now supports newline characters. +- Fix a bug where a DataFrame generated by `session.read.with_metadata` creates inconsistent table when doing `df.write.save_as_table`. ## 1.10.0 (2023-11-03) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 61fbaa4cbfb..0b4a0c0157c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -820,6 +820,7 @@ def read_file( schema_to_cast: Optional[List[Tuple[str, str]]] = None, transformations: Optional[List[str]] = None, metadata_project: Optional[List[str]] = None, + metadata_schema: Optional[List[Attribute]] = None, ): format_type_options, copy_options = get_copy_into_table_options(options) pattern = options.get("PATTERN") @@ -879,24 +880,24 @@ def read_file( if infer_schema: assert schema_to_cast is not None - schema_project = schema_cast_named(schema_to_cast) + schema_project: List[str] = schema_cast_named(schema_to_cast) else: - schema_project = schema_cast_seq(schema) + schema_project: List[str] = schema_cast_seq(schema) - metadata_project = [] if metadata_project is None else metadata_project queries.append( Query( select_from_path_with_format_statement( - metadata_project + schema_project, + (metadata_project or []) + schema_project, path, format_name, pattern, ) ) ) + return SnowflakePlan( queries, - schema_value_statement(schema), + schema_value_statement((metadata_schema or []) + schema), post_queries, {}, None, diff --git a/src/snowflake/snowpark/column.py b/src/snowflake/snowpark/column.py index 5fc93ad58c5..d3bd09076fb 100644 --- a/src/snowflake/snowpark/column.py +++ b/src/snowflake/snowpark/column.py @@ -71,7 +71,13 @@ type_string_to_type_object, ) from snowflake.snowpark._internal.utils import parse_positional_args_to_list, quote_name -from snowflake.snowpark.types import DataType +from snowflake.snowpark.types import ( + DataType, + IntegerType, + StringType, + TimestampTimeZone, + TimestampType, +) from snowflake.snowpark.window import Window, WindowSpec # Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable @@ -828,3 +834,11 @@ def otherwise(self, value: ColumnOrLiteral) -> "CaseExpr": METADATA_FILE_LAST_MODIFIED = Column("METADATA$FILE_LAST_MODIFIED") METADATA_START_SCAN_TIME = Column("METADATA$START_SCAN_TIME") METADATA_FILENAME = Column("METADATA$FILENAME") + +METADATA_COLUMN_TYPES = { + METADATA_FILE_ROW_NUMBER.get_name(): IntegerType(), + METADATA_FILE_CONTENT_KEY.getName(): StringType(), + METADATA_FILE_LAST_MODIFIED.getName(): TimestampType(TimestampTimeZone.NTZ), + METADATA_START_SCAN_TIME.getName(): TimestampType(TimestampTimeZone.LTZ), + METADATA_FILENAME.getName(): StringType(), +} diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 254500030b7..08fa27582fd 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -18,6 +18,7 @@ SelectSnowflakePlan, SelectStatement, ) +from snowflake.snowpark._internal.analyzer.unary_expression import Alias from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.telemetry import set_api_call_source from snowflake.snowpark._internal.type_utils import ColumnOrName, convert_sf_to_sp_type @@ -27,7 +28,7 @@ get_copy_into_table_options, random_name_for_temp_object, ) -from snowflake.snowpark.column import _to_col_if_str +from snowflake.snowpark.column import METADATA_COLUMN_TYPES, Column, _to_col_if_str from snowflake.snowpark.dataframe import DataFrame from snowflake.snowpark.functions import sql_expr from snowflake.snowpark.table import Table @@ -310,6 +311,39 @@ def _infer_schema(self): return self._cur_options.get("INFER_SCHEMA", True) return False + def _get_metadata_project_and_schema(self) -> Tuple[List[str], List[Attribute]]: + if self._metadata_cols: + metadata_project = [ + self._session._analyzer.analyze(col._expression, {}) + for col in self._metadata_cols + ] + else: + metadata_project = [] + + metadata_schema = [] + + def _get_unaliased_name(unaliased: ColumnOrName) -> str: + if isinstance(unaliased, Column): + if isinstance(unaliased._expression, Alias): + return unaliased._expression.child.sql + return unaliased._named().name + return unaliased + + try: + metadata_schema = [ + Attribute( + metadata_col._named().name, + METADATA_COLUMN_TYPES[_get_unaliased_name(metadata_col).upper()], + ) + for metadata_col in self._metadata_cols or [] + ] + except KeyError: + raise ValueError( + f"Metadata column name is not supported. Supported {METADATA_COLUMN_TYPES.keys()}, Got {metadata_project}" + ) + + return metadata_project, metadata_schema + def table(self, name: Union[str, Iterable[str]]) -> Table: """Returns a Table that points to the specified table. @@ -389,13 +423,7 @@ def csv(self, path: str) -> DataFrame: else: schema = self._user_schema._to_attributes() - if self._metadata_cols: - metadata_project = [ - self._session._analyzer.analyze(col._expression, {}) - for col in self._metadata_cols - ] - else: - metadata_project = [] + metadata_project, metadata_schema = self._get_metadata_project_and_schema() if self._session.sql_simplifier_enabled: df = DataFrame( @@ -411,6 +439,7 @@ def csv(self, path: str) -> DataFrame: schema_to_cast=schema_to_cast, transformations=transformations, metadata_project=metadata_project, + metadata_schema=metadata_schema, ), analyzer=self._session._analyzer, ), @@ -429,6 +458,7 @@ def csv(self, path: str) -> DataFrame: schema_to_cast=schema_to_cast, transformations=transformations, metadata_project=metadata_project, + metadata_schema=metadata_schema, ), ) df._reader = self @@ -644,13 +674,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: if new_schema: schema = new_schema - if self._metadata_cols: - metadata_project = [ - self._session._analyzer.analyze(col._expression, {}) - for col in self._metadata_cols - ] - else: - metadata_project = [] + metadata_project, metadata_schema = self._get_metadata_project_and_schema() if self._session.sql_simplifier_enabled: df = DataFrame( @@ -666,6 +690,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: schema_to_cast=schema_to_cast, transformations=read_file_transformations, metadata_project=metadata_project, + metadata_schema=metadata_schema, ), analyzer=self._session._analyzer, ), @@ -684,6 +709,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: schema_to_cast=schema_to_cast, transformations=read_file_transformations, metadata_project=metadata_project, + metadata_schema=metadata_schema, ), ) df._reader = self diff --git a/src/snowflake/snowpark/mock/plan_builder.py b/src/snowflake/snowpark/mock/plan_builder.py index 3983670b535..ce1b39d9734 100644 --- a/src/snowflake/snowpark/mock/plan_builder.py +++ b/src/snowflake/snowpark/mock/plan_builder.py @@ -26,6 +26,7 @@ def read_file( schema_to_cast: Optional[List[Tuple[str, str]]] = None, transformations: Optional[List[str]] = None, metadata_project: Optional[List[str]] = None, + metadata_schema: Optional[List[Attribute]] = None, ) -> MockExecutionPlan: if format.upper() != "CSV": raise NotImplementedError( diff --git a/tests/integ/scala/test_dataframe_reader_suite.py b/tests/integ/scala/test_dataframe_reader_suite.py index e3cfcfc3940..b2b189696c0 100644 --- a/tests/integ/scala/test_dataframe_reader_suite.py +++ b/tests/integ/scala/test_dataframe_reader_suite.py @@ -662,18 +662,49 @@ def test_read_metadata_column_from_stage(session, file_format): assert isinstance(res[0]["METADATA$FILE_LAST_MODIFIED"], datetime.datetime) assert isinstance(res[0]["METADATA$START_SCAN_TIME"], datetime.datetime) + table_name = Utils.random_table_name() + df.write.save_as_table(table_name, mode="append") + with session.table(table_name) as table_df: + table_res = table_df.collect() + assert table_res[0]["METADATA$FILENAME"] == res[0]["METADATA$FILENAME"] + assert ( + table_res[0]["METADATA$FILE_ROW_NUMBER"] + == res[0]["METADATA$FILE_ROW_NUMBER"] + ) + assert ( + table_res[0]["METADATA$FILE_CONTENT_KEY"] + == res[0]["METADATA$FILE_CONTENT_KEY"] + ) + assert ( + table_res[0]["METADATA$FILE_LAST_MODIFIED"] + == res[0]["METADATA$FILE_LAST_MODIFIED"] + ) + assert isinstance(res[0]["METADATA$START_SCAN_TIME"], datetime.datetime) + # test single column works reader = session.read.with_metadata(METADATA_FILENAME) df = get_df_from_reader_and_file_format(reader, file_format) res = df.collect() assert res[0]["METADATA$FILENAME"] == filename + table_name = Utils.random_table_name() + df.write.save_as_table(table_name, mode="append") + with session.table(table_name) as table_df: + table_res = table_df.collect() + assert table_res[0]["METADATA$FILENAME"] == res[0]["METADATA$FILENAME"] + # test that alias works reader = session.read.with_metadata(METADATA_FILENAME.alias("filename")) df = get_df_from_reader_and_file_format(reader, file_format) res = df.collect() assert res[0]["FILENAME"] == filename + table_name = Utils.random_table_name() + df.write.save_as_table(table_name, mode="append") + with session.table(table_name) as table_df: + table_res = table_df.collect() + assert table_res[0]["FILENAME"] == res[0]["FILENAME"] + # test that column name with str works reader = session.read.with_metadata("metadata$filename", "metadata$file_row_number") df = get_df_from_reader_and_file_format(reader, file_format) @@ -681,6 +712,22 @@ def test_read_metadata_column_from_stage(session, file_format): assert res[0]["METADATA$FILENAME"] == filename assert res[0]["METADATA$FILE_ROW_NUMBER"] >= 0 + table_name = Utils.random_table_name() + df.write.save_as_table(table_name, mode="append") + with session.table(table_name) as table_df: + table_res = table_df.collect() + assert table_res[0]["METADATA$FILENAME"] == res[0]["METADATA$FILENAME"] + assert ( + table_res[0]["METADATA$FILE_ROW_NUMBER"] + == res[0]["METADATA$FILE_ROW_NUMBER"] + ) + + # test non-existing metadata column + with pytest.raises(ValueError, match="Metadata column name is not supported"): + get_df_from_reader_and_file_format( + session.read.with_metadata("metadata$non-existing"), file_format + ) + @pytest.mark.parametrize("mode", ["select", "copy"]) def test_read_json_with_no_schema(session, mode):