Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed #26223 -- Fixed migration optimizer for operations with transient defaults. #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions django/db/migrations/autodetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,25 +802,23 @@ def _generate_added_field(self, app_label, model_name, field_name):
dependencies.extend(self._get_dependecies_for_foreign_key(field))
# You can't just add NOT NULL fields with no default or fields
# which don't allow empty strings as default.
preserve_default = True
default = models.NOT_PROVIDED
time_fields = (models.DateField, models.DateTimeField, models.TimeField)
if (not field.null and not field.has_default() and
not field.many_to_many and
not (field.blank and field.empty_strings_allowed) and
not (isinstance(field, time_fields) and field.auto_now)):
field = field.clone()
if isinstance(field, time_fields) and field.auto_now_add:
field.default = self.questioner.ask_auto_now_add_addition(field_name, model_name)
default = self.questioner.ask_auto_now_add_addition(field_name, model_name)
else:
field.default = self.questioner.ask_not_null_addition(field_name, model_name)
preserve_default = False
default = self.questioner.ask_not_null_addition(field_name, model_name)
self.add_operation(
app_label,
operations.AddField(
model_name=model_name,
name=field_name,
field=field,
preserve_default=preserve_default,
default=default,
),
dependencies=dependencies,
)
Expand Down Expand Up @@ -881,23 +879,17 @@ def generate_altered_fields(self):
neither_m2m = not old_field.many_to_many and not new_field.many_to_many
if both_m2m or neither_m2m:
# Either both fields are m2m or neither is
preserve_default = True
default = models.NOT_PROVIDED
if (old_field.null and not new_field.null and not new_field.has_default() and
not new_field.many_to_many):
field = new_field.clone()
new_default = self.questioner.ask_not_null_alteration(field_name, model_name)
if new_default is not models.NOT_PROVIDED:
field.default = new_default
preserve_default = False
else:
field = new_field
default = self.questioner.ask_not_null_alteration(field_name, model_name)
self.add_operation(
app_label,
operations.AlterField(
model_name=model_name,
name=field_name,
field=field,
preserve_default=preserve_default,
field=new_field,
default=default,
)
)
else:
Expand Down
64 changes: 35 additions & 29 deletions django/db/migrations/operations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,15 @@ class AddField(FieldOperation):
Adds a field to a model.
"""

def __init__(self, model_name, name, field, preserve_default=True):
def __init__(self, model_name, name, field, default=NOT_PROVIDED, preserve_default=None):
if preserve_default is not None:
# TODO: add a deprecation warning here?
if not preserve_default:
field = field.clone()
default = field.default
field.default = NOT_PROVIDED
self.field = field
self.preserve_default = preserve_default
self.default = default
super(AddField, self).__init__(model_name, name)

def deconstruct(self):
Expand All @@ -54,37 +60,32 @@ def deconstruct(self):
'name': self.name,
'field': self.field,
}
if self.preserve_default is not True:
kwargs['preserve_default'] = self.preserve_default
if self.default is not NOT_PROVIDED:
kwargs['default'] = self.default
return (
self.__class__.__name__,
[],
kwargs
)

def state_forwards(self, app_label, state):
# If preserve default is off, don't use the default for future state
if not self.preserve_default:
field = self.field.clone()
field.default = NOT_PROVIDED
else:
field = self.field
state.models[app_label, self.model_name_lower].fields.append((self.name, field))
state.models[app_label, self.model_name_lower].fields.append((self.name, self.field))
state.reload_model(app_label, self.model_name_lower)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.model_name)
field = to_model._meta.get_field(self.name)
if not self.preserve_default:
field.default = self.field.default
field_default = field.default
if self.default is not NOT_PROVIDED:
field.default = self.default
schema_editor.add_field(
from_model,
field,
)
if not self.preserve_default:
field.default = NOT_PROVIDED
if self.default is not NOT_PROVIDED:
field.default = field_default

def database_backwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name)
Expand All @@ -102,6 +103,7 @@ def reduce(self, operation, in_between, app_label=None):
model_name=self.model_name,
name=operation.name,
field=operation.field,
default=self.default if self.default is not NOT_PROVIDED else operation.default,
),
]
elif isinstance(operation, RemoveField):
Expand All @@ -112,6 +114,7 @@ def reduce(self, operation, in_between, app_label=None):
model_name=self.model_name,
name=operation.new_name,
field=self.field,
default=self.default,
),
]
return super(AddField, self).reduce(operation, in_between, app_label=app_label)
Expand Down Expand Up @@ -161,9 +164,15 @@ class AlterField(FieldOperation):
Alters a field's database column (e.g. null, max_length) to the provided new field
"""

def __init__(self, model_name, name, field, preserve_default=True):
def __init__(self, model_name, name, field, default=NOT_PROVIDED, preserve_default=None):
if preserve_default is not None:
# TODO: add a deprecation warning here?
if not preserve_default:
field = field.clone()
default = field.default
field.default = NOT_PROVIDED
self.field = field
self.preserve_default = preserve_default
self.default = default
super(AlterField, self).__init__(model_name, name)

