Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions services/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,21 +1336,33 @@ def _add_content_disposition_header(self, response):
response["Content-Disposition"] = header
return response

def _get_unit(self, pk):
def _get_unit(self, pk, queryset=None):
try:
int(pk)
except ValueError:
raise Http404

if queryset is None:
queryset = Unit.objects.filter(public=True, is_active=True)

try:
unit = Unit.objects.get(pk=pk, public=True, is_active=True)
unit = queryset.get(pk=pk)
except Unit.DoesNotExist:
# When unit is not found by pk, try to find it via UnitAlias
# We must fetch the aliased unit through the queryset to maintain
# prefetch optimizations and filter constraints (public, is_active)
unit_alias = get_object_or_404(UnitAlias, second=pk)
unit = unit_alias.first
try:
unit = queryset.get(pk=unit_alias.first_id)
except Unit.DoesNotExist:
# Aliased unit exists but doesn't meet filter criteria
# (e.g., not public or not active)
raise Http404
return unit

def retrieve(self, request, pk=None):
unit = self._get_unit(pk)
queryset = self.get_queryset()
unit = self._get_unit(pk, queryset=queryset)
serializer = self.serializer_class(unit, context=self.get_serializer_context())
return Response(serializer.data)

Expand Down
132 changes: 131 additions & 1 deletion services/tests/test_unit_view_set_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,32 @@
AdministrativeDivisionType,
Municipality,
)
from pytest_django.asserts import assertNumQueries
from rest_framework.test import APIClient

from services.api import make_muni_ocd_id
from services.models import Department, MobilityServiceNode, Service, ServiceNode, Unit
from services.models import (
Department,
Keyword,
MobilityServiceNode,
Service,
ServiceNode,
Unit,
UnitAlias,
UnitConnection,
)
from services.models.unit import PROJECTION_SRID
from services.tests.utils import get

UTC_TIMEZONE = pytz.timezone("UTC")

# Expected database query counts for unit retrieve operations with proper prefetching.
# These represent the actual query counts after optimization to prevent N+1 issues.
# Before optimization: 64+ queries (N+1 for each related object)
# After optimization: ~12-14 queries (prefetch_related eliminates N+1)
EXPECTED_QUERIES_UNIT_RETRIEVE = 12
EXPECTED_QUERIES_UNIT_RETRIEVE_WITH_ALIAS = 14


def create_units():
municipality_id = "helsinki"
Expand Down Expand Up @@ -654,3 +671,116 @@ def test_category_filtering(api_client):
assert response.data["count"] == 3
unit_ids = {result["id"] for result in response.data["results"]}
assert unit_ids == {1, 2, 3}


@pytest.mark.django_db
def test_unit_retrieve_prevents_n_plus_1_queries(api_client):
"""
Test that retrieving a single unit does not cause N+1 queries.

This test ensures that UnitViewSet.retrieve() uses the optimized
get_queryset() with prefetch_related to avoid N+1 database queries.
"""
create_units()
service_nodes = create_service_nodes()
unit = Unit.objects.get(id=1)
unit.service_nodes.add(service_nodes[0], service_nodes[1])
keyword1 = Keyword.objects.create(name="keyword1")
keyword2 = Keyword.objects.create(name="keyword2")
unit.keywords.add(keyword1, keyword2)
UnitConnection.objects.create(
unit=unit,
name="Phone",
section_type=UnitConnection.PHONE_OR_EMAIL_TYPE,
)
UnitConnection.objects.create(
unit=unit,
name="Website",
section_type=UnitConnection.LINK_TYPE,
)

with assertNumQueries(EXPECTED_QUERIES_UNIT_RETRIEVE):
response = get(api_client, reverse("unit-detail", kwargs={"pk": unit.id}))

assert response.status_code == 200
assert response.data["id"] == unit.id


@pytest.mark.django_db
def test_unit_retrieve_with_include_parameter(api_client):
"""
Test that unit retrieve with include parameter still uses prefetching.
Verifies that the fix works correctly with query parameters.
"""
create_units()
service_nodes = create_service_nodes()
unit = Unit.objects.get(id=1)
unit.service_nodes.add(service_nodes[0])
UnitConnection.objects.create(
unit=unit,
name="Connection",
section_type=UnitConnection.PHONE_OR_EMAIL_TYPE,
)

with assertNumQueries(EXPECTED_QUERIES_UNIT_RETRIEVE):
response = get(
api_client,
reverse("unit-detail", kwargs={"pk": unit.id}),
data={"include": "service_nodes,connections"},
)

assert response.status_code == 200
assert "service_nodes" in response.data
assert "connections" in response.data


@pytest.mark.django_db
def test_unit_retrieve_via_alias_prevents_n_plus_1(api_client):
"""
Test that retrieving a unit via UnitAlias does not cause N+1 queries.

Verifies that when a unit is accessed via an alias, the prefetch
optimizations are still applied and filters (public, is_active) are
enforced. This prevents the N+1 issue from being reintroduced through
the alias path.
"""
create_units()
service_nodes = create_service_nodes()
unit = Unit.objects.get(id=1)
unit.service_nodes.add(service_nodes[0])
keyword = Keyword.objects.create(name="test_keyword")
unit.keywords.add(keyword)
UnitConnection.objects.create(
unit=unit,
name="Test Connection",
section_type=UnitConnection.PHONE_OR_EMAIL_TYPE,
)
UnitAlias.objects.create(first=unit, second=9999)

with assertNumQueries(EXPECTED_QUERIES_UNIT_RETRIEVE_WITH_ALIAS):
response = get(api_client, reverse("unit-detail", kwargs={"pk": 9999}))

assert response.status_code == 200
assert response.data["id"] == unit.id


@pytest.mark.django_db
def test_unit_alias_respects_public_and_active_filters(api_client):
"""
Test that UnitAlias lookups respect public and is_active filters.

Ensures that accessing a unit via alias still enforces the queryset
filters, preventing access to non-public or inactive units.
"""
create_units()

non_public_unit = Unit.objects.get(id=5)
inactive_unit = Unit.objects.get(id=6)
UnitAlias.objects.create(first=non_public_unit, second=8888)
UnitAlias.objects.create(first=inactive_unit, second=7777)

response = api_client.get(reverse("unit-detail", kwargs={"pk": 8888}))
assert response.status_code == 404

response = api_client.get(reverse("unit-detail", kwargs={"pk": 7777}))
assert response.status_code == 404