Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions services/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from django.contrib.gis.geos import Point
from django.contrib.gis.measure import D
from django.core.exceptions import ValidationError
from django.db.models import F, Prefetch, Q
from django.db.models import F, Prefetch, Q, Subquery
from django.http import Http404
from django.shortcuts import get_object_or_404, redirect
from django.template.loader import render_to_string
Expand Down Expand Up @@ -1160,11 +1160,12 @@ def validate_service_node_ids(service_node_ids):
mobility_service_nodes.split(",")
)
if mobility_service_node_ids:
queryset = queryset.filter(
matching_unit_ids = Unit.objects.filter(
mobility_service_nodes__in=service_nodes_by_ancestors(
mobility_service_node_ids, node_model=MobilityServiceNode
)
).distinct()
).values("id")
queryset = queryset.filter(id__in=Subquery(matching_unit_ids))

service_node_ids = None
if service_nodes:
Expand All @@ -1174,9 +1175,10 @@ def validate_service_node_ids(service_node_ids):
if level_specs["type"] == "include":
service_node_ids = level_specs["service_nodes"]
if service_node_ids:
queryset = queryset.filter(
matching_unit_ids = Unit.objects.filter(
service_nodes__in=service_nodes_by_ancestors(service_node_ids)
).distinct()
).values("id")
queryset = queryset.filter(id__in=Subquery(matching_unit_ids))

service_node_ids = None
val = filters.get("exclude_service_nodes", None)
Expand All @@ -1187,13 +1189,17 @@ def validate_service_node_ids(service_node_ids):
if level_specs["type"] == "exclude":
service_node_ids = level_specs["service_nodes"]
if service_node_ids:
queryset = queryset.exclude(
excluded_unit_ids = Unit.objects.filter(
service_nodes__in=service_nodes_by_ancestors(service_node_ids)
).distinct()
).values("id")
queryset = queryset.exclude(id__in=Subquery(excluded_unit_ids))

services = filters.get("service")
if services is not None:
queryset = queryset.filter(services__in=services.split(",")).distinct()
matching_unit_ids = Unit.objects.filter(
services__in=services.split(",")
).values("id")
queryset = queryset.filter(id__in=Subquery(matching_unit_ids))

if "division" in filters:
# Divisions can be specified with form:
Expand Down Expand Up @@ -1254,10 +1260,11 @@ def validate_service_node_ids(service_node_ids):
service_ids.append(value)
elif key == "service_node":
servicenode_ids.append(value)
queryset = queryset.filter(
matching_unit_ids = Unit.objects.filter(
Q(services__in=service_ids)
| Q(service_nodes__in=service_nodes_by_ancestors(servicenode_ids))
).distinct()
).values("id")
queryset = queryset.filter(id__in=Subquery(matching_unit_ids))

if "address" in filters:
language = filters["language"] if "language" in filters else "fi"
Expand Down
184 changes: 183 additions & 1 deletion services/tests/test_unit_view_set_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rest_framework.test import APIClient

from services.api import make_muni_ocd_id
from services.models import Department, MobilityServiceNode, ServiceNode, Unit
from services.models import Department, MobilityServiceNode, Service, ServiceNode, Unit
from services.models.unit import PROJECTION_SRID
from services.tests.utils import get

Expand Down Expand Up @@ -472,3 +472,185 @@ def test_heightprofilegeom_parameter(api_client):
results[4]["height_profile_geom"]["properties"]["label_fi"] == "Korkeusprofiili"
)
assert results[4]["height_profile_geom"]["properties"]["label_sv"] == "Höjdprofil"


@pytest.mark.django_db
def test_service_filtering(api_client):
"""
Test service filtering.
"""
create_units()

# Create services and associate them with units
service1 = Service.objects.create(
id=695, name="Service 1", last_modified_time=datetime.now(UTC_TIMEZONE)
)
service2 = Service.objects.create(
id=406, name="Service 2", last_modified_time=datetime.now(UTC_TIMEZONE)
)
service3 = Service.objects.create(
id=235, name="Service 3", last_modified_time=datetime.now(UTC_TIMEZONE)
)

unit1 = Unit.objects.get(id=1)
unit2 = Unit.objects.get(id=2)
unit3 = Unit.objects.get(id=3)

# Associate units with multiple services to test M2M optimization
service1.units.add(unit1, unit2)
service2.units.add(unit2, unit3)
service3.units.add(unit1)

# Test single service filtering
response = get(api_client, reverse("unit-list"), data={"service": "695"})
assert response.status_code == 200
assert response.data["count"] == 2
unit_ids = {result["id"] for result in response.data["results"]}
assert unit_ids == {1, 2}