def deconstruct(self):
Expand All @@ -172,22 +181,17 @@ def deconstruct(self):
'name': self.name,
'field': self.field,
}
if self.preserve_default is not True:
kwargs['preserve_default'] = self.preserve_default
if self.default is not NOT_PROVIDED:
kwargs['default'] = self.default
return (
self.__class__.__name__,
[],
kwargs
)

def state_forwards(self, app_label, state):
if not self.preserve_default:
field = self.field.clone()
field.default = NOT_PROVIDED
else:
field = self.field
state.models[app_label, self.model_name_lower].fields = [
(n, field if n == self.name else f)
(n, self.field if n == self.name else f)
for n, f in
state.models[app_label, self.model_name_lower].fields
]
Expand All @@ -199,11 +203,12 @@ def database_forwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name)
from_field = from_model._meta.get_field(self.name)
to_field = to_model._meta.get_field(self.name)
if not self.preserve_default:
to_field.default = self.field.default
field_default = to_field.default
if self.default is not NOT_PROVIDED:
to_field.default = self.default
schema_editor.alter_field(from_model, from_field, to_field)
if not self.preserve_default:
to_field.default = NOT_PROVIDED
if self.default is not NOT_PROVIDED:
to_field.default = field_default

def database_backwards(self, app_label, schema_editor, from_state, to_state):
self.database_forwards(app_label, schema_editor, from_state, to_state)
Expand All @@ -221,6 +226,7 @@ def reduce(self, operation, in_between, app_label=None):
model_name=self.model_name,
name=operation.new_name,
field=self.field,
default=self.default,
),
]
return super(AlterField, self).reduce(operation, in_between, app_label=app_label)
Expand Down
10 changes: 5 additions & 5 deletions tests/migrations/test_autodetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def test_alter_field(self):
# Right number/type of migrations?
self.assertNumberMigrations(changes, 'testapp', 1)
self.assertOperationTypes(changes, 'testapp', 0, ["AlterField"])
self.assertOperationAttributes(changes, "testapp", 0, 0, name="name", preserve_default=True)
self.assertOperationAttributes(changes, "testapp", 0, 0, name="name", default=models.NOT_PROVIDED)

def test_supports_functools_partial(self):
def _content_file_name(instance, filename, key, **kwargs):
Expand Down Expand Up @@ -759,7 +759,7 @@ def test_alter_field_to_not_null_with_default(self, mocked_ask_method):
# Right number/type of migrations?
self.assertNumberMigrations(changes, 'testapp', 1)
self.assertOperationTypes(changes, 'testapp', 0, ["AlterField"])
self.assertOperationAttributes(changes, "testapp", 0, 0, name="name", preserve_default=True)
self.assertOperationAttributes(changes, "testapp", 0, 0, name="name", default=models.NOT_PROVIDED)
self.assertOperationFieldAttributes(changes, "testapp", 0, 0, default='Ada Lovelace')

