Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/ad/location-api-bulk-update' int…
Browse files Browse the repository at this point in the history
…o autostaging
  • Loading branch information
esoergel committed Feb 29, 2024
2 parents d48dcfa + 88a91b8 commit 4bd8d00
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 47 deletions.
61 changes: 50 additions & 11 deletions corehq/apps/api/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from django.urls import NoReverseMatch

from tastypie import http
from tastypie.exceptions import ImmediateHttpResponse, InvalidSortError
from tastypie.resources import Resource

from tastypie.exceptions import BadRequest, ImmediateHttpResponse, InvalidSortError
from tastypie.resources import Resource, convert_post_to_patch
from corehq import privileges, toggles
from corehq.apps.accounting.utils import domain_has_privilege
from corehq.apps.analytics.tasks import track_workflow
Expand Down Expand Up @@ -46,7 +45,8 @@ def create_response(self, request, data, response_class=HttpResponse, **response
# http://stackoverflow.com/questions/17280513/tastypie-json-header-to-use-utf-8
desired_format = self.determine_format(request)
serialized = self.serialize(request, data, desired_format)
return response_class(content=serialized, content_type=build_content_type(desired_format), **response_kwargs)
return response_class(content=serialized, content_type=build_content_type(desired_format),
**response_kwargs)

def determine_format(self, request):
format = super(JsonResourceMixin, self).determine_format(request)
Expand Down Expand Up @@ -117,6 +117,41 @@ def dispatch(self, request_type, request, **kwargs):
def get_required_privilege(self):
return privileges.API_ACCESS

def patch_list_replica(self, create_or_update_object, request=None, obj_limit=None, **kwargs):
"""
Exactly copied fromhttps://github.com/toastdriven/django-tastypie/blob/v0.9.14/tastypie/resources.py#L1466
(BSD licensed) and modified to call custom method `create_or_update_object` on each bundle
"""
request = convert_post_to_patch(request)
deserialized = self.deserialize(request, request.body,
format=request.META.get('CONTENT_TYPE', 'application/json'))

collection_name = self._meta.collection_name
if collection_name not in deserialized:
raise BadRequest("Invalid data sent: missing '%s'" % collection_name)

if len(deserialized[collection_name]) and 'put' not in self._meta.detail_allowed_methods:
raise ImmediateHttpResponse(response=http.HttpMethodNotAllowed())

bundles_seen = []
status = http.HttpAccepted

if obj_limit and obj_limit < len(deserialized[collection_name]):
raise BadRequest("Object count exceeds limit for PATCH method.")

for data in deserialized[collection_name]:
data = self.alter_deserialized_detail_data(request, data)
bundle = self.build_bundle(data=data, request=request)
try:
create_or_update_object(bundle=bundle, **self.remove_api_resource_names(kwargs))
except AssertionError as e:
status = http.HttpBadRequest
bundle.data['_id'] = str(e)
bundles_seen.append(bundle)

to_be_serialized = [bundle.data['_id'] for bundle in bundles_seen]
return self.create_response(request, to_be_serialized, response_class=status)


class SimpleSortableResourceMixin(object):
'''
Expand All @@ -131,7 +166,7 @@ class SimpleSortableResourceMixin(object):
and should also have a meta field `ordering` that specifies the allowed fields
_meta :: [str]
'''

def apply_sorting(self, obj_list, options=None):
Expand All @@ -152,10 +187,10 @@ def apply_sorting(self, obj_list, options=None):
field_name = field

# Map the field back to the actual attribute
if not field_name in self.fields:
if field_name not in self.fields:
raise InvalidSortError("No matching '%s' field for ordering on." % field_name)

if not field_name in self._meta.ordering:
if field_name not in self._meta.ordering:
raise InvalidSortError("The '%s' field does not allow ordering." % field_name)

if self.fields[field_name].attribute is None:
Expand All @@ -179,8 +214,11 @@ def get_list(self, request, **kwargs):
base_bundle = self.build_bundle(request=request)
objects = self.obj_get_list(bundle=base_bundle, **self.remove_api_resource_names(kwargs))
sorted_objects = self.apply_sorting(objects, options=request.GET)

paginator = self._meta.paginator_class(request.GET, sorted_objects, resource_uri=self.get_resource_list_uri(request, kwargs), limit=self._meta.limit, max_limit=self._meta.max_limit, collection_name=self._meta.collection_name)

paginator = self._meta.paginator_class(request.GET, sorted_objects,
resource_uri=self.get_resource_list_uri(request, kwargs),
limit=self._meta.limit, max_limit=self._meta.max_limit,
collection_name=self._meta.collection_name)
to_be_serialized = paginator.page()

# Dehydrate the bundles in preparation for serialization.
Expand All @@ -199,14 +237,15 @@ def get_resource_list_uri(self, request=None, **kwargs):
Exactly copied from https://github.com/toastdriven/django-tastypie/blob/v0.9.11/tastypie/resources.py#L601
(BSD licensed) and modified to use the kwargs.
(v0.9.14 combines get_resource_list_uri and get_resource_uri; this re-separates them to keep things simpler)
(v0.9.14 combines get_resource_list_uri and get_resource_uri; this re-separates them to keep
things simpler)
"""
kwargs = dict(kwargs)
kwargs['resource_name'] = self._meta.resource_name

if self._meta.api_name is not None:
kwargs['api_name'] = self._meta.api_name

try:
return self._build_reverse_url("api_dispatch_list", kwargs=kwargs)
except NoReverseMatch:
Expand Down
34 changes: 2 additions & 32 deletions corehq/apps/api/resources/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from tastypie.bundle import Bundle
from tastypie.exceptions import BadRequest, ImmediateHttpResponse, NotFound
from tastypie.http import HttpForbidden, HttpUnauthorized
from tastypie.resources import ModelResource, Resource, convert_post_to_patch
from tastypie.resources import ModelResource, Resource


from phonelog.models import DeviceReportEntry
Expand Down Expand Up @@ -374,37 +374,7 @@ def serialize(self, request, data, format, options=None):
return self._meta.serializer.serialize(data, format, options)

def patch_list(self, request=None, **kwargs):
"""
Exactly copied from https://github.com/toastdriven/django-tastypie/blob/v0.9.14/tastypie/resources.py#L1466
(BSD licensed) and modified to pass the kwargs to `obj_create` and support only create method
"""
request = convert_post_to_patch(request)
deserialized = self.deserialize(request, request.body,
format=request.META.get('CONTENT_TYPE', 'application/json'))

collection_name = self._meta.collection_name
if collection_name not in deserialized:
raise BadRequest("Invalid data sent: missing '%s'" % collection_name)

if len(deserialized[collection_name]) and 'put' not in self._meta.detail_allowed_methods:
raise ImmediateHttpResponse(response=http.HttpMethodNotAllowed())

bundles_seen = []
status = http.HttpAccepted
for data in deserialized[collection_name]:

data = self.alter_deserialized_detail_data(request, data)
bundle = self.build_bundle(data=data, request=request)
try:

self.obj_create(bundle=bundle, **self.remove_api_resource_names(kwargs))
except AssertionError as e:
status = http.HttpBadRequest
bundle.data['_id'] = str(e)
bundles_seen.append(bundle)

to_be_serialized = [bundle.data['_id'] for bundle in bundles_seen]
return self.create_response(request, to_be_serialized, response_class=status)
super().patch_list_replica(self.obj_create, request, **kwargs)

def post_list(self, request, **kwargs):
"""
Expand Down
4 changes: 4 additions & 0 deletions corehq/apps/api/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def _assert_auth_post_resource(self, url, post_data, content_type='application/j
response = self.client.post(url, post_data, content_type=content_type)
elif method == "PUT":
response = self.client.put(url, post_data, content_type=content_type)
elif method == "PATCH":
response = self.client.patch(url, post_data, content_type=content_type)
elif method == "DELETE":
response = self.client.delete(url, post_data, content_type=content_type)
self.assertEqual(response.status_code, failure_code)
Expand All @@ -170,6 +172,8 @@ def _assert_auth_post_resource(self, url, post_data, content_type='application/j
response = self.client.post(url, post_data, content_type=content_type, **headers)
elif method == "PUT":
response = self.client.put(url, post_data, content_type=content_type, **headers)
elif method == "PATCH":
response = self.client.patch(url, post_data, content_type=content_type, **headers)
elif method == "DELETE":
response = self.client.delete(url, post_data, content_type=content_type, **headers)
return response
23 changes: 19 additions & 4 deletions corehq/apps/locations/resources/v0_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@
)
from corehq.apps.locations.views import LocationFieldsView

from django.db.transaction import atomic
from django.utils.translation import gettext as _

from tastypie.exceptions import BadRequest


class LocationResource(v0_5.LocationResource):
resource_name = 'location'
patch_limit = 100

class Meta:
queryset = SQLLocation.active_objects.all()
detail_uri_name = 'location_id'
authentication = RequirePermissionAuthentication(HqPermissions.edit_locations)
list_allowed_methods = ['get', 'post']
list_allowed_methods = ['get', 'post', 'patch']
detail_allowed_methods = ['get', 'put']
always_return_data = True
include_resource_uri = False
Expand Down Expand Up @@ -52,16 +54,17 @@ def dehydrate(self, bundle):
def obj_create(self, bundle, **kwargs):
domain = kwargs['domain']
if 'name' not in bundle.data or 'location_type_code' not in bundle.data:
raise BadRequest("'name' and 'location_type_code' are required fields.")
raise BadRequest("'name' and 'location_type_code' are required fields when creating a new location.")
bundle.obj = SQLLocation(domain=domain)
self._update(bundle, domain, is_new_location=True)
return bundle

def obj_update(self, bundle, **kwargs):
location_id = kwargs.get('location_id') or bundle.data.pop('location_id')
try:
bundle.obj = SQLLocation.objects.get(location_id=kwargs['location_id'], domain=kwargs['domain'])
bundle.obj = SQLLocation.objects.get(location_id=location_id, domain=kwargs['domain'])
except SQLLocation.DoesNotExist:
raise BadRequest(_("Could not find location with given ID on the domain."))
raise BadRequest(_("Could not update: could not find location with given ID on the domain."))
self._update(bundle, kwargs['domain'], is_new_location=False)
return bundle

Expand Down Expand Up @@ -120,3 +123,15 @@ def _validate_new_parent(self, domain, location, parent):
if location.location_type not in parent_allowed_types:
raise BadRequest(_("Parent cannot have children of this location's type."))
self._validate_unique_among_siblings(location, location.name, parent)

@atomic
def patch_list(self, request, **kwargs):
def create_or_update(bundle, **kwargs):
if 'location_id' in bundle.data and SQLLocation.objects.filter(
location_id=bundle.data['location_id'], domain=kwargs['domain']).exists():
bundle = self.obj_update(bundle, **kwargs)
else:
bundle = self.obj_create(bundle, **kwargs)
bundle.data['_id'] = bundle.obj.location_id # For serialization
return bundle
return super().patch_list_replica(create_or_update, request, obj_limit=self.patch_limit, **kwargs)
68 changes: 68 additions & 0 deletions corehq/apps/locations/tests/test_api_v6.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,71 @@ def test_site_code_unique(self):
response = self._assert_auth_post_resource(self.single_endpoint(self.location2.location_id),
put_data, method='PUT')
self.assertEqual(response.status_code, 400)

def test_successful_patch_list(self):
patch_data = {
"objects": [
{
"name": "newtown",
"latitude": "31.41",
"location_type_code": self.child_type.code,
"parent_location_id": self.location1.location_id
},
{
"location_id": self.south_park.location_id,
"latitude": "32.42",
"parent_location_id": self.location1.location_id
}
]
}
response = self._assert_auth_post_resource(self.list_endpoint,
patch_data, method='PATCH')
self.assertEqual(response.status_code, 202)

self.assertTrue(SQLLocation.objects.filter(
domain=self.domain.name, name="newtown").exists())
newtown = SQLLocation.objects.get(domain=self.domain.name, name="newtown")
self.assertEqual(newtown.parent_location_id, self.location1.location_id)
self.assertEqual(float(newtown.latitude), 31.41)

updated_south_park = SQLLocation.objects.get(domain=self.domain.name, name=self.south_park.name)
self.assertEqual(float(updated_south_park.latitude), 32.42)
self.assertEqual(updated_south_park.parent_location_id, self.location1.location_id)

def test_patch_list_is_atomic(self):
patch_data = {
"objects": [
{
"name": "newtown",
"latitude": "31.41",
"location_type_code": self.child_type.code,
"parent_location_id": self.location1.location_id
},
{
"location_id": self.south_park.location_id,
"latitude": "32.42",
"parent_location_id": self.location2.location_id # Invalid parent
}
]
}

response = self._assert_auth_post_resource(self.list_endpoint,
patch_data, method='PATCH')
self.assertEqual(response.status_code, 400)
# "newtown" should not be created since the update to South Park failed
self.assertFalse(SQLLocation.objects.filter(
domain=self.domain.name, name="newtown").exists())

def test_patch_list_missing_location_id(self):
patch_data = {
"objects": [
{
"_id": self.south_park.location_id, # Incorrect ID key
"latitude": "32.42",
}
]
}

response = self._assert_auth_post_resource(self.list_endpoint,
patch_data, method='PATCH')
self.assertEqual(response.status_code, 400)

0 comments on commit 4bd8d00

Please sign in to comment.