Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various reliability fixes #76

Merged
merged 10 commits into from
Oct 15, 2024
7 changes: 5 additions & 2 deletions archive/frames/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_header_dict(self):
archive_settings.PUBLIC_DATE_KEY: self.public_date,
}

def as_dict(self, include_thumbnails=False):
def as_dict(self, include_thumbnails=False, include_related_frames=False):
ret_dict = model_to_dict(self, exclude=('related_frames', 'area'))
ret_dict['version_set'] = [v.as_dict() for v in self.version_set.all()]
ret_dict['url'] = self.url if self.version_set.exists() else None
Expand All @@ -158,7 +158,10 @@ def as_dict(self, include_thumbnails=False):

if self.area:
ret_dict['area'] = json.loads(self.area.geojson)
ret_dict['related_frames'] = list(self.related_frames.all().values_list('id', flat=True))
if include_thumbnails:
ret_dict['thumbnails'] = [t.as_dict() for t in Thumbnail.objects.filter(frame=self)]
if include_related_frames:
ret_dict['related_frames'] = list(self.related_frames.all().values_list('id', flat=True))
return ret_dict


Expand Down
25 changes: 25 additions & 0 deletions archive/frames/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,37 @@ def setUp(self):
self.client.force_login(user)
self.frames = FrameFactory.create_batch(5)
self.frame = self.frames[0]
# Used to check if the related frames are prefetched
self.patcher = patch('archive.frames.views.Prefetch')
self.mock_prefetch = self.patcher.start()

def test_get_frame(self):
response = self.client.get(reverse('frame-detail', args=(self.frame.id, )))
self.assertEqual(response.json()['basename'], self.frame.basename)
self.assertContains(response, 'related_frames')
self.assertEqual(self.mock_prefetch.call_count, 1)

def test_get_frame_exclude_related_frames(self):
response = self.client.get(reverse('frame-detail', args=(self.frame.id,)),
{'include_related_frames': False})

self.assertNotContains(response, 'related_frames')
self.assertEqual(self.mock_prefetch.call_count, 0)

def test_get_frame_list(self):
response = self.client.get(reverse('frame-list'))
self.assertEqual(response.json()['count'], 5)
self.assertContains(response, self.frame.basename)
self.assertContains(response, 'related_frames')
self.assertEqual(self.mock_prefetch.call_count, 1)

def test_get_frame_list_exclude_related_frames(self):
response = self.client.get(reverse('frame-list'),
{'include_related_frames': False})
self.assertEqual(response.json()['count'], 5)
self.assertContains(response, self.frame.basename)
self.assertNotContains(response, 'related_frames')
self.assertEqual(self.mock_prefetch.call_count, 0)

def test_get_frame_list_filter(self):
response = self.client.get(
Expand Down Expand Up @@ -75,6 +97,9 @@ def test_get_frame_list_exclude_empty_version_set(self):
self.assertEqual(Frame.objects.count(), 6)
self.assertEqual(len(response.json()['results']), 5)

def tearDown(self):
self.patcher.stop()


class TestFramePost(ReplicationTestCase):
def setUp(self):
Expand Down
7 changes: 6 additions & 1 deletion archive/frames/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ def archived_queue_payload(validated_data: dict, frame):
new_dictionary['area'] = validated_data.get('area').json if validated_data.get('area') else None
new_dictionary['basename'] = validated_data.get('basename')
new_dictionary['version_set'] = validated_data.get('version_set')
new_dictionary['filename'] = frame.filename
# construct filename from version_set
try:
filename = ''.join([validated_data.get('basename'), validated_data.get('version_set')[0].get('extension')])
except Exception:
filename = ''
new_dictionary['filename'] = filename
new_dictionary['frameid'] = frame.id
return new_dictionary

Expand Down
19 changes: 15 additions & 4 deletions archive/frames/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def get_queryset(self):
queryset = (
Frame.objects.exclude(observation_date=None)
.prefetch_related('version_set')
.prefetch_related(Prefetch('related_frames', queryset=Frame.objects.all().only('id')))
.prefetch_related('thumbnails')
)
# Only prefetch related frames if we're including them in the response
if self.request.query_params.get('include_related_frames', '').lower() != 'false':
queryset = queryset.prefetch_related(Prefetch('related_frames', queryset=Frame.objects.all().only('id')))
if self.request.user.is_superuser:
return queryset
elif self.request.user.is_authenticated:
Expand All @@ -84,16 +86,25 @@ def get_queryset(self):

# These two method overrides just force the use of the as_dict method for serialization for list and detail endpoints
def list(self, request, *args, **kwargs):
# TODO: Default to not include related frames once we've announced it to users
include_related_frames = True
if request.query_params.get('include_related_frames', '').lower() == 'false':
include_related_frames = False
include_thumbnails = True if request.query_params.get('include_thumbnails', '').lower() == 'true' else False

queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
include_thumbnails = True if request.query_params.get('include_thumbnails', '').lower() == 'true' else False
json_models = [model.as_dict(include_thumbnails) for model in page]
json_models = [model.as_dict(include_thumbnails, include_related_frames) for model in page]
json_models = [model for model in json_models if model['version_set']] # Filter out frames with no versions
return self.get_paginated_response(json_models)

def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
return Response(instance.as_dict())
include_thumbnails = True if request.query_params.get('include_thumbnails', '').lower() == 'true' else False
include_related_frames = True
if request.query_params.get('include_related_frames', '').lower() == 'false':
include_related_frames = False
return Response(instance.as_dict(include_thumbnails, include_related_frames))

def create(self, request):
basename = request.data.get('basename')
Expand Down
Loading