@mock.patch('django.db.migrations.questioner.MigrationQuestioner.ask_not_null_alteration',
Expand All @@ -773,7 +773,7 @@ def test_alter_field_to_not_null_without_default(self, mocked_ask_method):
# Right number/type of migrations?
self.assertNumberMigrations(changes, 'testapp', 1)
self.assertOperationTypes(changes, 'testapp', 0, ["AlterField"])
self.assertOperationAttributes(changes, "testapp", 0, 0, name="name", preserve_default=True)
self.assertOperationAttributes(changes, "testapp", 0, 0, name="name", default=models.NOT_PROVIDED)
self.assertOperationFieldAttributes(changes, "testapp", 0, 0, default=models.NOT_PROVIDED)

@mock.patch('django.db.migrations.questioner.MigrationQuestioner.ask_not_null_alteration',
Expand All @@ -787,8 +787,8 @@ def test_alter_field_to_not_null_oneoff_default(self, mocked_ask_method):
# Right number/type of migrations?
self.assertNumberMigrations(changes, 'testapp', 1)
self.assertOperationTypes(changes, 'testapp', 0, ["AlterField"])
self.assertOperationAttributes(changes, "testapp", 0, 0, name="name", preserve_default=False)
self.assertOperationFieldAttributes(changes, "testapp", 0, 0, default="Some Name")
self.assertOperationAttributes(changes, "testapp", 0, 0, name="name", default="Some Name")
self.assertOperationFieldAttributes(changes, "testapp", 0, 0, default=models.NOT_PROVIDED)

def test_rename_field(self):
"""Tests autodetection of renamed fields."""
Expand Down
10 changes: 5 additions & 5 deletions tests/migrations/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,18 +947,18 @@ def test_column_name_quoting(self):
operation.database_forwards("test_regr22168", editor, project_state, new_state)
self.assertColumnExists("test_regr22168_pony", "order")

def test_add_field_preserve_default(self):
def test_add_field_default(self):
"""
Tests the AddField operation's state alteration
when preserve_default = False.
when default is given.
"""
project_state = self.set_up_test_model("test_adflpd")
# Test the state alteration
operation = migrations.AddField(
"Pony",
"height",
models.FloatField(null=True, default=4),
preserve_default=False,
models.FloatField(null=True),
default=4,
)
new_state = project_state.clone()
operation.state_forwards("test_adflpd", new_state)
Expand All @@ -980,7 +980,7 @@ def test_add_field_preserve_default(self):
definition = operation.deconstruct()
self.assertEqual(definition[0], "AddField")
self.assertEqual(definition[1], [])
self.assertEqual(sorted(definition[2]), ["field", "model_name", "name", "preserve_default"])
self.assertEqual(sorted(definition[2]), ["default", "field", "model_name", "name"])

def test_add_field_m2m(self):
"""
Expand Down
96 changes: 94 additions & 2 deletions tests/migrations/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db import migrations, models
from django.db.migrations import operations
from django.db.migrations.optimizer import MigrationOptimizer
from django.db.migrations.serializer import serializer_factory
from django.test import SimpleTestCase

from .models import EmptyManager, UnicodeModel
Expand All @@ -20,10 +21,13 @@ def optimize(self, operations, app_label):
optimizer = MigrationOptimizer()
return optimizer.optimize(operations, app_label), optimizer._iterations

def serialize(self, value):
return serializer_factory(value).serialize()[0]

def assertOptimizesTo(self, operations, expected, exact=None, less_than=None, app_label=None):
result, iterations = self.optimize(operations, app_label)
result = [repr(f.deconstruct()) for f in result]
expected = [repr(f.deconstruct()) for f in expected]
result = [self.serialize(f) for f in result]
expected = [self.serialize(f) for f in expected]
self.assertEqual(expected, result)
if exact is not None and iterations != exact:
raise self.failureException(
Expand Down Expand Up @@ -696,3 +700,91 @@ def test_optimize_elidable_operation(self):
migrations.CreateModel("Phou", [("name", models.CharField(max_length=255))]),
],
)

def test_default_create_model_add_field(self):
"""
AddField optimizing into CreateModel should drop default if
preserve_default is False.
"""
self.assertOptimizesTo(
[
migrations.CreateModel("Foo", [("name", models.CharField(max_length=255))]),
migrations.AddField("Foo", "value", models.IntegerField(default=42), preserve_default=False),
],
[
migrations.CreateModel("Foo", [
("name", models.CharField(max_length=255)),
("value", models.IntegerField()),
]),
],
)

def test_default_create_model_alter_field(self):
"""
AlterField optimizing into CreateModel should drop default if
preserve_default is False.
"""
self.assertOptimizesTo(
[
migrations.CreateModel("Foo", [("value", models.IntegerField(null=True))]),
migrations.AlterField("Foo", "value", models.IntegerField(default=42), preserve_default=False),
],
[
migrations.CreateModel("Foo", [("value", models.IntegerField())]),
],
)

def test_default_add_field_alter_field(self):
"""
AddField and AlterField optimizing when one of them has
preserve_default=False should pass the preserve_default value.
"""
self.assertOptimizesTo(
[
migrations.AddField("Foo", "value", models.IntegerField(null=True)),
migrations.AlterField("Foo", "value", models.IntegerField(default=42), preserve_default=False),
],
[
migrations.AddField("Foo", "value", models.IntegerField(default=42), preserve_default=False),
],
)
self.assertOptimizesTo(
[
migrations.AddField("Foo", "value", models.IntegerField(default=42), preserve_default=False),
migrations.AlterField("Foo", "value", models.IntegerField("Value")),
],
[
migrations.AddField("Foo", "value", models.IntegerField("Value", default=42), preserve_default=False),
],
)

def test_default_add_field_rename_field(self):
"""
AddField optimizing with RenameField should retain its
preserve_default value.
"""
self.assertOptimizesTo(
[
migrations.AddField("Foo", "value", models.IntegerField(default=42), preserve_default=False),
migrations.RenameField("Foo", "value", "price"),
],
[
migrations.AddField("Foo", "price", models.IntegerField(default=42), preserve_default=False),
],
)

def test_default_alter_field_rename_field(self):
"""
AlterField optimizing with RenameField should retain its
preserve_default value.
"""
self.assertOptimizesTo(
[
migrations.AlterField("Foo", "value", models.IntegerField(default=42), preserve_default=False),
migrations.RenameField("Foo", "value", "price"),
],
[
migrations.RenameField("Foo", "value", "price"),
migrations.AlterField("Foo", "price", models.IntegerField(default=42), preserve_default=False),
],
)