From 674f3fca6991247d84276d3eb8c262532069bc6a Mon Sep 17 00:00:00 2001 From: Mika Hietanen Date: Thu, 12 Feb 2026 13:09:02 +0200 Subject: [PATCH 1/2] fix: prevent SQL injection in search queries Replace f-string SQL construction with parameterized queries using psycopg.sql module. Sanitize user-controlled values like language_short and config_language that were directly interpolated into SQL strings. - Use sql.SQL() with sql.Identifier() for column names - Use sql.Literal() for configuration values - Pass search_query_str and limits as query parameters - Apply fix to both conditional branches in main search query - Apply similar parameterization to address sorting query This ensures all user input is properly sanitized before being executed as SQL queries. Refs: PL-210 --- services/search/api.py | 78 +++++++++++++++++++++++++++++------------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/services/search/api.py b/services/search/api.py index 60f9b1b5..44afe0fe 100644 --- a/services/search/api.py +++ b/services/search/api.py @@ -29,6 +29,7 @@ from munigeo import api as munigeo_api from munigeo.models import Address, AdministrativeDivision from munigeo.utils import get_default_srid +from psycopg import sql from rest_framework import serializers, status from rest_framework.exceptions import ParseError from rest_framework.generics import GenericAPIView @@ -510,7 +511,11 @@ def get(self, request): except ValueError: raise ParseError("'sql_query_limit' need to be of type integer.") else: - sql_query_limit = DEFAULT_SEARCH_SQL_LIMIT_VALUE + sql_query_limit = ( + None + if DEFAULT_SEARCH_SQL_LIMIT_VALUE == "NULL" + else DEFAULT_SEARCH_SQL_LIMIT_VALUE + ) # Read values for limit values for each model for type_name in QUERY_PARAM_TYPE_NAMES: param_name = f"{type_name}_limit" @@ -551,18 +556,46 @@ def get(self, request): # This is ~100 times faster than using Django's SearchRank and allows searching # using wildcard "|*" and by ranking gives better results, e.g. extra fields # weight is counted. - sql = f""" - SELECT * from ( - SELECT id, type_name, name_{language_short}, ts_rank_cd(search_column_{language_short}, search_query) - AS rank FROM search_view, {search_fn}('{config_language}', %s) search_query - WHERE search_query @@ search_column_{language_short} - ORDER BY rank DESC LIMIT {sql_query_limit} - ) AS sub_query where sub_query.rank >= {rank_threshold}; - """ # noqa: E501 + if sql_query_limit is not None: + query = sql.SQL(""" + SELECT * FROM ( + SELECT id, type_name, {name_col}, + ts_rank_cd({search_col}, search_query) + AS rank FROM search_view, + {search_fn}({config_lang}, %s) search_query + WHERE search_query @@ {search_col_where} + ORDER BY rank DESC LIMIT %s + ) AS sub_query WHERE sub_query.rank >= %s + """).format( + name_col=sql.Identifier(f"name_{language_short}"), + search_col=sql.Identifier(f"search_column_{language_short}"), + search_col_where=sql.Identifier(f"search_column_{language_short}"), + search_fn=sql.Identifier(search_fn), + config_lang=sql.Literal(config_language), + ) + query_params = [search_query_str, sql_query_limit, rank_threshold] + else: + query = sql.SQL(""" + SELECT * FROM ( + SELECT id, type_name, {name_col}, + ts_rank_cd({search_col}, search_query) + AS rank FROM search_view, + {search_fn}({config_lang}, %s) search_query + WHERE search_query @@ {search_col_where} + ORDER BY rank DESC + ) AS sub_query WHERE sub_query.rank >= %s + """).format( + name_col=sql.Identifier(f"name_{language_short}"), + search_col=sql.Identifier(f"search_column_{language_short}"), + search_col_where=sql.Identifier(f"search_column_{language_short}"), + search_fn=sql.Identifier(search_fn), + config_lang=sql.Literal(config_language), + ) + query_params = [search_query_str, rank_threshold] cursor = connection.cursor() try: - cursor.execute(sql, [search_query_str]) + cursor.execute(query, query_params) except Exception as e: logger.error(f"Error in search query: {e}", exc_info=e) raise ParseError("Search query failed.") @@ -723,25 +756,22 @@ def get(self, request): # Use naturalsort function that is migrated to munigeo to # sort the addresses. if len(addresses_qs) > 0: - ids = [str(addr.id) for addr in addresses_qs] - # create string containing ids in format (1,4,2) - ids_str = ",".join(ids) - ids_str = f"({ids_str})" - sql = f""" - select id from munigeo_address where id in {ids_str} - order by naturalsort(full_name_{language_short}) asc; - """ + ids = [addr.id for addr in addresses_qs] + address_query = sql.SQL(""" + SELECT id FROM munigeo_address + WHERE id = ANY(%s) + ORDER BY naturalsort({full_name_col}) ASC + """).format(full_name_col=sql.Identifier(f"full_name_{language_short}")) cursor = connection.cursor() - cursor.execute(sql) + cursor.execute(address_query, [ids]) addresses = cursor.fetchall() - # addresses are in format e.g. [(12755,), (4067,)], remove comma and - # parenthesis - ids = [re.sub(r"[(,)]", "", str(a)) for a in addresses] + # addresses are in format e.g. [(12755,), (4067,)], extract ids + ids = [str(a[0]) for a in addresses] preserved = get_preserved_order(ids) addresses_qs = Address.objects.filter(id__in=ids).order_by(preserved) # if no units has been found without trigram search and addresses are - # found, - # do not return any units, thus they might confuse in the results. + # found, do not return any units, thus they might + # confuse in the results. if addresses_qs.exists() and show_only_address: units_qs = Unit.objects.none() else: From 3f9f256662e79ac3dd1b6f3e37d722dfc05ba1d1 Mon Sep 17 00:00:00 2001 From: Mika Hietanen Date: Thu, 12 Feb 2026 14:44:41 +0200 Subject: [PATCH 2/2] fix: mock geocode_address in test_address_filter The test was failing with a 509 error because it was making actual HTTP requests to the Nominatim geocoding service. Added @patch decorator to mock the geocode_address function, preventing external API calls and ensuring reliable test execution. The mock returns controlled coordinates to test both inside and outside the test area boundaries. Refs: PL-210 --- .../test_administrative_division_view_set_api.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/services/tests/test_administrative_division_view_set_api.py b/services/tests/test_administrative_division_view_set_api.py index 8a6fddc0..eb41d544 100644 --- a/services/tests/test_administrative_division_view_set_api.py +++ b/services/tests/test_administrative_division_view_set_api.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from django.conf import settings from django.contrib.gis.geos import MultiPolygon, Point, Polygon @@ -183,13 +185,18 @@ def test_municipality_filter(api_client): @pytest.mark.django_db -def test_address_filter(api_client): +@patch("services.api.geocode_address") +def test_address_filter(mock_geocode_address, api_client): create_administrative_divisions() division = AdministrativeDivision.objects.get(name="helsinki") AdministrativeDivisionGeometry.objects.create( division=division, boundary=create_test_area() ) + # Mock geocode_address to return coordinates inside the test area + # Test area is approximately: lat 60.159-60.178, lon 24.928-24.948 + mock_geocode_address.return_value = (60.168, 24.938) # Inside the test area + response = get( api_client, reverse("administrativedivision-list"), @@ -203,6 +210,9 @@ def test_address_filter(api_client): assert response.data["count"] == 1 assert response.data["results"][0]["municipality"] == "helsinki" + # Mock geocode_address to return coordinates outside the test area + mock_geocode_address.return_value = (60.150, 24.920) # Outside the test area + response = get( api_client, reverse("administrativedivision-list"),