diff --git a/server/mergin/sync/public_api_v2_controller.py b/server/mergin/sync/public_api_v2_controller.py
index 2b7f124e..1070c830 100644
--- a/server/mergin/sync/public_api_v2_controller.py
+++ b/server/mergin/sync/public_api_v2_controller.py
@@ -50,7 +50,7 @@
from .storages.disk import move_to_tmp, save_to_file
from .utils import get_device_id, get_ip, get_user_agent, get_chunk_location
from .workspace import WorkspaceRole
-from ..utils import parse_order_params
+from ..utils import parse_order_params, get_schema_fields_map
@auth_required
@@ -445,11 +445,15 @@ def list_workspace_projects(workspace_id, page, per_page, order_params=None, q=N
projects = projects.filter(Project.name.ilike(f"%{q}%"))
if order_params:
- order_by_params = parse_order_params(Project, order_params)
+ schema_map = get_schema_fields_map(ProjectSchemaV2)
+ order_by_params = parse_order_params(
+ Project, order_params, field_map=schema_map
+ )
projects = projects.order_by(*order_by_params)
- result = projects.paginate(page, per_page).items
- total = projects.paginate(page, per_page).total
+ pagination = projects.paginate(page=page, per_page=per_page)
+ result = pagination.items
+ total = pagination.total
data = ProjectSchemaV2(many=True).dump(result)
return jsonify(projects=data, count=total, page=page, per_page=per_page), 200
diff --git a/server/mergin/tests/test_public_api_v2.py b/server/mergin/tests/test_public_api_v2.py
index 6e702f31..f6434cd2 100644
--- a/server/mergin/tests/test_public_api_v2.py
+++ b/server/mergin/tests/test_public_api_v2.py
@@ -684,6 +684,17 @@ def test_list_workspace_projects(client):
url + f"?page={page}&per_page={per_page}&q=1&order_params=created DESC"
)
assert response.json["projects"][0]["name"] == "project_10"
+ # using field name instead column names for sorting
+ p4 = Project.query.filter(Project.name == project_name).first()
+ p4.disk_usage = 1234567
+ db.session.commit()
+ response = client.get(url + f"?page=1&per_page=10&order_params=size DESC")
+ resp_data = json.loads(response.data)
+ assert resp_data["projects"][0]["name"] == project_name
+
+ # invalid order param
+ response = client.get(url + f"?page=1&per_page=10&order_params=invalid DESC")
+ assert response.status_code == 200
# no permissions to workspace
user2 = add_user("user", "password")
diff --git a/server/mergin/tests/test_utils.py b/server/mergin/tests/test_utils.py
index bf5f4666..00b3e1c6 100644
--- a/server/mergin/tests/test_utils.py
+++ b/server/mergin/tests/test_utils.py
@@ -7,6 +7,7 @@
import json
import pytest
from flask import url_for, current_app
+from marshmallow import Schema, fields
from sqlalchemy import desc
import os
from unittest.mock import patch
@@ -14,7 +15,7 @@
from pygeodiff import GeoDiff
from pathlib import PureWindowsPath
-from ..utils import save_diagnostic_log_file
+from ..utils import save_diagnostic_log_file, get_schema_fields_map
from ..sync.utils import (
is_reserved_word,
@@ -297,3 +298,27 @@ def test_save_diagnostic_log_file(client, app):
with open(saved_file_path, "r") as f:
content = f.read()
assert content == body.decode("utf-8")
+
+
+def test_get_schema_fields_map():
+ """Test that schema map correctly resolves DB attributes, keeps all fields, and ignores virtual fields."""
+
+ # dummy schema for testing
+ class TestSchema(Schema):
+ # standard field -> map 'name': 'name'
+ name = fields.String()
+ # aliased field -> map 'size': 'disk_usage
+ size = fields.Integer(attribute="disk_usage")
+ # virtual fields -> skip
+ version = fields.Function(lambda obj: "v1")
+ role = fields.Method("get_role")
+ # excluded field - set to None in schema inheritance -> skip
+ hidden_field = None
+
+ schema_map = get_schema_fields_map(TestSchema)
+
+ expected_map = {
+ "name": "name",
+ "size": "disk_usage",
+ }
+ assert schema_map == expected_map
diff --git a/server/mergin/utils.py b/server/mergin/utils.py
index 9acc6124..7b062770 100644
--- a/server/mergin/utils.py
+++ b/server/mergin/utils.py
@@ -1,6 +1,8 @@
# Copyright (C) Lutra Consulting Limited
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-MerginMaps-Commercial
+import logging
+
import math
from collections import namedtuple
from datetime import datetime, timedelta, timezone
@@ -8,11 +10,11 @@
import os
from flask import current_app
from flask_sqlalchemy import Model
+from marshmallow import Schema, fields
from pathvalidate import sanitize_filename
from sqlalchemy import Column, JSON
from sqlalchemy.sql.elements import UnaryExpression
-from typing import Optional
-
+from typing import Optional, Type
OrderParam = namedtuple("OrderParam", "name direction")
@@ -33,7 +35,7 @@ def split_order_param(order_param: str) -> Optional[OrderParam]:
def get_order_param(
- cls: Model, order_param: OrderParam, json_sort: dict = None
+ cls: Model, order_param: OrderParam, json_sort: dict = None, field_map: dict = None
) -> Optional[UnaryExpression]:
"""Return order by clause parameter for SQL query
@@ -43,15 +45,22 @@ def get_order_param(
:type order_param: OrderParam
:param json_sort: type mapping for sort by json field, e.g. '{"storage": "int"}', defaults to None
:type json_sort: dict
+ :param field_map: mapping for translating public field names to internal DB columns, e.g. '{"size": "disk_usage"}'
+ :type field_map: dict
"""
+ # translate field name to column name
+ db_column_name = order_param.name
+ if field_map and order_param.name in field_map:
+ db_column_name = field_map[order_param.name]
# find candidate for nested json sort
- if "." in order_param.name:
- col, attr = order_param.name.split(".")
+ if "." in db_column_name:
+ col, attr = db_column_name.split(".")
else:
- col = order_param.name
+ col = db_column_name
attr = None
order_attr = cls.__table__.c.get(col, None)
if not isinstance(order_attr, Column):
+ logging.warning("Ignoring invalid order parameter.")
return
# sort by key in JSON field
if attr:
@@ -80,7 +89,9 @@ def get_order_param(
return order_attr.desc()
-def parse_order_params(cls: Model, order_params: str, json_sort: dict = None):
+def parse_order_params(
+ cls: Model, order_params: str, json_sort: dict = None, field_map: dict = None
+) -> list[UnaryExpression]:
"""Convert order parameters in query string to list of order by clauses.
:param cls: Db model class
@@ -89,6 +100,8 @@ def parse_order_params(cls: Model, order_params: str, json_sort: dict = None):
:type order_params: str
:param json_sort: type mapping for sort by json field, e.g. '{"storage": "int"}', defaults to None
:type json_sort: dict
+ :param field_map: mapping response fields to database column names, e.g. '{"size": "disk_usage"}'
+ :type field_map: dict
:rtype: List[Column]
"""
@@ -97,7 +110,7 @@ def parse_order_params(cls: Model, order_params: str, json_sort: dict = None):
order_param = split_order_param(p)
if not order_param:
continue
- order_attr = get_order_param(cls, order_param, json_sort)
+ order_attr = get_order_param(cls, order_param, json_sort, field_map)
if order_attr is not None:
order_by_params.append(order_attr)
return order_by_params
@@ -135,3 +148,27 @@ def save_diagnostic_log_file(app: str, username: str, body: bytes) -> str:
f.write(content)
return file_name
+
+
+def get_schema_fields_map(schema: Type[Schema]) -> dict:
+ """
+ Creates a mapping of schema field names to corresponding DB columns.
+ This allows sorting by the API field name (e.g. 'size') while
+ actually sorting by the database column (e.g. 'disk_usage').
+ """
+ mapping = {}
+ for name, field in schema._declared_fields.items():
+ # some fields could have been overridden with None to be excluded
+ if not field:
+ continue
+ # skip virtual fields as DB cannot sort by them
+ if isinstance(
+ field, (fields.Function, fields.Method, fields.Nested, fields.List)
+ ):
+ continue
+ if field.attribute:
+ mapping[name] = field.attribute
+ # keep the map complete
+ else:
+ mapping[name] = name
+ return mapping
diff --git a/web-app/packages/lib/src/modules/project/views/ProjectsListViewTemplate.vue b/web-app/packages/lib/src/modules/project/views/ProjectsListViewTemplate.vue
index 5d91bebc..afa28bed 100644
--- a/web-app/packages/lib/src/modules/project/views/ProjectsListViewTemplate.vue
+++ b/web-app/packages/lib/src/modules/project/views/ProjectsListViewTemplate.vue
@@ -32,7 +32,7 @@ SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-MerginMaps-Commercial
class="w-full"
/>
-
+