diff --git a/events/api.py b/events/api.py index 4a16eb7ff..cee07e280 100644 --- a/events/api.py +++ b/events/api.py @@ -2408,6 +2408,16 @@ def _optimize_include(includes, queryset): "keywords__data_source", "keywords__publisher", ) + + if "images" in includes: + queryset = queryset.prefetch_related( + "images__created_by", + "images__data_source", + "images__last_modified_by", + "images__license", + "images__publisher", + ) + return queryset def get_queryset(self): diff --git a/events/tests/test_event_get.py b/events/tests/test_event_get.py index c573f3c46..9b973878b 100644 --- a/events/tests/test_event_get.py +++ b/events/tests/test_event_get.py @@ -12,12 +12,13 @@ from django.utils import timezone from django.utils.timezone import localtime from freezegun import freeze_time +from pytest_django.asserts import assertNumQueries from resilient_logger.models import ResilientLogEntry from rest_framework import status from events.models import Event, Language, License, PublicationStatus from events.tests.conftest import APIClient -from events.tests.factories import EventFactory, OfferFactory +from events.tests.factories import EventFactory, ImageFactory, OfferFactory from events.tests.utils import ( assert_fields_exist, create_super_event, @@ -1892,41 +1893,99 @@ def test_sub_events_increase_query_count_sanely( discover sub-sub events. Sub-sub events should also require 8 queries. """ - def get_num_queries(): - with CaptureQueriesContext(connections[DEFAULT_DB_ALIAS]) as queries: - response = get_list(api_client, query_string=qs) - assert response.status_code == status.HTTP_200_OK - - return len(queries) + def get_and_check_response(): + response = get_list(api_client, query_string=qs) + assert response.status_code == status.HTTP_200_OK + return response event_1 = EventFactory() event_2 = EventFactory() + def count_queries(): + with CaptureQueriesContext(connections[DEFAULT_DB_ALIAS]) as queries: + get_and_check_response() + return len(queries) + # Do one warmup list, there's some savepoint/insert happening # on first call related to test setup that would give higher # than expected number of queries. - get_num_queries() + count_queries() - base_count = get_num_queries() + base_count = count_queries() sub_event_1 = EventFactory(super_event=event_1) - one_sub_event_count = get_num_queries() - - assert one_sub_event_count == base_count + sub_event_query_count + with assertNumQueries(base_count + sub_event_query_count): + get_and_check_response() # More than one sub event should NOT increase number of queries sub_event_2 = EventFactory(super_event=event_2) - assert get_num_queries() == one_sub_event_count + one_sub_event_count = base_count + sub_event_query_count + with assertNumQueries(one_sub_event_count): + get_and_check_response() EventFactory(super_event=sub_event_1) - one_sub_sub_event_count = get_num_queries() - assert one_sub_sub_event_count == one_sub_event_count + sub_sub_event_query_count + one_sub_sub_event_count = one_sub_event_count + sub_sub_event_query_count + with assertNumQueries(one_sub_sub_event_count): + get_and_check_response() # More than one sub-sub event should NOT increase number of queries EventFactory(super_event=sub_event_2) EventFactory(super_event=sub_event_2) - assert get_num_queries() == one_sub_sub_event_count + with assertNumQueries(one_sub_sub_event_count): + get_and_check_response() + + +@pytest.mark.django_db +def test_images_with_include_uses_prefetch_optimization( + api_client, data_source, organization, user +): + """ + Test that include=images uses prefetch_related optimization for image + related fields (created_by, data_source, last_modified_by, license, publisher). + + The prefetch_related optimization in the EventViewSet._optimize_include method + should efficiently fetch all related fields for images, preventing N+1 query + problems when multiple events with images are listed. + """ + + def count_queries(query_string=""): + with CaptureQueriesContext(connections[DEFAULT_DB_ALIAS]) as queries: + response = get_list(api_client, query_string=query_string) + assert response.status_code == status.HTTP_200_OK + return len(queries) + + _license = License.objects.create(name="Test License", url="http://test.license") + + for i in range(3): + event = EventFactory(data_source=data_source, publisher=organization) + image = ImageFactory( + data_source=data_source, + publisher=organization, + created_by=user, + last_modified_by=user, + license=_license, + url=f"http://test.image/{i}.jpg", + ) + event.images.add(image) + + baseline_queries = count_queries(query_string="include=images") + + for i in range(3, 6): + event = EventFactory(data_source=data_source, publisher=organization) + image = ImageFactory( + data_source=data_source, + publisher=organization, + created_by=user, + last_modified_by=user, + license=_license, + url=f"http://test.image/{i}.jpg", + ) + event.images.add(image) + + with assertNumQueries(baseline_queries): + response = get_list(api_client, query_string="include=images") + assert response.status_code == status.HTTP_200_OK @pytest.mark.django_db