Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-878116 Add support for PARTITION BY to COPY INTO location #542

Merged
merged 12 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ Source code is also available at:

# Release Notes

- (Unreleased)
- Add support for partition by to copy into <location>

- v1.7.0(November 22, 2024)

- Add support for dynamic tables and required options
Expand Down
26 changes: 20 additions & 6 deletions src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from sqlalchemy.schema import Sequence, Table
from sqlalchemy.sql import compiler, expression, functions
from sqlalchemy.sql.base import CompileState
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import BindParameter, quoted_name
from sqlalchemy.sql.expression import Executable
from sqlalchemy.sql.selectable import Lateral, SelectState

from snowflake.sqlalchemy._constants import DIALECT_NAME
Expand Down Expand Up @@ -563,9 +564,8 @@ def visit_copy_into(self, copy_into, **kw):
if isinstance(copy_into.into, Table)
else copy_into.into._compiler_dispatch(self, **kw)
)
from_ = None
if isinstance(copy_into.from_, Table):
from_ = copy_into.from_
from_ = copy_into.from_.name
# this is intended to catch AWSBucket and AzureContainer
elif (
isinstance(copy_into.from_, AWSBucket)
Expand All @@ -576,6 +576,21 @@ def visit_copy_into(self, copy_into, **kw):
# everything else (selects, etc.)
else:
from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})"

partition_by_value = None
if isinstance(copy_into.partition_by, (BindParameter, Executable)):
partition_by_value = copy_into.partition_by.compile(
compile_kwargs={"literal_binds": True}
)
elif copy_into.partition_by is not None:
partition_by_value = copy_into.partition_by

partition_by = (
f"PARTITION BY {partition_by_value}"
if partition_by_value is not None and partition_by_value != ""
else ""
)

