diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 063910fe..0e5818c2 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -37,6 +37,7 @@ CreateStage, CSVFormatter, ExternalStage, + InsertMulti, JSONFormatter, MergeInto, PARQUETFormatter, diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index d87b78c1..081a6868 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -205,6 +205,49 @@ def visit_merge_into_clause(self, merge_into_clause, **kw): " SET %s" % sets if merge_into_clause.set else "", ) + def visit_insert_multi(self, insert_all, **kw): + clauses = [] + for condition, table, columns, values in insert_all.clauses: + clauses.append( + ( + f"WHEN {condition._compiler_dispatch(self, include_table=False, **kw)} THEN " + if condition is not None + else "" + ) + + f"INTO {table._compiler_dispatch(self, asfrom=True, **kw)}" + + ( + f" ({', '.join(c._compiler_dispatch(self, include_table=False, **kw) for c in columns)})" + if columns + else "" + ) + + ( + f" VALUES ({', '.join(v._compiler_dispatch(self, include_table=False, **kw) for v in values)})" + if values + else "" + ) + ) + + source = insert_all.source._compiler_dispatch(self, **kw) + if insert_all.else__: + else_ = ( + f" ELSE {insert_all.else__[0]._compiler_dispatch(self, asfrom=True, **kw)}" + + ( + f" ({', '.join(c._compiler_dispatch(self, include_table=False, **kw) for c in insert_all.else__[1])})" + if insert_all.else__[1] + else "" + ) + + ( + f" VALUES ({', '.join(v._compiler_dispatch(self, include_table=False, **kw) for v in insert_all.else__[2])})" + if insert_all.else__[2] + else "" + ) + ) + else: + else_ = "" + overwrite = " OVERWRITE" if insert_all.overwrite else "" + condition = "FIRST" if insert_all.is_conditional and insert_all.first else "ALL" + return f"INSERT{overwrite} {condition} {' '.join(clauses)}{else_} {source}" + def visit_copy_into(self, copy_into, **kw): if hasattr(copy_into, "formatter") and copy_into.formatter is not None: formatter = copy_into.formatter._compiler_dispatch(self, **kw) diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index 9cc14389..283de606 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -94,6 +94,92 @@ def when_not_matched_then_insert(self): return clause +class InsertMulti(UpdateBase): + __visit_name__ = "insert_multi" + _bind = None + + def __init__(self, source, overwrite=False, first=False): + self.source = source + self.overwrite = overwrite + self.first = first + self.clauses = [] + self.else__ = None + + @property + def is_conditional(self): + return any(condition is not None for condition, _, _, _ in self.clauses) + + def __repr__(self): + clauses = [] + for condition, table, columns, values in self.clauses: + clauses.append( + (f"WHEN {condition!r} THEN " if condition is not None else "") + + f" INTO {table!r}" + + (f"({', '.join(repr(c) for c in columns)})" if columns else "") + + (f" VALUES ({', '.join(str(v) for v in values)})" if values else "") + ) + else_ = f" ELSE {self.else__!r}" if self.else__ else "" + overwrite = " OVERWRITE" if self.overwrite else "" + condition = "FIRST" if self.is_conditional and self.first else "ALL" + return ( + f"INSERT{overwrite} {condition} {', '.join(clauses)}{else_} {self.source}" + ) + + def _adapt_columns(self, columns, coll): + """Make sure all columns are column instances from the given table, not strings""" + if columns is None: + return None + return [coll[c] if isinstance(c, str) else c for c in columns] + + def into(self, table, columns=None, values=None): + if self.is_conditional: + raise ValueError( + "Cannot add an unconditional clause to a Conditional multi-table insert" + ) + if columns and values: + assert len(columns) == len( + values + ), "columns and values must be of the same length" + self.clauses.append( + ( + None, + table, + self._adapt_columns(columns, table.c), + self._adapt_columns(values, self.source.selected_columns), + ) + ) + return self + + def when(self, condition, table, columns=None, values=None): + if self.clauses and not self.is_conditional: + raise ValueError( + "Cannot add a conditional clause to an Unconditional multi-table insert" + ) + if columns and values: + assert len(columns) == len( + values + ), "columns and values must be of the same length" + self.clauses.append( + ( + condition, + table, + self._adapt_columns(columns, table.c), + self._adapt_columns(values, self.source.selected_columns), + ) + ) + return self + + def else_(self, table, columns=None, values=None): + if self.clauses and not self.is_conditional: + raise ValueError("Cannot set ELSE on an Unconditional multi-table insert") + self.else__ = ( + table, + self._adapt_columns(columns, table.c), + self._adapt_columns(values, self.source.selected_columns), + ) + return self + + class FilesOption: """ Class to represent FILES option for the snowflake COPY INTO statement diff --git a/tests/test_core.py b/tests/test_core.py index 157889ff..cee4ff9c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -29,6 +29,7 @@ Table, UniqueConstraint, dialects, + func, inspect, text, ) @@ -39,7 +40,7 @@ import snowflake.connector.errors import snowflake.sqlalchemy.snowdialect from snowflake.connector import Error, ProgrammingError, connect -from snowflake.sqlalchemy import URL, MergeInto, dialect +from snowflake.sqlalchemy import URL, InsertMulti, MergeInto, dialect from snowflake.sqlalchemy._constants import ( APPLICATION_NAME, SNOWFLAKE_SQLALCHEMY_VERSION, @@ -1290,6 +1291,144 @@ def test_deterministic_merge_into(sql_compiler): ) +def test_unconditional_insert_all(sql_compiler): + meta = MetaData() + users1 = Table( + "users1", + meta, + Column("id", Integer, Sequence("user_id_seq"), primary_key=True), + Column("name", String), + Column("fullname", String), + Column("created_at", DateTime), + ) + users2 = Table( + "users2", + meta, + Column("id", Integer, Sequence("user_id_seq2"), primary_key=True), + Column("name", String), + Column("full/name", String), + ) + onboarding_users = Table( + "onboarding_users", + meta, + Column("id", Integer, Sequence("new_user_id_seq"), primary_key=True), + Column("name", String), + Column("fullname", String), + Column("delete", Boolean), + ) + insert_all = ( + InsertMulti( + select( + onboarding_users.c.id, + onboarding_users.c.name, + onboarding_users.c.fullname, + ) + ) + .into(users1) + .into(users2) + ) + assert ( + sql_compiler(insert_all) == "INSERT ALL INTO users1 INTO users2 " + "SELECT onboarding_users.id, onboarding_users.name, onboarding_users.fullname " + "FROM onboarding_users" + ) + + stmt = select( + onboarding_users.c.id, + onboarding_users.c.name, + onboarding_users.c.fullname, + onboarding_users.c.delete, + ) + insert_all = ( + InsertMulti(stmt) + .into( + users1, + [users1.c.id, users1.c.name, users1.c.fullname, users1.c.created_at], + [ + stmt.selected_columns.id, + stmt.selected_columns.name, + stmt.selected_columns.fullname, + func.now(), + ], + ) + .into( + users2, + [users2.c.name, users2.c["full/name"]], + [stmt.selected_columns.fullname, stmt.selected_columns.name], + ) + ) + assert ( + sql_compiler(insert_all) == "INSERT ALL " + "INTO users1 (id, name, fullname, created_at) VALUES (id, name, fullname, CURRENT_TIMESTAMP) " + 'INTO users2 (name, "full/name") VALUES (fullname, name) ' + "SELECT onboarding_users.id, onboarding_users.name, onboarding_users.fullname, " + 'onboarding_users."delete" FROM onboarding_users' + ) + + +def test_conditional_insert_multi(sql_compiler): + meta = MetaData() + users1 = Table( + "users1", + meta, + Column("id", Integer, Sequence("user_id_seq"), primary_key=True), + Column("name", String), + Column("fullname", String), + ) + users2 = Table( + "users2", + meta, + Column("id", Integer, Sequence("user_id_seq2"), primary_key=True), + Column("name", String), + Column("full/name", String), + ) + onboarding_users = Table( + "onboarding_users", + meta, + Column("id", Integer, Sequence("new_user_id_seq"), primary_key=True), + Column("name", String), + Column("fullname", String), + Column("delete", Boolean), + ) + stmt = select( + onboarding_users.c.id, + onboarding_users.c.name, + onboarding_users.c.fullname, + onboarding_users.c.delete, + ) + insert_all = ( + InsertMulti(stmt) + .when( + stmt.selected_columns.delete, + users1, + values=[ + stmt.selected_columns.id, + stmt.selected_columns.name, + stmt.selected_columns.fullname, + ], + ) + .when( + ~stmt.selected_columns.delete, + users2, + [users2.c.id, users2.c.name, users2.c["full/name"]], + [ + stmt.selected_columns.id, + stmt.selected_columns.name, + stmt.selected_columns.fullname, + ], + ) + .else_(users1) + ) + assert ( + sql_compiler(insert_all) == "INSERT ALL " + 'WHEN "delete" THEN INTO users1 VALUES (id, name, fullname) ' + 'WHEN NOT "delete" THEN INTO users2 (id, name, "full/name") VALUES (id, name, fullname) ' + "ELSE users1 " + "SELECT onboarding_users.id, onboarding_users.name, onboarding_users.fullname, " + 'onboarding_users."delete" FROM onboarding_users' + ) + + def test_comments(engine_testaccount): """Tests strictly reading column comment through SQLAlchemy""" table_name = random_string(5, choices=string.ascii_uppercase)