Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement multi-table INSERT statement #404

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/snowflake/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
CreateStage,
CSVFormatter,
ExternalStage,
InsertMulti,
JSONFormatter,
MergeInto,
PARQUETFormatter,
Expand Down
43 changes: 43 additions & 0 deletions src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,49 @@ def visit_merge_into_clause(self, merge_into_clause, **kw):
" SET %s" % sets if merge_into_clause.set else "",
)

def visit_insert_multi(self, insert_all, **kw):
clauses = []
for condition, table, columns, values in insert_all.clauses:
clauses.append(
(
f"WHEN {condition._compiler_dispatch(self, include_table=False, **kw)} THEN "
if condition is not None
else ""
)
+ f"INTO {table._compiler_dispatch(self, asfrom=True, **kw)}"
+ (
f" ({', '.join(c._compiler_dispatch(self, include_table=False, **kw) for c in columns)})"
if columns
else ""
)
+ (
f" VALUES ({', '.join(v._compiler_dispatch(self, include_table=False, render_label_as_label=v, **kw) for v in values)})"
if values
else ""
)
)

if insert_all.else__:
else_ = (
f" ELSE {insert_all.else__[0]._compiler_dispatch(self, asfrom=True, **kw)}"
+ (
f" ({', '.join(c._compiler_dispatch(self, include_table=False, **kw) for c in insert_all.else__[1])})"
if insert_all.else__[1]
else ""
)
+ (
f" VALUES ({', '.join(v._compiler_dispatch(self, include_table=False, render_label_as_label=v, **kw) for v in insert_all.else__[2])})"
if insert_all.else__[2]
else ""
)
)
else:
else_ = ""
overwrite = " OVERWRITE" if insert_all.overwrite else ""
condition = "FIRST" if insert_all.is_conditional and insert_all.first else "ALL"
source = insert_all.source._compiler_dispatch(self, asfrom=True, **kw)
return f"INSERT{overwrite} {condition} {' '.join(clauses)}{else_} {source}"

def visit_copy_into(self, copy_into, **kw):
if hasattr(copy_into, "formatter") and copy_into.formatter is not None:
formatter = copy_into.formatter._compiler_dispatch(self, **kw)
Expand Down
86 changes: 86 additions & 0 deletions src/snowflake/sqlalchemy/custom_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,92 @@ def when_not_matched_then_insert(self):
return clause


class InsertMulti(UpdateBase):
__visit_name__ = "insert_multi"
_bind = None

def __init__(self, source, overwrite=False, first=False):
self.source = source
self.overwrite = overwrite
self.first = first
self.clauses = []
self.else__ = None

@property
def is_conditional(self):
return any(condition is not None for condition, _, _, _ in self.clauses)

def __repr__(self):
clauses = []
for condition, table, columns, values in self.clauses:
clauses.append(
(f"WHEN {condition!r} THEN " if condition is not None else "")
+ f" INTO {table!r}"
+ (f"({', '.join(repr(c) for c in columns)})" if columns else "")
+ (f" VALUES ({', '.join(str(v) for v in values)})" if values else "")
)
else_ = f" ELSE {self.else__!r}" if self.else__ else ""
overwrite = " OVERWRITE" if self.overwrite else ""
condition = "FIRST" if self.is_conditional and self.first else "ALL"
return (
f"INSERT{overwrite} {condition} {', '.join(clauses)}{else_} {self.source}"
)

def _adapt_columns(self, columns, coll):
"""Make sure all columns are column instances from the given table, not strings"""
if columns is None:
return None
return [coll[c] if isinstance(c, str) else c for c in columns]

def into(self, table, columns=None, values=None):
if self.is_conditional:
raise ValueError(
"Cannot add an unconditional clause to a Conditional multi-table insert"
)
if columns and values:
assert len(columns) == len(
values
), "columns and values must be of the same length"
self.clauses.append(
(
None,
table,
self._adapt_columns(columns, table.c),
self._adapt_columns(values, self.source.selected_columns),
)
)
return self

def when(self, condition, table, columns=None, values=None):
if self.clauses and not self.is_conditional:
raise ValueError(
"Cannot add a conditional clause to an Unconditional multi-table insert"
)
if columns and values:
assert len(columns) == len(
values
), "columns and values must be of the same length"
self.clauses.append(
(
condition,
table,
self._adapt_columns(columns, table.c),
self._adapt_columns(values, self.source.selected_columns),
)
)
return self

def else_(self, table, columns=None, values=None):
if self.clauses and not self.is_conditional:
raise ValueError("Cannot set ELSE on an Unconditional multi-table insert")
self.else__ = (
table,
self._adapt_columns(columns, table.c),
self._adapt_columns(values, self.source.selected_columns),
)
return self


