Skip to content

Commit

Permalink
implement INSERT ALL statement
Browse files Browse the repository at this point in the history
  • Loading branch information
sjhewitt committed Apr 19, 2023
1 parent 7baaa1a commit 7758ef6
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/snowflake/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
CreateStage,
CSVFormatter,
ExternalStage,
InsertMulti,
JSONFormatter,
MergeInto,
PARQUETFormatter,
Expand Down
43 changes: 43 additions & 0 deletions src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,49 @@ def visit_merge_into_clause(self, merge_into_clause, **kw):
" SET %s" % sets if merge_into_clause.set else "",
)

def visit_insert_multi(self, insert_all, **kw):
clauses = []
for condition, table, columns, values in insert_all.clauses:
clauses.append(
(
f"WHEN {condition._compiler_dispatch(self, include_table=False, **kw)} THEN "
if condition is not None
else ""
)
+ f"INTO {table._compiler_dispatch(self, asfrom=True, **kw)}"
+ (
f" ({', '.join(c._compiler_dispatch(self, include_table=False, **kw) for c in columns)})"
if columns
else ""
)
+ (
f" VALUES ({', '.join(v._compiler_dispatch(self, include_table=False, render_label_as_label=v, **kw) for v in values)})"
if values
else ""
)
)

if insert_all.else__:
else_ = (
f" ELSE {insert_all.else__[0]._compiler_dispatch(self, asfrom=True, **kw)}"
+ (
f" ({', '.join(c._compiler_dispatch(self, include_table=False, **kw) for c in insert_all.else__[1])})"
if insert_all.else__[1]
else ""
)
+ (
f" VALUES ({', '.join(v._compiler_dispatch(self, include_table=False, render_label_as_label=v, **kw) for v in insert_all.else__[2])})"
if insert_all.else__[2]
else ""
)
)
else:
else_ = ""
overwrite = " OVERWRITE" if insert_all.overwrite else ""
condition = "FIRST" if insert_all.is_conditional and insert_all.first else "ALL"
source = insert_all.source._compiler_dispatch(self, asfrom=True, **kw)
return f"INSERT{overwrite} {condition} {' '.join(clauses)}{else_} {source}"

def visit_copy_into(self, copy_into, **kw):
if hasattr(copy_into, "formatter") and copy_into.formatter is not None:
formatter = copy_into.formatter._compiler_dispatch(self, **kw)
Expand Down
86 changes: 86 additions & 0 deletions src/snowflake/sqlalchemy/custom_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,92 @@ def when_not_matched_then_insert(self):
return clause


class InsertMulti(UpdateBase):
__visit_name__ = "insert_multi"
_bind = None

def __init__(self, source, overwrite=False, first=False):
self.source = source
self.overwrite = overwrite
self.first = first
self.clauses = []
self.else__ = None

@property
def is_conditional(self):
return any(condition is not None for condition, _, _, _ in self.clauses)

def __repr__(self):
clauses = []
for condition, table, columns, values in self.clauses:
clauses.append(
(f"WHEN {condition!r} THEN " if condition is not None else "")
+ f" INTO {table!r}"
+ (f"({', '.join(repr(c) for c in columns)})" if columns else "")
+ (f" VALUES ({', '.join(str(v) for v in values)})" if values else "")
)
else_ = f" ELSE {self.else__!r}" if self.else__ else ""
overwrite = " OVERWRITE" if self.overwrite else ""
condition = "FIRST" if self.is_conditional and self.first else "ALL"
return (
f"INSERT{overwrite} {condition} {', '.join(clauses)}{else_} {self.source}"
)

def _adapt_columns(self, columns, coll):
"""Make sure all columns are column instances from the given table, not strings"""
if columns is None:
return None
return [coll[c] if isinstance(c, str) else c for c in columns]

def into(self, table, columns=None, values=None):
if self.is_conditional:
raise ValueError(
"Cannot add an unconditional clause to a Conditional multi-table insert"
)
if columns and values:
assert len(columns) == len(
values
), "columns and values must be of the same length"
self.clauses.append(
(
None,
table,
self._adapt_columns(columns, table.c),
self._adapt_columns(values, self.source.selected_columns),
)
)
return self

def when(self, condition, table, columns=None, values=None):
if self.clauses and not self.is_conditional:
raise ValueError(
"Cannot add a conditional clause to an Unconditional multi-table insert"
)
if columns and values:
assert len(columns) == len(
values
), "columns and values must be of the same length"
self.clauses.append(
(
condition,
table,
self._adapt_columns(columns, table.c),
self._adapt_columns(values, self.source.selected_columns),
)
)
return self

def else_(self, table, columns=None, values=None):
if self.clauses and not self.is_conditional:
raise ValueError("Cannot set ELSE on an Unconditional multi-table insert")
self.else__ = (
table,
self._adapt_columns(columns, table.c),
self._adapt_columns(values, self.source.selected_columns),
)
return self


