diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 2a1bb51a..38112daa 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -564,11 +564,11 @@ def visit_copy_into(self, copy_into, **kw): # everything else (selects, etc.) else: from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})" - credentials, encryption = "", "" + storage_integration, credentials, encryption = "", "", "" if isinstance(into, tuple): - into, credentials, encryption = into + into, storage_integration, credentials, encryption = into elif isinstance(from_, tuple): - from_, credentials, encryption = from_ + from_, storage_integration, credentials, encryption = from_ options_list = list(copy_into.copy_options.items()) if kw.get("deterministic", False): options_list.sort(key=operator.itemgetter(0)) @@ -592,6 +592,8 @@ def visit_copy_into(self, copy_into, **kw): if copy_into.copy_options else "" ) + if storage_integration: + options += f" {storage_integration}" if credentials: options += f" {credentials}" if encryption: @@ -630,6 +632,9 @@ def visit_aws_bucket(self, aws_bucket, **kw): credentials_list = list(aws_bucket.credentials_used.items()) if kw.get("deterministic", False): credentials_list.sort(key=operator.itemgetter(0)) + storage_integration = "STORAGE_INTEGRATION={}".format( + aws_bucket.storage_integration_used + ) credentials = "CREDENTIALS=({})".format( " ".join(f"{n}='{v}'" for n, v in credentials_list) ) @@ -647,6 +652,7 @@ def visit_aws_bucket(self, aws_bucket, **kw): ) return ( uri, + storage_integration if aws_bucket.storage_integration_used else "", credentials if aws_bucket.credentials_used else "", encryption if aws_bucket.encryption_used else "", ) @@ -655,6 +661,9 @@ def visit_azure_container(self, azure_container, **kw): credentials_list = list(azure_container.credentials_used.items()) if kw.get("deterministic", False): credentials_list.sort(key=operator.itemgetter(0)) + storage_integration = "STORAGE_INTEGRATION={}".format( + azure_container.storage_integration_used + ) credentials = "CREDENTIALS=({})".format( " ".join(f"{n}='{v}'" for n, v in credentials_list) ) @@ -674,6 +683,7 @@ def visit_azure_container(self, azure_container, **kw): ) return ( uri, + storage_integration if azure_container.storage_integration_used else "", credentials if azure_container.credentials_used else "", encryption if azure_container.encryption_used else "", ) diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index cec16673..26683705 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -502,6 +502,7 @@ def __init__(self, bucket, path=None): self.path = path self.encryption_used = {} self.credentials_used = {} + self.storage_integration_used = None @classmethod def from_uri(cls, uri): @@ -531,6 +532,10 @@ def __repr__(self): f" {encryption}" if self.encryption_used else "", ) + def storage_integration(self, integration_name): + self.storage_integration_used = integration_name + return self + def credentials( self, aws_role=None, aws_key_id=None, aws_secret_key=None, aws_token=None ): @@ -575,6 +580,7 @@ def __init__(self, account, container, path=None): self.path = path self.encryption_used = {} self.credentials_used = {} + self.storage_integration_used = None @classmethod def from_uri(cls, uri): @@ -609,6 +615,10 @@ def __repr__(self): f" {encryption}" if self.encryption_used else "", ) + def storage_integration(self, integration_name): + self.storage_integration_used = integration_name + return self + def credentials(self, azure_sas_token): self.credentials_used = {"AZURE_SAS_TOKEN": azure_sas_token} return self diff --git a/tests/test_copy.py b/tests/test_copy.py index e0752d4f..1a701118 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -62,6 +62,16 @@ def test_copy_into_location(engine_testaccount, sql_compiler): "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) + copy_stmt_1a = CopyIntoStorage( + from_=food_items, + into=AWSBucket.from_uri("s3://backup").storage_integration("foobar"), + formatter=CSVFormatter(), + ) + assert ( + sql_compiler(copy_stmt_1a) + == "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv) " + "STORAGE_INTEGRATION=foobar" + ) copy_stmt_2 = CopyIntoStorage( from_=select(food_items).where(food_items.c.id == 1), # Test sub-query into=AWSBucket.from_uri("s3://backup")