Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 54 additions & 24 deletions services/search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from munigeo import api as munigeo_api
from munigeo.models import Address, AdministrativeDivision
from munigeo.utils import get_default_srid
from psycopg import sql
from rest_framework import serializers, status
from rest_framework.exceptions import ParseError
from rest_framework.generics import GenericAPIView
Expand Down Expand Up @@ -510,7 +511,11 @@ def get(self, request):
except ValueError:
raise ParseError("'sql_query_limit' need to be of type integer.")
else:
sql_query_limit = DEFAULT_SEARCH_SQL_LIMIT_VALUE
sql_query_limit = (
None
if DEFAULT_SEARCH_SQL_LIMIT_VALUE == "NULL"
else DEFAULT_SEARCH_SQL_LIMIT_VALUE
)
# Read values for limit values for each model
for type_name in QUERY_PARAM_TYPE_NAMES:
param_name = f"{type_name}_limit"
Expand Down Expand Up @@ -551,18 +556,46 @@ def get(self, request):
# This is ~100 times faster than using Django's SearchRank and allows searching
# using wildcard "|*" and by ranking gives better results, e.g. extra fields
# weight is counted.
sql = f"""
SELECT * from (
SELECT id, type_name, name_{language_short}, ts_rank_cd(search_column_{language_short}, search_query)
AS rank FROM search_view, {search_fn}('{config_language}', %s) search_query
WHERE search_query @@ search_column_{language_short}
ORDER BY rank DESC LIMIT {sql_query_limit}
) AS sub_query where sub_query.rank >= {rank_threshold};
""" # noqa: E501
if sql_query_limit is not None:
query = sql.SQL("""
SELECT * FROM (
SELECT id, type_name, {name_col},
ts_rank_cd({search_col}, search_query)
AS rank FROM search_view,
{search_fn}({config_lang}, %s) search_query
WHERE search_query @@ {search_col_where}
ORDER BY rank DESC LIMIT %s
) AS sub_query WHERE sub_query.rank >= %s
""").format(
name_col=sql.Identifier(f"name_{language_short}"),
search_col=sql.Identifier(f"search_column_{language_short}"),
search_col_where=sql.Identifier(f"search_column_{language_short}"),
search_fn=sql.Identifier(search_fn),
config_lang=sql.Literal(config_language),
)
query_params = [search_query_str, sql_query_limit, rank_threshold]
else:
query = sql.SQL("""
SELECT * FROM (
SELECT id, type_name, {name_col},
ts_rank_cd({search_col}, search_query)
AS rank FROM search_view,
{search_fn}({config_lang}, %s) search_query
WHERE search_query @@ {search_col_where}
ORDER BY rank DESC
) AS sub_query WHERE sub_query.rank >= %s
""").format(
name_col=sql.Identifier(f"name_{language_short}"),
search_col=sql.Identifier(f"search_column_{language_short}"),
search_col_where=sql.Identifier(f"search_column_{language_short}"),
search_fn=sql.Identifier(search_fn),
config_lang=sql.Literal(config_language),
)
query_params = [search_query_str, rank_threshold]

cursor = connection.cursor()
try:
cursor.execute(sql, [search_query_str])
cursor.execute(query, query_params)
except Exception as e:
logger.error(f"Error in search query: {e}", exc_info=e)
raise ParseError("Search query failed.")
Expand Down Expand Up @@ -723,25 +756,22 @@ def get(self, request):
# Use naturalsort function that is migrated to munigeo to
# sort the addresses.
if len(addresses_qs) > 0:
ids = [str(addr.id) for addr in addresses_qs]
# create string containing ids in format (1,4,2)
ids_str = ",".join(ids)
ids_str = f"({ids_str})"
sql = f"""
select id from munigeo_address where id in {ids_str}
order by naturalsort(full_name_{language_short}) asc;
"""
ids = [addr.id for addr in addresses_qs]
address_query = sql.SQL("""
SELECT id FROM munigeo_address
WHERE id = ANY(%s)
ORDER BY naturalsort({full_name_col}) ASC
""").format(full_name_col=sql.Identifier(f"full_name_{language_short}"))
cursor = connection.cursor()
cursor.execute(sql)
cursor.execute(address_query, [ids])
addresses = cursor.fetchall()
# addresses are in format e.g. [(12755,), (4067,)], remove comma and
# parenthesis
ids = [re.sub(r"[(,)]", "", str(a)) for a in addresses]
# addresses are in format e.g. [(12755,), (4067,)], extract ids
ids = [str(a[0]) for a in addresses]
preserved = get_preserved_order(ids)
addresses_qs = Address.objects.filter(id__in=ids).order_by(preserved)
# if no units has been found without trigram search and addresses are
# found,
# do not return any units, thus they might confuse in the results.
# found, do not return any units, thus they might
# confuse in the results.
if addresses_qs.exists() and show_only_address:
units_qs = Unit.objects.none()
else:
Expand Down
12 changes: 11 additions & 1 deletion services/tests/test_administrative_division_view_set_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import pytest
from django.conf import settings
from django.contrib.gis.geos import MultiPolygon, Point, Polygon
Expand Down Expand Up @@ -183,13 +185,18 @@ def test_municipality_filter(api_client):


@pytest.mark.django_db
def test_address_filter(api_client):
@patch("services.api.geocode_address")
def test_address_filter(mock_geocode_address, api_client):
create_administrative_divisions()
division = AdministrativeDivision.objects.get(name="helsinki")
AdministrativeDivisionGeometry.objects.create(
division=division, boundary=create_test_area()
)

# Mock geocode_address to return coordinates inside the test area
# Test area is approximately: lat 60.159-60.178, lon 24.928-24.948
mock_geocode_address.return_value = (60.168, 24.938) # Inside the test area

response = get(
api_client,
reverse("administrativedivision-list"),
Expand All @@ -203,6 +210,9 @@ def test_address_filter(api_client):
assert response.data["count"] == 1
assert response.data["results"][0]["municipality"] == "helsinki"

# Mock geocode_address to return coordinates outside the test area
mock_geocode_address.return_value = (60.150, 24.920) # Outside the test area

response = get(
api_client,
reverse("administrativedivision-list"),
Expand Down