credentials, encryption = "", ""
if isinstance(into, tuple):
into, credentials, encryption = into
Expand All @@ -586,8 +601,7 @@ def visit_copy_into(self, copy_into, **kw):
options_list.sort(key=operator.itemgetter(0))
options = (
(
" "
+ " ".join(
" ".join(
[
"{} = {}".format(
n,
Expand All @@ -608,7 +622,7 @@ def visit_copy_into(self, copy_into, **kw):
options += f" {credentials}"
if encryption:
options += f" {encryption}"
return f"COPY INTO {into} FROM {from_} {formatter}{options}"
return f"COPY INTO {into} FROM {' '.join([from_, partition_by, formatter, options])}"

def visit_copy_formatter(self, formatter, **kw):
options_list = list(formatter.options.items())
Expand Down
9 changes: 7 additions & 2 deletions src/snowflake/sqlalchemy/custom_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,23 @@ class CopyInto(UpdateBase):
__visit_name__ = "copy_into"
_bind = None

def __init__(self, from_, into, formatter=None):
def __init__(self, from_, into, partition_by=None, formatter=None):
self.from_ = from_
self.into = into
self.formatter = formatter
self.copy_options = {}
self.partition_by = partition_by

def __repr__(self):
"""
repr for debugging / logging purposes only. For compilation logic, see
the corresponding visitor in base.py
"""
return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})"
val = f"COPY INTO {self.into} FROM {repr(self.from_)}"
if self.partition_by is not None:
val += f" PARTITION BY {self.partition_by}"

return val + f" {repr(self.formatter)} ({self.copy_options})"

def bind(self):
return None
Expand Down
65 changes: 48 additions & 17 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table
from sqlalchemy.sql import select, text
from sqlalchemy.sql import functions, select, text

from snowflake.sqlalchemy import (
AWSBucket,
Expand Down Expand Up @@ -58,8 +58,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
)
assert (
sql_compiler(copy_stmt_1)
== "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv "
"ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION="
== "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv "
"ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION="
"(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')"
)
copy_stmt_2 = CopyIntoStorage(
Expand All @@ -73,8 +73,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
sql_compiler(copy_stmt_2)
== "COPY INTO 's3://backup' FROM (SELECT python_tests_foods.id, "
"python_tests_foods.name, python_tests_foods.quantity FROM python_tests_foods "
"WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' "
"FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') "
"WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' "
"FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') "
"ENCRYPTION=(TYPE='AWS_SSE_S3')"
)
copy_stmt_3 = CopyIntoStorage(
Expand All @@ -87,15 +87,15 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
assert (
sql_compiler(copy_stmt_3)
== "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' "
"FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) "
"FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) "
"CREDENTIALS=(AZURE_SAS_TOKEN='token')"
)

copy_stmt_3.maxfilesize(50000000)
assert (
sql_compiler(copy_stmt_3)
== "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' "
"FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) "
"FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) "
"MAX_FILE_SIZE = 50000000 "
"CREDENTIALS=(AZURE_SAS_TOKEN='token')"
)
Expand All @@ -112,8 +112,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
)
assert (
sql_compiler(copy_stmt_4)
== "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv "
"ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION="
== "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv "
"ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION="
"(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')"
)

Expand All @@ -126,8 +126,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
)
assert (
sql_compiler(copy_stmt_5)
== "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv "
"FIELD_DELIMITER=',') ENCRYPTION="
== "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv "
"FIELD_DELIMITER=',') ENCRYPTION="
"(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')"
)

Expand All @@ -138,7 +138,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
)
assert (
sql_compiler(copy_stmt_6)
== "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv)"
== "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv) "
)

copy_stmt_7 = CopyIntoStorage(
Expand All @@ -148,7 +148,38 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
)
assert (
sql_compiler(copy_stmt_7)
== "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv)"
== "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv) "
)

copy_stmt_8 = CopyIntoStorage(
from_=food_items,
into=ExternalStage(name="stage_name"),
partition_by=text("('YEAR=' || year)"),
)
assert (
sql_compiler(copy_stmt_8)
== "COPY INTO @stage_name FROM python_tests_foods PARTITION BY ('YEAR=' || year) "
)

copy_stmt_9 = CopyIntoStorage(
from_=food_items,
into=ExternalStage(name="stage_name"),
partition_by=functions.concat(
text("'YEAR='"), text(food_items.columns["name"].name)
),
)
assert (
sql_compiler(copy_stmt_9)
== "COPY INTO @stage_name FROM python_tests_foods PARTITION BY concat('YEAR=', name) "
)

copy_stmt_10 = CopyIntoStorage(
from_=food_items,
into=ExternalStage(name="stage_name"),
partition_by="",
)
assert (
sql_compiler(copy_stmt_10) == "COPY INTO @stage_name FROM python_tests_foods "
)

# NOTE Other than expect known compiled text, submit it to RegressionTests environment and expect them to fail, but
Expand Down Expand Up @@ -231,7 +262,7 @@ def test_copy_into_storage_csv_extended(sql_compiler):
result = sql_compiler(copy_into)
expected = (
r"COPY INTO TEST_IMPORT "
r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata "
r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata "
r"FILE_FORMAT=(TYPE=csv COMPRESSION='auto' DATE_FORMAT='AUTO' "
r"ERROR_ON_COLUMN_COUNT_MISMATCH=True ESCAPE=None "
r"ESCAPE_UNENCLOSED_FIELD='\134' FIELD_DELIMITER=',' "
Expand Down Expand Up @@ -288,7 +319,7 @@ def test_copy_into_storage_parquet_named_format(sql_compiler):
expected = (
"COPY INTO TEST_IMPORT "
"FROM (SELECT $1:COL1::number, $1:COL2::varchar "
"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) "
"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) "
"FILE_FORMAT=(format_name = parquet_file_format) force = TRUE"
)
assert result == expected
Expand Down Expand Up @@ -350,7 +381,7 @@ def test_copy_into_storage_parquet_files(sql_compiler):
"COPY INTO TEST_IMPORT "
"FROM (SELECT $1:COL1::number, $1:COL2::varchar "
"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet "
"(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') "
"(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') "
"FORCE = true"
)
assert result == expected
Expand Down Expand Up @@ -412,6 +443,6 @@ def test_copy_into_storage_parquet_pattern(sql_compiler):
"COPY INTO TEST_IMPORT "
"FROM (SELECT $1:COL1::number, $1:COL2::varchar "
"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet "
"(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'"
"(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'"
)
assert result == expected
Loading