diff --git a/DESCRIPTION.md b/DESCRIPTION.md index e39984b7..82ddebc9 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,9 @@ Source code is also available at: # Release Notes +- (Unreleased) + - Add support for partition by to copy into + - v1.7.0(November 22, 2024) - Add support for dynamic tables and required options diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index a1e16062..02e4f741 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -16,7 +16,8 @@ from sqlalchemy.schema import Sequence, Table from sqlalchemy.sql import compiler, expression, functions from sqlalchemy.sql.base import CompileState -from sqlalchemy.sql.elements import quoted_name +from sqlalchemy.sql.elements import BindParameter, quoted_name +from sqlalchemy.sql.expression import Executable from sqlalchemy.sql.selectable import Lateral, SelectState from snowflake.sqlalchemy._constants import DIALECT_NAME @@ -563,9 +564,8 @@ def visit_copy_into(self, copy_into, **kw): if isinstance(copy_into.into, Table) else copy_into.into._compiler_dispatch(self, **kw) ) - from_ = None if isinstance(copy_into.from_, Table): - from_ = copy_into.from_ + from_ = copy_into.from_.name # this is intended to catch AWSBucket and AzureContainer elif ( isinstance(copy_into.from_, AWSBucket) @@ -576,6 +576,21 @@ def visit_copy_into(self, copy_into, **kw): # everything else (selects, etc.) else: from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})" + + partition_by_value = None + if isinstance(copy_into.partition_by, (BindParameter, Executable)): + partition_by_value = copy_into.partition_by.compile( + compile_kwargs={"literal_binds": True} + ) + elif copy_into.partition_by is not None: + partition_by_value = copy_into.partition_by + + partition_by = ( + f"PARTITION BY {partition_by_value}" + if partition_by_value is not None and partition_by_value != "" + else "" + ) + credentials, encryption = "", "" if isinstance(into, tuple): into, credentials, encryption = into @@ -586,8 +601,7 @@ def visit_copy_into(self, copy_into, **kw): options_list.sort(key=operator.itemgetter(0)) options = ( ( - " " - + " ".join( + " ".join( [ "{} = {}".format( n, @@ -608,7 +622,7 @@ def visit_copy_into(self, copy_into, **kw): options += f" {credentials}" if encryption: options += f" {encryption}" - return f"COPY INTO {into} FROM {from_} {formatter}{options}" + return f"COPY INTO {into} FROM {' '.join([from_, partition_by, formatter, options])}" def visit_copy_formatter(self, formatter, **kw): options_list = list(formatter.options.items()) diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index 15585bd5..1b9260fe 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -115,18 +115,23 @@ class CopyInto(UpdateBase): __visit_name__ = "copy_into" _bind = None - def __init__(self, from_, into, formatter=None): + def __init__(self, from_, into, partition_by=None, formatter=None): self.from_ = from_ self.into = into self.formatter = formatter self.copy_options = {} + self.partition_by = partition_by def __repr__(self): """ repr for debugging / logging purposes only. For compilation logic, see the corresponding visitor in base.py """ - return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})" + val = f"COPY INTO {self.into} FROM {repr(self.from_)}" + if self.partition_by is not None: + val += f" PARTITION BY {self.partition_by}" + + return val + f" {repr(self.formatter)} ({self.copy_options})" def bind(self): return None diff --git a/tests/test_copy.py b/tests/test_copy.py index e0752d4f..8dfcf286 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -4,7 +4,7 @@ import pytest from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table -from sqlalchemy.sql import select, text +from sqlalchemy.sql import functions, select, text from snowflake.sqlalchemy import ( AWSBucket, @@ -58,8 +58,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_1) - == "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv " - "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" + == "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv " + "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) copy_stmt_2 = CopyIntoStorage( @@ -73,8 +73,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): sql_compiler(copy_stmt_2) == "COPY INTO 's3://backup' FROM (SELECT python_tests_foods.id, " "python_tests_foods.name, python_tests_foods.quantity FROM python_tests_foods " - "WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' " - "FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') " + "WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' " + "FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') " "ENCRYPTION=(TYPE='AWS_SSE_S3')" ) copy_stmt_3 = CopyIntoStorage( @@ -87,7 +87,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): assert ( sql_compiler(copy_stmt_3) == "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' " - "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " + "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " "CREDENTIALS=(AZURE_SAS_TOKEN='token')" ) @@ -95,7 +95,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): assert ( sql_compiler(copy_stmt_3) == "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' " - "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " + "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " "MAX_FILE_SIZE = 50000000 " "CREDENTIALS=(AZURE_SAS_TOKEN='token')" ) @@ -112,8 +112,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_4) - == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " - "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" + == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " + "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) @@ -126,8 +126,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_5) - == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " - "FIELD_DELIMITER=',') ENCRYPTION=" + == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " + "FIELD_DELIMITER=',') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) @@ -138,7 +138,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_6) - == "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv)" + == "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv) " ) copy_stmt_7 = CopyIntoStorage( @@ -148,7 +148,38 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_7) - == "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv)" + == "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv) " + ) + + copy_stmt_8 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by=text("('YEAR=' || year)"), + ) + assert ( + sql_compiler(copy_stmt_8) + == "COPY INTO @stage_name FROM python_tests_foods PARTITION BY ('YEAR=' || year) " + ) + + copy_stmt_9 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by=functions.concat( + text("'YEAR='"), text(food_items.columns["name"].name) + ), + ) + assert ( + sql_compiler(copy_stmt_9) + == "COPY INTO @stage_name FROM python_tests_foods PARTITION BY concat('YEAR=', name) " + ) + + copy_stmt_10 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by="", + ) + assert ( + sql_compiler(copy_stmt_10) == "COPY INTO @stage_name FROM python_tests_foods " ) # NOTE Other than expect known compiled text, submit it to RegressionTests environment and expect them to fail, but @@ -231,7 +262,7 @@ def test_copy_into_storage_csv_extended(sql_compiler): result = sql_compiler(copy_into) expected = ( r"COPY INTO TEST_IMPORT " - r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata " + r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata " r"FILE_FORMAT=(TYPE=csv COMPRESSION='auto' DATE_FORMAT='AUTO' " r"ERROR_ON_COLUMN_COUNT_MISMATCH=True ESCAPE=None " r"ESCAPE_UNENCLOSED_FIELD='\134' FIELD_DELIMITER=',' " @@ -288,7 +319,7 @@ def test_copy_into_storage_parquet_named_format(sql_compiler): expected = ( "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " - "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) " + "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) " "FILE_FORMAT=(format_name = parquet_file_format) force = TRUE" ) assert result == expected @@ -350,7 +381,7 @@ def test_copy_into_storage_parquet_files(sql_compiler): "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet " - "(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') " + "(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') " "FORCE = true" ) assert result == expected @@ -412,6 +443,6 @@ def test_copy_into_storage_parquet_pattern(sql_compiler): "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet " - "(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'" + "(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'" ) assert result == expected