class FilesOption:
"""
Class to represent FILES option for the snowflake COPY INTO statement
Expand Down
141 changes: 140 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Table,
UniqueConstraint,
dialects,
func,
inspect,
text,
)
Expand All @@ -39,7 +40,7 @@
import snowflake.connector.errors
import snowflake.sqlalchemy.snowdialect
from snowflake.connector import Error, ProgrammingError, connect
from snowflake.sqlalchemy import URL, MergeInto, dialect
from snowflake.sqlalchemy import URL, InsertMulti, MergeInto, dialect
from snowflake.sqlalchemy._constants import (
APPLICATION_NAME,
SNOWFLAKE_SQLALCHEMY_VERSION,
Expand Down Expand Up @@ -1290,6 +1291,144 @@ def test_deterministic_merge_into(sql_compiler):
)


def test_unconditional_insert_all(sql_compiler):
meta = MetaData()
users1 = Table(
"users1",
meta,
Column("id", Integer, Sequence("user_id_seq"), primary_key=True),
Column("name", String),
Column("fullname", String),
Column("created_at", DateTime),
)
users2 = Table(
"users2",
meta,
Column("id", Integer, Sequence("user_id_seq2"), primary_key=True),
Column("name", String),
Column("full/name", String),
)
onboarding_users = Table(
"onboarding_users",
meta,
Column("id", Integer, Sequence("new_user_id_seq"), primary_key=True),
Column("name", String),
Column("fullname", String),
Column("delete", Boolean),
)
insert_all = (
InsertMulti(
select(
onboarding_users.c.id,
onboarding_users.c.name,
onboarding_users.c.fullname,
)
)
.into(users1)
.into(users2)
)
assert (
sql_compiler(insert_all) == "INSERT ALL INTO users1 INTO users2 "
"SELECT onboarding_users.id, onboarding_users.name, onboarding_users.fullname "
"FROM onboarding_users"
)

stmt = select(
onboarding_users.c.id,
onboarding_users.c.name.label("name_label"),
onboarding_users.c.fullname,
onboarding_users.c.delete,
)
insert_all = (
InsertMulti(stmt)
.into(
users1,
["id", "name", users1.c.fullname, users1.c.created_at],
[
"id",
"name_label",
stmt.selected_columns.fullname,
func.now(),
],
)
.into(
users2,
[users2.c.name, users2.c["full/name"]],
[stmt.selected_columns.fullname, stmt.selected_columns.name_label],
)
)
assert (
sql_compiler(insert_all) == "INSERT ALL "
"INTO users1 (id, name, fullname, created_at) VALUES (id, name_label, fullname, CURRENT_TIMESTAMP) "
'INTO users2 (name, "full/name") VALUES (fullname, name_label) '
"SELECT onboarding_users.id, onboarding_users.name AS name_label, onboarding_users.fullname, "
'onboarding_users."delete" FROM onboarding_users'
)


def test_conditional_insert_multi(sql_compiler):
meta = MetaData()
users1 = Table(
"users1",
meta,
Column("id", Integer, Sequence("user_id_seq"), primary_key=True),
Column("name", String),
Column("fullname", String),
)
users2 = Table(
"users2",
meta,
Column("id", Integer, Sequence("user_id_seq2"), primary_key=True),
Column("name", String),
Column("full/name", String),
)
onboarding_users = Table(
"onboarding_users",
meta,
Column("id", Integer, Sequence("new_user_id_seq"), primary_key=True),
Column("name", String),
Column("fullname", String),
Column("delete", Boolean),
)
stmt = select(
onboarding_users.c.id,
onboarding_users.c.name,
onboarding_users.c.fullname,
onboarding_users.c.delete,
)
insert_all = (
InsertMulti(stmt)
.when(
stmt.selected_columns.delete,
users1,
values=[
stmt.selected_columns.id,
stmt.selected_columns.name,
stmt.selected_columns.fullname,
],
)
.when(
~stmt.selected_columns.delete,
users2,
[users2.c.id, users2.c.name, users2.c["full/name"]],
[
stmt.selected_columns.id,
stmt.selected_columns.name,
stmt.selected_columns.fullname,
],
)
.else_(users1)
)
assert (
sql_compiler(insert_all) == "INSERT ALL "
'WHEN "delete" THEN INTO users1 VALUES (id, name, fullname) '
'WHEN NOT "delete" THEN INTO users2 (id, name, "full/name") VALUES (id, name, fullname) '
"ELSE users1 "
"SELECT onboarding_users.id, onboarding_users.name, onboarding_users.fullname, "
'onboarding_users."delete" FROM onboarding_users'
)


def test_comments(engine_testaccount):
"""Tests strictly reading column comment through SQLAlchemy"""
table_name = random_string(5, choices=string.ascii_uppercase)
Expand Down