From b36cff765167b368efd6b18ee493e33f216cac14 Mon Sep 17 00:00:00 2001 From: azban Date: Thu, 27 Jul 2023 18:25:47 -0700 Subject: [PATCH] add PARTITION BY option for CopyInto --- src/snowflake/sqlalchemy/base.py | 11 +++++++---- src/snowflake/sqlalchemy/custom_commands.py | 11 ++++++++--- tests/test_copy.py | 10 ++++++++++ 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 9835ced2..a47b54a1 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -215,7 +215,6 @@ def visit_copy_into(self, copy_into, **kw): if isinstance(copy_into.into, Table) else copy_into.into._compiler_dispatch(self, **kw) ) - from_ = None if isinstance(copy_into.from_, Table): from_ = copy_into.from_ # this is intended to catch AWSBucket and AzureContainer @@ -228,6 +227,11 @@ def visit_copy_into(self, copy_into, **kw): # everything else (selects, etc.) else: from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})" + + partition_by = "" + if copy_into.partition_by is not None: + partition_by = f"PARTITION BY {copy_into.partition_by}" + credentials, encryption = "", "" if isinstance(into, tuple): into, credentials, encryption = into @@ -238,8 +242,7 @@ def visit_copy_into(self, copy_into, **kw): options_list.sort(key=operator.itemgetter(0)) options = ( ( - " " - + " ".join( + " ".join( [ "{} = {}".format( n, @@ -258,7 +261,7 @@ def visit_copy_into(self, copy_into, **kw): options += f" {credentials}" if encryption: options += f" {encryption}" - return f"COPY INTO {into} FROM {from_} {formatter}{options}" + return f"COPY INTO {into} FROM {' '.join([from_, partition_by, formatter, options])}" def visit_copy_formatter(self, formatter, **kw): options_list = list(formatter.options.items()) diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index 9bb60916..64ffd2cf 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -114,18 +114,23 @@ class CopyInto(UpdateBase): __visit_name__ = "copy_into" _bind = None - def __init__(self, from_, into, formatter=None): + def __init__(self, from_, into, partition_by=None, formatter=None): self.from_ = from_ self.into = into self.formatter = formatter self.copy_options = {} + self.partition_by = partition_by def __repr__(self): """ repr for debugging / logging purposes only. For compilation logic, see the corresponding visitor in base.py """ - return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})" + val = f"COPY INTO {self.into} FROM {repr(self.from_)}" + if self.partition_by is not None: + val += f" PARTITION BY {self.partition_by}" + + return val + f" {repr(self.formatter)} ({self.copy_options})" def bind(self): return None @@ -530,7 +535,7 @@ def __repr__(self): ) def credentials( - self, aws_role=None, aws_key_id=None, aws_secret_key=None, aws_token=None + self, aws_role=None, aws_key_id=None, aws_secret_key=None, aws_token=None ): if aws_role is None and (aws_key_id is None and aws_secret_key is None): raise ValueError( diff --git a/tests/test_copy.py b/tests/test_copy.py index e0752d4f..92e38160 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -151,6 +151,16 @@ def test_copy_into_location(engine_testaccount, sql_compiler): == "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv)" ) + copy_stmt_7 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by="('YEAR=' || year)" + ) + assert ( + sql_compiler(copy_stmt_7) + == "COPY INTO @stage_name FROM python_tests_foods PARTITION BY ('YEAR=' || year)" + ) + # NOTE Other than expect known compiled text, submit it to RegressionTests environment and expect them to fail, but # because of the right reasons acceptable_exc_reasons = {