diff --git a/corehq/apps/dump_reload/tests/test_serialization.py b/corehq/apps/dump_reload/tests/test_serialization.py index 47eb8dc4551c..6c3da36d9c37 100644 --- a/corehq/apps/dump_reload/tests/test_serialization.py +++ b/corehq/apps/dump_reload/tests/test_serialization.py @@ -1,6 +1,14 @@ +import json +from io import StringIO +from unittest.mock import patch + from django.core.serializers.python import Deserializer from django.test import SimpleTestCase +from corehq.apps.dump_reload.sql.dump import SqlDataDumper +from corehq.form_processor.models.cases import CaseTransaction, CommCareCase +from corehq.form_processor.models.forms import XFormInstance, XFormOperation + class TestJSONFieldSerialization(SimpleTestCase): """ @@ -22,3 +30,53 @@ def _test_json_field_after_serialization(serialized): _test_json_field_after_serialization(serialized_model_with_primary_key) _test_json_field_after_serialization(serialized_model_with_natural_key) + + +class TestForeignKeyFieldSerialization(SimpleTestCase): + """ + We use natural foreign keys when dumping SQL data, but CommCareCase and XFormInstance have natural_key methods + that intentionally return a string for the case_id or form_id, rather than a tuple as Django recommends for + all natural_key methods. We made this decision to optimize loading deserialized data back into a database. If + the natural_key method returns a tuple, it will use the get_by_natural_key method on the foreign key model's + default object manager to fetch the foreign keyed object, resulting in a database lookup everytime we write + a model that foreign keys to cases or forms in SqlDataLoader. + """ + + def test_serialized_foreign_key_field_referencing_User_returns_an_iterable(self): + from django.contrib.auth.models import User + + from corehq.apps.users.models import SQLUserData + user = User(username='testuser') + user_data = SQLUserData(django_user=user, data={'test': 1}) + + output_stream = StringIO() + with patch('corehq.apps.dump_reload.sql.dump.get_objects_to_dump', return_value=[user_data]): + SqlDataDumper('test', [], []).dump(output_stream) + + deserialized_model = json.loads(output_stream.getvalue()) + fk_field = deserialized_model['fields']['django_user'] + self.assertEqual(fk_field, ['testuser']) + + def test_serialized_foreign_key_field_referencing_CommCareCase_returns_a_str(self): + cc_case = CommCareCase(domain='test', case_id='abc123') + transaction = CaseTransaction(case=cc_case) + + output_stream = StringIO() + with patch('corehq.apps.dump_reload.sql.dump.get_objects_to_dump', return_value=[transaction]): + SqlDataDumper('test', [], []).dump(output_stream) + + deserialized_model = json.loads(output_stream.getvalue()) + fk_field = deserialized_model['fields']['case'] + self.assertEqual(fk_field, 'abc123') + + def test_serialized_foreign_key_field_referencing_XFormInstance_returns_a_str(self): + xform = XFormInstance(domain='test', form_id='abc123') + operation = XFormOperation(form=xform) + + output_stream = StringIO() + with patch('corehq.apps.dump_reload.sql.dump.get_objects_to_dump', return_value=[operation]): + SqlDataDumper('test', [], []).dump(output_stream) + + deserialized_model = json.loads(output_stream.getvalue()) + fk_field = deserialized_model['fields']['form'] + self.assertEqual(fk_field, 'abc123') diff --git a/corehq/apps/dump_reload/tests/test_sql_data_loader.py b/corehq/apps/dump_reload/tests/test_sql_data_loader.py new file mode 100644 index 000000000000..718e8a070a3b --- /dev/null +++ b/corehq/apps/dump_reload/tests/test_sql_data_loader.py @@ -0,0 +1,55 @@ +import json + +from django.contrib.auth.models import User +from django.test import TestCase + +from corehq.apps.dump_reload.sql.load import SqlDataLoader +from corehq.apps.users.models import SQLUserData +from corehq.form_processor.models.cases import CaseTransaction +from corehq.form_processor.tests.utils import create_case + + +class TestSqlDataLoader(TestCase): + + def test_loading_foreign_keys_using_iterable_natural_key(self): + user = User.objects.create(username='testuser') + model = { + "model": "users.sqluserdata", + "fields": { + "domain": "test", + "user_id": "testuser", + "django_user": ["testuser"], + "modified_on": "2024-01-01T12:00:00.000000Z", + "profile": None, + "data": {"test": "1"}, + }, + } + serialized_model = json.dumps(model) + + SqlDataLoader().load_objects([serialized_model]) + + user_data = SQLUserData.objects.get(django_user=user) + self.assertEqual(user_data.django_user.pk, user.pk) + + def test_loading_foreign_keys_using_non_iterable_natural_key(self): + # create_case will create a CaseTransaction too so test verifies the serialized one is saved properly + cc_case = create_case('test', case_id='abc123', save=True) + model = { + "model": "form_processor.casetransaction", + "fields": { + "case": "abc123", + "form_id": "fk-test", + "sync_log_id": None, + "server_date": "2024-01-01T12:00:00.000000Z", + "_client_date": None, + "type": 1, + "revoked": False, + "details": {}, + }, + } + serialized_model = json.dumps(model) + + SqlDataLoader().load_objects([serialized_model]) + + transaction = CaseTransaction.objects.partitioned_query('abc123').get(case=cc_case, form_id='fk-test') + self.assertEqual(transaction.case_id, 'abc123')