diff --git a/acapy_agent/protocols/out_of_band/v1_0/routes.py b/acapy_agent/protocols/out_of_band/v1_0/routes.py index 5826d9d1f5..ccd37a8bf9 100644 --- a/acapy_agent/protocols/out_of_band/v1_0/routes.py +++ b/acapy_agent/protocols/out_of_band/v1_0/routes.py @@ -18,6 +18,10 @@ from ....admin.request_context import AdminRequestContext from ....messaging.models.base import BaseModelError from ....messaging.models.openapi import OpenAPISchema +from ....messaging.models.paginated_query import ( + PaginatedQuerySchema, + get_paginated_query_params, +) from ....messaging.valid import UUID4_EXAMPLE, UUID4_VALIDATE from ....storage.error import StorageError, StorageNotFoundError from ...didcomm_prefix import DIDCommPrefix @@ -26,7 +30,7 @@ from .message_types import SPEC_URI from .messages.invitation import HSProto, InvitationMessage, InvitationMessageSchema from .models.invitation import InvitationRecordSchema -from .models.oob_record import OobRecordSchema +from .models.oob_record import OobRecord, OobRecordSchema LOGGER = logging.getLogger(__name__) @@ -240,6 +244,108 @@ class OobInvitationRecordMatchInfoSchema(OpenAPISchema): ) +class OobRecordListQueryStringSchema(PaginatedQuerySchema): + """Parameters and validators for OOB record list request query string.""" + + state = fields.Str( + required=False, + validate=validate.OneOf( + OobRecord.get_attributes_by_prefix("STATE_", walk_mro=True) + ), + metadata={ + "description": "OOB record state", + "example": OobRecord.STATE_INITIAL, + }, + ) + role = fields.Str( + required=False, + validate=validate.OneOf( + OobRecord.get_attributes_by_prefix("ROLE_", walk_mro=False) + ), + metadata={ + "description": "OOB record role", + "example": OobRecord.ROLE_SENDER, + }, + ) + connection_id = fields.Str( + required=False, + validate=UUID4_VALIDATE, + metadata={ + "description": "Connection identifier", + "example": UUID4_EXAMPLE, + }, + ) + invi_msg_id = fields.Str( + required=False, + validate=UUID4_VALIDATE, + metadata={ + "description": "Invitation message identifier", + "example": UUID4_EXAMPLE, + }, + ) + + +class OobRecordListSchema(OpenAPISchema): + """Result schema for OOB record list.""" + + results = fields.List( + fields.Nested(OobRecordSchema()), + required=True, + metadata={"description": "List of OOB records"}, + ) + + +@docs( + tags=["out-of-band"], + summary="Query OOB records", +) +@querystring_schema(OobRecordListQueryStringSchema()) +@response_schema(OobRecordListSchema(), 200, description="") +@tenant_authentication +async def oob_records_list(request: web.BaseRequest): + """Request handler for searching OOB records. + + Args: + request: aiohttp request object + + Returns: + The OOB record list response + + """ + context: AdminRequestContext = request["context"] + + tag_filter = { + k: request.query[k] + for k in ("connection_id", "invi_msg_id") + if request.query.get(k, "") != "" + } + post_filter = { + k: request.query[k] + for k in ("state", "role") + if request.query.get(k, "") != "" + } + + limit, offset, order_by, descending = get_paginated_query_params(request) + + profile = context.profile + try: + async with profile.session() as session: + records = await OobRecord.query( + session, + tag_filter, + limit=limit, + offset=offset, + order_by=order_by, + descending=descending, + post_filter_positive=post_filter, + ) + results = [record.serialize() for record in records] + except (StorageError, BaseModelError) as err: + raise web.HTTPBadRequest(reason=err.roll_up) from err + + return web.json_response({"results": results}) + + @docs(tags=["out-of-band"], summary="Fetch an existing Out-of-Band invitation.") @querystring_schema(OobIdQueryStringSchema()) @response_schema(InvitationRecordResponseSchema(), description="") @@ -414,6 +520,11 @@ async def register(app: web.Application): [ web.post("/out-of-band/create-invitation", invitation_create), web.post("/out-of-band/receive-invitation", invitation_receive), + web.get( + "/out-of-band/records", + oob_records_list, + allow_head=False, + ), web.get( "/out-of-band/invitations", invitation_fetch, diff --git a/acapy_agent/protocols/out_of_band/v1_0/tests/test_routes.py b/acapy_agent/protocols/out_of_band/v1_0/tests/test_routes.py index 1c787f94f0..0203140ff6 100644 --- a/acapy_agent/protocols/out_of_band/v1_0/tests/test_routes.py +++ b/acapy_agent/protocols/out_of_band/v1_0/tests/test_routes.py @@ -2,6 +2,7 @@ from .....admin.request_context import AdminRequestContext from .....connections.models.conn_record import ConnRecord +from .....storage.error import StorageError from .....tests import mock from .....utils.testing import create_test_profile from .. import routes as test_module @@ -232,6 +233,76 @@ async def test_invitation_receive_x(self): with self.assertRaises(test_module.web.HTTPBadRequest): await test_module.invitation_receive(self.request) + async def test_oob_records_list(self): + mock_record = mock.MagicMock( + serialize=mock.MagicMock(return_value={"oob_id": "test"}) + ) + with ( + mock.patch.object( + test_module.OobRecord, + "query", + mock.CoroutineMock(return_value=[mock_record]), + ), + mock.patch.object( + test_module.web, "json_response", mock.Mock() + ) as mock_json_response, + ): + await test_module.oob_records_list(self.request) + mock_json_response.assert_called_once_with( + {"results": [{"oob_id": "test"}]} + ) + + async def test_oob_records_list_with_filters(self): + self.request.query = { + "state": "initial", + "role": "sender", + "connection_id": "test-conn-id", + "invi_msg_id": "test-invi-id", + } + with ( + mock.patch.object( + test_module.OobRecord, + "query", + mock.CoroutineMock(return_value=[]), + ) as mock_query, + mock.patch.object( + test_module.web, "json_response", mock.Mock() + ) as mock_json_response, + ): + await test_module.oob_records_list(self.request) + mock_query.assert_called_once() + call_kwargs = mock_query.call_args + tag_filter = call_kwargs[0][1] + assert "connection_id" in tag_filter + assert "invi_msg_id" in tag_filter + assert "state" not in tag_filter + mock_json_response.assert_called_once_with({"results": []}) + + async def test_oob_records_list_with_pagination(self): + self.request.query = {"limit": "10", "offset": "5"} + with ( + mock.patch.object( + test_module.OobRecord, + "query", + mock.CoroutineMock(return_value=[]), + ) as mock_query, + mock.patch.object(test_module.web, "json_response", mock.Mock()), + ): + await test_module.oob_records_list(self.request) + mock_query.assert_called_once() + call_kwargs = mock_query.call_args[1] + assert call_kwargs["limit"] == 10 + assert call_kwargs["offset"] == 5 + + async def test_oob_records_list_storage_error(self): + with mock.patch.object( + test_module.OobRecord, + "query", + mock.CoroutineMock(side_effect=StorageError("test error")), + ): + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.oob_records_list(self.request) + async def test_register(self): mock_app = mock.MagicMock() mock_app.add_routes = mock.MagicMock()