# Test multiple service filtering (this was causing HARAKIRI due to distinct() performance issues)
response = get(api_client, reverse("unit-list"), data={"service": "695,406,235"})
assert response.status_code == 200
assert response.data["count"] == 3
unit_ids = {result["id"] for result in response.data["results"]}
assert unit_ids == {1, 2, 3}

# Test with non-existent service
response = get(api_client, reverse("unit-list"), data={"service": "999"})
assert response.status_code == 200
assert response.data["count"] == 0


@pytest.mark.django_db
def test_service_node_filtering(api_client):
"""
Test service node filtering.
"""
create_units()
service_node_1, service_node_2, service_node_3, service_node_4 = (
create_service_nodes()
)

unit1 = Unit.objects.get(id=1)
unit2 = Unit.objects.get(id=2)
unit3 = Unit.objects.get(id=3)

# Associate units with service nodes
unit1.service_nodes.add(service_node_2)
unit2.service_nodes.add(service_node_2, service_node_3)
unit3.service_nodes.add(service_node_3, service_node_4)

# Test multiple service node filtering
response = get(
api_client,
reverse("unit-list"),
data={
"service_node": f"{service_node_2.id},{service_node_3.id},{service_node_4.id}"
},
)
assert response.status_code == 200
assert response.data["count"] == 3
unit_ids = {result["id"] for result in response.data["results"]}
assert unit_ids == {1, 2, 3}


@pytest.mark.django_db
def test_exclude_service_nodes_filtering(api_client):
"""
Test exclude service nodes filtering.
"""
create_units()
service_node_1, service_node_2, service_node_3, service_node_4 = (
create_service_nodes()
)

unit1 = Unit.objects.get(id=1)
unit2 = Unit.objects.get(id=2)
unit3 = Unit.objects.get(id=3)
Unit.objects.get(id=4)
Unit.objects.get(id=7)

# Associate units with service nodes
unit1.service_nodes.add(service_node_2)
unit2.service_nodes.add(service_node_3)
unit3.service_nodes.add(service_node_4)

# Test excluding service nodes
response = get(
api_client,
reverse("unit-list"),
data={"exclude_service_nodes": f"{service_node_2.id},{service_node_3.id}"},
)
assert response.status_code == 200
# Should exclude units 1 and 2, leaving units 3, 4, 7
assert response.data["count"] == 3
unit_ids = {result["id"] for result in response.data["results"]}
assert unit_ids == {3, 4, 7}


@pytest.mark.django_db
def test_mobility_node_filtering(api_client):
"""
Test mobility node filtering.
"""
create_units()
mobility_node_1, mobility_node_2, mobility_node_3 = create_mobility_nodes()

unit1 = Unit.objects.get(id=1)
unit2 = Unit.objects.get(id=2)
unit3 = Unit.objects.get(id=3)

# Associate units with mobility service nodes
unit1.mobility_service_nodes.add(mobility_node_2)
unit2.mobility_service_nodes.add(mobility_node_2, mobility_node_3)
unit3.mobility_service_nodes.add(mobility_node_3)

# Test multiple mobility node filtering
response = get(
api_client,
reverse("unit-list"),
data={"mobility_node": f"{mobility_node_2.id},{mobility_node_3.id}"},
)
assert response.status_code == 200
assert response.data["count"] == 3
unit_ids = {result["id"] for result in response.data["results"]}
assert unit_ids == {1, 2, 3}


@pytest.mark.django_db
def test_category_filtering(api_client):
"""
Test category filtering.
"""
create_units()
service_node_1, service_node_2, service_node_3, service_node_4 = (
create_service_nodes()
)

# Create services
service1 = Service.objects.create(
id=100, name="Service 100", last_modified_time=datetime.now(UTC_TIMEZONE)
)
service2 = Service.objects.create(
id=200, name="Service 200", last_modified_time=datetime.now(UTC_TIMEZONE)
)

unit1 = Unit.objects.get(id=1)
unit2 = Unit.objects.get(id=2)
unit3 = Unit.objects.get(id=3)

# Associate units with services and service nodes
service1.units.add(unit1)
service2.units.add(unit2)
unit3.service_nodes.add(service_node_4)

# Test category filtering with both services and service nodes
response = get(
api_client,
reverse("unit-list"),
data={"category": f"service:100,service:200,service_node:{service_node_4.id}"},
)
assert response.status_code == 200
assert response.data["count"] == 3
unit_ids = {result["id"] for result in response.data["results"]}
assert unit_ids == {1, 2, 3}