class FilesOption:
"""
Class to represent FILES option for the snowflake COPY INTO statement
Expand Down
141 changes: 140 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Table,
UniqueConstraint,
dialects,
func,
inspect,
text,
)
Expand All @@ -39,7 +40,7 @@
import snowflake.connector.errors
import snowflake.sqlalchemy.snowdialect
from snowflake.connector import Error, ProgrammingError, connect
from snowflake.sqlalchemy import URL, MergeInto, dialect
from snowflake.sqlalchemy import URL, InsertMulti, MergeInto, dialect
from snowflake.sqlalchemy._constants import (
APPLICATION_NAME,
SNOWFLAKE_SQLALCHEMY_VERSION,
Expand Down Expand Up @@ -1290,6 +1291,144 @@ def test_deterministic_merge_into(sql_compiler):
)


def test_unconditional_insert_all(sql_compiler):
meta = MetaData()
users1 = Table(
"users1",
meta,
Column("id", Integer, Sequence("user_id_seq"), primary_key=True),
Column("name", String),
Column("fullname", String),
Column("created_at", DateTime),
)
users2 = Table(
"users2",
meta,
Column("id", Integer, Sequence("user_id_seq2"), primary_key=True),
Column("name", String),
Column("full/name", String),
)
onboarding_users = Table(
"onboarding_users",
meta,
Column("id", Integer, Sequence("new_user_id_seq"), primary_key=True),
Column("name", String),
Column("fullname", String),
Column("delete", Boolean),
)
insert_all = (
InsertMulti(
select(
onboarding_users.c.id,
onboarding_users.c.name,
onboarding_users.c.fullname,
)
)
.into(users1)
.into(users2)
)
assert (
sql_compiler(insert_all) == "INSERT ALL INTO users1 INTO users2 "
"SELECT onboarding_users.id, onboarding_users.name, onboarding_users.fullname "
"FROM onboarding_users"
)

stmt = select(
onboarding_users.c.id,
onboarding_users.c.name.label("name_label"),
onboarding_users.c.fullname,
onboarding_users.c.delete,
)
insert_all = (
InsertMulti(stmt)
.into(
users1,
["id", "name", users1.c.fullname, users1.c.created_at],
[
"id",
"name_label",
stmt.selected_columns.fullname,
func.now(),
],
)
.into(
users2,
[users2.c.name, users2.c["full/name"]],
[stmt.selected_columns.fullname, stmt.selected_columns.name_label],
)
)
assert (
sql_compiler(insert_all) == "INSERT ALL "
"INTO users1 (id, name, fullname, created_at) VALUES (id, name_label, fullname, CURRENT_TIMESTAMP) "
'INTO users2 (name, "full/name") VALUES (fullname, name_label) '
"SELECT onboarding_users.id, onboarding_users.name AS name_label, onboarding_users.fullname, "
'onboarding_users."delete" FROM onboarding_users'
)


def test_conditional_insert_multi(sql_compiler):
meta = MetaData()
users1 = Table(
"users1",
meta,
Column("id", Integer, Sequence("user_id_seq"), primary_key=True),
Column("name", String),
Column("fullname", String),
)
users2 = Table(
"users2",
meta,
Column("id", Integer, Sequence("user_id_seq2"), primary_key=True),
Column("name", String),
Column("full/name", String),
)
onboarding_users = Table(
"onboarding_users",
meta,
Column("id", Integer, Sequence("new_user_id_seq"), primary_key=True),
Column("name", String),
Column("fullname", String),
Column("delete", Boolean),
)
stmt = select(
onboarding_users.c.id,
onboarding_users.c.name,
onboarding_users.c.fullname,
onboarding_users.c.delete,
)
insert_all = (
InsertMulti(stmt)
.when(
stmt.selected_columns.delete,
users1,
values=[
stmt.selected_columns.id,
stmt.selected_columns.name,
stmt.selected_columns.fullname,
],
)
.when(
~stmt.selected_columns.delete,
users2,
[users2.c.id, users2.c.name, users2.c["full/name"]],
[
stmt.selected_columns.id,
stmt.selected_columns.name,
stmt.selected_columns.fullname,
],
)
.else_(users1)
)
assert (
sql_compiler(insert_all) == "INSERT ALL "
'WHEN "delete" THEN INTO users1 VALUES (id, name, fullname) '
'WHEN NOT "delete" THEN INTO users2 (id, name, "full/name") VALUES (id, name, fullname) '
"ELSE users1 "
"SELECT onboarding_users.id, onboarding_users.name, onboarding_users.fullname, "
'onboarding_users."delete" FROM onboarding_users'
)


def test_comments(engine_testaccount):
"""Tests strictly reading column comment through SQLAlchemy"""
table_name = random_string(5, choices=string.ascii_uppercase)
Expand Down

0 comments on commit 7758ef6

Please sign in to comment.