Skip to content

Commit

Permalink
Support Truncate with Cascade
Browse files Browse the repository at this point in the history
  • Loading branch information
marcostvz committed Jan 8, 2024
1 parent a6ea29a commit 8868201
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
21 changes: 18 additions & 3 deletions dj_anonymizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
'oracle': 'TRUNCATE TABLE',
}

VENDOR_TO_CASCADE = {
'postgresql': 'CASCADE',
'oracle': 'CASCADE',
}


def import_if_exist(filename):
"""
Expand All @@ -27,7 +32,7 @@ def import_if_exist(filename):
spec.loader.exec_module(mod)


def truncate_table(model):
def truncate_table(model, cascade=False):
"""
Generate and execute via Django ORM proper SQL to truncate table
"""
Expand All @@ -42,11 +47,21 @@ def truncate_table(model):
"Database vendor %s is not supported" % vendor
)

cascade_op = ''
try:
if cascade:
cascade_op = VENDOR_TO_CASCADE[vendor]
except KeyError:
raise NotImplementedError(
"Database vendor %s does not support TRUNCATE with CASCADE" % vendor
)

dbtable = '"{}"'.format(model._meta.db_table)

sql = '{operation} {dbtable}'.format(
sql = '{operation} {dbtable} {cascade}'.format(
operation=operation,
dbtable=dbtable,
)
cascade=cascade_op,
).strip()
with connection.cursor() as c:
c.execute(sql)
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,20 @@ def test_truncate_table(mock_connections):

with pytest.raises(NotImplementedError):
truncate_table(User)


@mock.patch('dj_anonymizer.utils.connections')
def test_truncate_table_with_cascade(mock_connections):
mock_cursor = mock_connections.\
__getitem__(DEFAULT_DB_ALIAS).\
cursor.return_value.__enter__.return_value
mock_connections.__getitem__(DEFAULT_DB_ALIAS).vendor = 'postgresql'

truncate_table(User, True)
mock_cursor.execute.assert_called_once_with('TRUNCATE TABLE "auth_user" CASCADE')

mock_connections.__getitem__(DEFAULT_DB_ALIAS).vendor = 'sqlite'

with pytest.raises(NotImplementedError):
truncate_table(User, True)

0 comments on commit 8868201

Please sign in to comment.