diff --git a/acapy_agent/admin/tests/test_admin_server.py b/acapy_agent/admin/tests/test_admin_server.py index 8c8a8e8039..ef29e6b04f 100644 --- a/acapy_agent/admin/tests/test_admin_server.py +++ b/acapy_agent/admin/tests/test_admin_server.py @@ -15,6 +15,7 @@ from ...core.event_bus import Event from ...core.goal_code_registry import GoalCodeRegistry from ...core.protocol_registry import ProtocolRegistry +from ...didcomm_v2.protocol_registry import V2ProtocolRegistry from ...multitenant.error import MultitenantManagerError from ...storage.base import BaseStorage from ...storage.error import StorageNotFoundError @@ -339,6 +340,7 @@ async def test_import_routes(self): # for routes with associated tests, this shouldn't make a difference in coverage context = InjectionContext() context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) + context.injector.bind_instance(V2ProtocolRegistry, V2ProtocolRegistry()) context.injector.bind_instance(GoalCodeRegistry, GoalCodeRegistry()) await DefaultContextBuilder().load_plugins(context) server = await self.get_admin_server({"admin.admin_insecure_mode": True}, context) @@ -347,6 +349,7 @@ async def test_import_routes(self): async def test_register_external_plugin_x(self): context = InjectionContext() context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) + context.injector.bind_instance(V2ProtocolRegistry, V2ProtocolRegistry()) context.injector.bind_instance(GoalCodeRegistry, GoalCodeRegistry()) with self.assertLogs(level="ERROR") as logs: builder = DefaultContextBuilder( diff --git a/acapy_agent/config/default_context.py b/acapy_agent/config/default_context.py index 136c79791d..969469f650 100644 --- a/acapy_agent/config/default_context.py +++ b/acapy_agent/config/default_context.py @@ -8,6 +8,7 @@ from ..core.plugin_registry import PluginRegistry from ..core.profile import ProfileManager, ProfileManagerProvider from ..core.protocol_registry import ProtocolRegistry +from ..didcomm_v2.protocol_registry import V2ProtocolRegistry from ..protocols.actionmenu.v1_0.base_service import BaseMenuService from ..protocols.actionmenu.v1_0.driver_service import DriverMenuService from ..protocols.introduction.v0_1.base_service import BaseIntroductionService @@ -45,6 +46,7 @@ async def build_context(self) -> InjectionContext: # Global protocol registry context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) + context.injector.bind_instance(V2ProtocolRegistry, V2ProtocolRegistry()) # Global goal code registry context.injector.bind_instance(GoalCodeRegistry, GoalCodeRegistry()) @@ -129,6 +131,7 @@ async def load_plugins(self, context: InjectionContext): # Register standard protocol plugins if not self.settings.get("transport.disabled"): plugin_registry.register_package("acapy_agent.protocols") + plugin_registry.register_package("acapy_agent.protocols_v2") # Currently providing admin routes only plugin_registry.register_plugin("acapy_agent.holder") diff --git a/acapy_agent/connections/models/conn_peer_record.py b/acapy_agent/connections/models/conn_peer_record.py new file mode 100644 index 0000000000..4468da4c8c --- /dev/null +++ b/acapy_agent/connections/models/conn_peer_record.py @@ -0,0 +1,685 @@ +"""Handle connection information interface with non-secrets storage.""" + +import json +from enum import Enum +from typing import Any, Optional, Union, List + +from marshmallow import fields, validate + +from ...core.profile import ProfileSession +from ...messaging.models.base_record import BaseRecord, BaseRecordSchema +from ...messaging.valid import ( + GENERIC_DID_EXAMPLE, + GENERIC_DID_VALIDATE, + RAW_ED25519_2018_PUBLIC_KEY_EXAMPLE, + RAW_ED25519_2018_PUBLIC_KEY_VALIDATE, + UUID4_EXAMPLE, +) +from ...protocols.connections.v1_0.message_types import ARIES_PROTOCOL as CONN_PROTO +from ...protocols.connections.v1_0.message_types import ( + CONNECTION_INVITATION, + CONNECTION_REQUEST, +) +from ...protocols.connections.v1_0.messages.connection_invitation import ( + ConnectionInvitation, +) +from ...protocols.connections.v1_0.messages.connection_request import ConnectionRequest +from ...protocols.didcomm_prefix import DIDCommPrefix +from ...protocols.didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDEX_1_1 +from ...protocols.didexchange.v1_0.message_types import DIDEX_1_0 +from ...protocols.didexchange.v1_0.messages.request import DIDXRequest +from ...protocols.out_of_band.v1_0.messages.invitation import ( + InvitationMessage as OOBInvitation, +) +from ...storage.base import BaseStorage +from ...storage.error import StorageNotFoundError +from ...storage.record import StorageRecord + + +class PeerwiseRecord(BaseRecord): + """Represents a single pairwise connection.""" + + class Meta: + """ConnRecord metadata.""" + + schema_class = "MaybeStoredConnRecordSchema" + + SUPPORTED_PROTOCOLS = (CONN_PROTO, DIDEX_1_0, DIDEX_1_1) + + class Role(Enum): + """RFC 160 (inviter, invitee) = RFC 23 (responder, requester).""" + + REQUESTER = ("invitee", "requester") # == RFC 23 initiator, RFC 434 receiver + RESPONDER = ("inviter", "responder") # == RFC 160 initiator(!), RFC 434 sender + + @property + def rfc160(self): + """Return RFC 160 (connection protocol) nomenclature.""" + return self.value[0] + + @property + def rfc23(self): + """Return RFC 23 (DID exchange protocol) nomenclature.""" + return self.value[1] + + @classmethod + def get(cls, label: Union[str, "ConnRecord.Role"]): + """Get role enum for label.""" + if isinstance(label, str): + for role in ConnRecord.Role: + if label in role.value: + return role + elif isinstance(label, ConnRecord.Role): + return label + return None + + def flip(self): + """Return opposite interlocutor role: theirs for ours, ours for theirs.""" + return ( + ConnRecord.Role.REQUESTER + if self is ConnRecord.Role.RESPONDER + else ConnRecord.Role.RESPONDER + ) + + def __eq__(self, other: Union[str, "ConnRecord.Role"]) -> bool: + """Comparison between roles.""" + return self is ConnRecord.Role.get(other) + + class State(Enum): + """Collator for equivalent states between RFC 160 and RFC 23. + + On the connection record, the state has to serve for both RFCs. + Hence, internally, RFC23 requester/responder states collate to + their RFC160 condensed equivalent. + """ + + INIT = ("init", "start") + INVITATION = ("invitation", "invitation") + REQUEST = ("request", "request") + RESPONSE = ("response", "response") + COMPLETED = ("active", "completed") + ABANDONED = ("error", "abandoned") + + @property + def rfc160(self): + """Return RFC 160 (connection protocol) nomenclature.""" + return self.value[0] + + @property + def rfc23(self): + """Return RFC 23 (DID exchange protocol) nomenclature to record logic.""" + return self.value[1] + + def rfc23strict(self, their_role: "ConnRecord.Role"): + """Return RFC 23 (DID exchange protocol) nomenclature to role as per RFC.""" + + if not their_role or self in ( + ConnRecord.State.INIT, + ConnRecord.State.COMPLETED, + ConnRecord.State.ABANDONED, + ): + return self.value[1] + + if self is ConnRecord.State.REQUEST: + return self.value[1] + ( + "-sent" + if ConnRecord.Role.get(their_role) is ConnRecord.Role.RESPONDER + else "-received" + ) + else: + return self.value[1] + ( + "-received" + if ConnRecord.Role.get(their_role) is ConnRecord.Role.RESPONDER + else "-sent" + ) + + @classmethod + def get(cls, label: Union[str, "ConnRecord.State"]): + """Get state enum for label.""" + if isinstance(label, str): + for state in ConnRecord.State: + if label in state.value: + return state + elif isinstance(label, ConnRecord.State): + return label + return None + + def __eq__(self, other: Union[str, "ConnRecord.State"]) -> bool: + """Comparison between states.""" + return self is ConnRecord.State.get(other) + + RECORD_ID_NAME = "pairwise_id" + RECORD_TOPIC = "peer_connections" + LOG_STATE_FLAG = "debug.connections" + TAG_NAMES = { + "my_did", + "their_did", + #"request_id", + "invitation_msg_id", + } + + RECORD_TYPE = "peer_connection" + RECORD_TYPE_INVITATION = "connection_invitation" + RECORD_TYPE_REQUEST = "connection_request" + RECORD_TYPE_METADATA = "connection_metadata" + + INVITATION_MODE_ONCE = "once" + INVITATION_MODE_MULTI = "multi" + INVITATION_MODE_STATIC = "static" + + ACCEPT_MANUAL = "manual" + ACCEPT_AUTO = "auto" + + def __init__( + self, + *, + pairwise_id: Optional[str] = None, + my_did: Optional[str] = None, + their_did: Optional[str] = None, + #their_label: Optional[str] = None, + endpoints: Optional[List[str]] = None, + invitation_msg_id: Optional[str] = None, + accept: Optional[str] = None, + alias: Optional[str] = None, + aka: Optional[List[str]] = None, + **kwargs, + ): + """Initialize a new ConnRecord.""" + super().__init__( + pairwise_id, + **kwargs, + ) + self.my_did = my_did + self.their_did = their_did + #self.their_label = their_label + self.invitation_msg_id = invitation_msg_id + self.accept = accept or self.ACCEPT_MANUAL + self.endpoints = endpoints + self.alias = alias + self.aka = aka + + @property + def pairwise_id(self) -> str: + """Accessor for the ID associated with this connection.""" + return self._id + + @property + def record_value(self) -> dict: + """Accessor to for the JSON record value properties for this connection.""" + return { + prop: getattr(self, prop) + for prop in ( + "accept", + "invitation_msg_id", + "alias", + #"their_label", + ) + } + + @classmethod + async def retrieve_by_did( + cls, + session: ProfileSession, + their_did: Optional[str] = None, + my_did: Optional[str] = None, + their_role: Optional[str] = None, + ) -> "ConnRecord": + """Retrieve a connection record by target DID. + + Args: + session: The active profile session + their_did: The target DID to filter by + my_did: One of our DIDs to filter by + my_role: Filter connections by their role + their_role: Filter connections by their role + """ + tag_filter = {} + if their_did: + tag_filter["their_did"] = their_did + if my_did: + tag_filter["my_did"] = my_did + + post_filter = {} + if their_role: + post_filter["their_role"] = cls.Role.get(their_role).rfc160 + + return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter) + + @classmethod + async def retrieve_by_did_peer_4( + cls, + session: ProfileSession, + their_did_long: Optional[str] = None, + their_did_short: Optional[str] = None, + my_did: Optional[str] = None, + their_role: Optional[str] = None, + ) -> "ConnRecord": + """Retrieve a connection record by target DID. + + Args: + session: The active profile session + their_did_long: The target DID to filter by, in long form + their_did_short: The target DID to filter by, in short form + my_did: One of our DIDs to filter by + my_role: Filter connections by their role + their_role: Filter connections by their role + """ + tag_filter = {} + if their_did_long and their_did_short: + tag_filter["$or"] = [ + {"their_did": their_did_long}, + {"their_did": their_did_short}, + ] + elif their_did_short: + tag_filter["their_did"] = their_did_short + elif their_did_long: + tag_filter["their_did"] = their_did_long + if my_did: + tag_filter["my_did"] = my_did + + post_filter = {} + if their_role: + post_filter["their_role"] = cls.Role.get(their_role).rfc160 + + return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter) + + @classmethod + async def retrieve_by_invitation_key( + cls, + session: ProfileSession, + invitation_key: str, + their_role: Optional[str] = None, + ) -> "ConnRecord": + """Retrieve a connection record by invitation key. + + Args: + session: The active profile session + invitation_key: The key on the originating invitation + their_role: Filter by their role + """ + tag_filter = { + "invitation_key": invitation_key, + "state": cls.State.INVITATION.rfc160, + } + post_filter = {"state": cls.State.INVITATION.rfc160} + + if their_role: + post_filter["their_role"] = cls.Role.get(their_role).rfc160 + tag_filter["their_role"] = cls.Role.get(their_role).rfc160 + + return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter) + + @classmethod + async def retrieve_by_invitation_msg_id( + cls, + session: ProfileSession, + invitation_msg_id: str, + their_role: Optional[str] = None, + ) -> Optional["ConnRecord"]: + """Retrieve a connection record by invitation_msg_id. + + Args: + session: The active profile session + invitation_msg_id: Invitation message identifier + their_role: Filter by their role + """ + tag_filter = {"invitation_msg_id": invitation_msg_id} + post_filter = { + "state": cls.State.INVITATION.rfc160, + } + if their_role: + post_filter["their_role"] = cls.Role.get(their_role).rfc160 + try: + return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter) + except StorageNotFoundError: + return None + + @classmethod + async def retrieve_by_request_id( + cls, session: ProfileSession, request_id: str, their_role: Optional[str] = None + ) -> "ConnRecord": + """Retrieve a connection record from our previous request ID. + + Args: + session: The active profile session + request_id: The ID of the originating connection request + their_role: Filter by their role + """ + tag_filter = {"request_id": request_id} + if their_role: + tag_filter["their_role"] = their_role + return await cls.retrieve_by_tag_filter(session, tag_filter) + + @classmethod + async def retrieve_by_alias(cls, session: ProfileSession, alias: str) -> "ConnRecord": + """Retrieve a connection record from an alias. + + Args: + session: The active profile session + alias: The alias of the connection + """ + post_filter = {"alias": alias} + return await cls.query(session, post_filter_positive=post_filter) + + async def attach_invitation( + self, + session: ProfileSession, + invitation: Union[ConnectionInvitation, OOBInvitation], + ): + """Persist the related connection invitation to storage. + + Args: + session: The active profile session + invitation: The invitation to relate to this connection record + """ + assert self.pairwise_id + record = StorageRecord( + self.RECORD_TYPE_INVITATION, # conn- or oob-invitation, to retrieve easily + invitation.to_json(), + {"pairwise_id": self.pairwise_id}, + ) + storage = session.inject(BaseStorage) + await storage.add_record(record) + + async def retrieve_invitation( + self, session: ProfileSession + ) -> Union[ConnectionInvitation, OOBInvitation]: + """Retrieve the related connection invitation. + + Args: + session: The active profile session + """ + assert self.pairwise_id + storage = session.inject(BaseStorage) + result = await storage.find_record( + self.RECORD_TYPE_INVITATION, + {"pairwise_id": self.pairwise_id}, + ) + ser = json.loads(result.value) + return ( + ConnectionInvitation + if DIDCommPrefix.unqualify(ser["@type"]) == CONNECTION_INVITATION + else OOBInvitation + ).deserialize(ser) + + async def attach_request( + self, + session: ProfileSession, + request: Union[ConnectionRequest, DIDXRequest], + ): + """Persist the related connection request to storage. + + Args: + session: The active profile session + request: The request to relate to this connection record + """ + assert self.pairwise_id + record = StorageRecord( + self.RECORD_TYPE_REQUEST, # conn- or didx-request, to retrieve easily + request.to_json(), + {"pairwise_id": self.pairwise_id}, + ) + storage: BaseStorage = session.inject(BaseStorage) + await storage.add_record(record) + + async def retrieve_request( + self, + session: ProfileSession, + ) -> Union[ConnectionRequest, DIDXRequest]: + """Retrieve the related connection invitation. + + Args: + session: The active profile session + """ + assert self.pairwise_id + storage: BaseStorage = session.inject(BaseStorage) + result = await storage.find_record( + self.RECORD_TYPE_REQUEST, {"pairwise_id": self.pairwise_id} + ) + ser = json.loads(result.value) + return ( + ConnectionRequest + if DIDCommPrefix.unqualify(ser["@type"]) == CONNECTION_REQUEST + else DIDXRequest + ).deserialize(ser) + + @property + def is_ready(self) -> str: + """Accessor for connection readiness.""" + return ConnRecord.State.get(self.state) in ( + ConnRecord.State.COMPLETED, + ConnRecord.State.RESPONSE, + ) + + @property + def is_multiuse_invitation(self) -> bool: + """Accessor for multi use invitation mode.""" + return self.invitation_mode == self.INVITATION_MODE_MULTI + + async def post_save(self, session: ProfileSession, *args, **kwargs): + """Perform post-save actions. + + Args: + session: The active profile session + args: Additional positional arguments + kwargs: Additional keyword arguments + """ + await super().post_save(session, *args, **kwargs) + + # clear cache key set by connection manager + cache_key = f"pairwise_connection::{self.pairwise_id}" + await self.clear_cached_key(session, cache_key) + + async def delete_record(self, session: ProfileSession): + """Perform connection record deletion actions. + + Args: + session (ProfileSession): session + + """ + await super().delete_record(session) + + storage = session.inject(BaseStorage) + # Delete metadata + if self.pairwise_id: + await storage.delete_all_records( + self.RECORD_TYPE_METADATA, + {"pairwise_id": self.pairwise_id}, + ) + + # Delete attached messages + await storage.delete_all_records( + self.RECORD_TYPE_REQUEST, + {"pairwise_id": self.pairwise_id}, + ) + await storage.delete_all_records( + self.RECORD_TYPE_INVITATION, + {"pairwise_id": self.pairwise_id}, + ) + + async def abandon(self, session: ProfileSession, *, reason: Optional[str] = None): + """Set state to abandoned.""" + reason = reason or "Connection abandoned" + self.state = ConnRecord.State.ABANDONED.rfc160 + self.error_msg = reason + await self.save(session, reason=reason) + + async def metadata_get( + self, session: ProfileSession, key: str, default: Optional[Any] = None + ) -> Any: + """Retrieve arbitrary metadata associated with this connection. + + Args: + session (ProfileSession): session used for storage + key (str): key identifying metadata + default (Any): default value to get; type should be a JSON + compatible value. + + Returns: + Any: metadata stored by key + + """ + assert self.pairwise_id + storage: BaseStorage = session.inject(BaseStorage) + try: + record = await storage.find_record( + self.RECORD_TYPE_METADATA, + {"key": key, "pairwise_id": self.pairwise_id}, + ) + return json.loads(record.value) + except StorageNotFoundError: + return default + + async def metadata_set(self, session: ProfileSession, key: str, value: Any): + """Set arbitrary metadata associated with this connection. + + Args: + session (ProfileSession): session used for storage + key (str): key identifying metadata + value (Any): value to set + """ + assert self.pairwise_id + value = json.dumps(value) + storage: BaseStorage = session.inject(BaseStorage) + try: + record = await storage.find_record( + self.RECORD_TYPE_METADATA, + {"key": key, "pairwise_id": self.pairwise_id}, + ) + await storage.update_record(record, value, record.tags) + except StorageNotFoundError: + record = StorageRecord( + self.RECORD_TYPE_METADATA, + value, + {"key": key, "pairwise_id": self.pairwise_id}, + ) + await storage.add_record(record) + + async def metadata_delete(self, session: ProfileSession, key: str): + """Delete custom metadata associated with this connection. + + Args: + session (ProfileSession): session used for storage + key (str): key of metadata to delete + """ + assert self.pairwise_id + storage: BaseStorage = session.inject(BaseStorage) + try: + record = await storage.find_record( + self.RECORD_TYPE_METADATA, + {"key": key, "pairwise_id": self.pairwise_id}, + ) + await storage.delete_record(record) + except StorageNotFoundError as err: + raise KeyError(f"{key} not found in connection metadata") from err + + async def metadata_get_all(self, session: ProfileSession) -> dict: + """Return all custom metadata associated with this connection. + + Args: + session (ProfileSession): session used for storage + + Returns: + dict: dictionary representation of all metadata values + + """ + assert self.pairwise_id + storage: BaseStorage = session.inject(BaseStorage) + records = await storage.find_all_records( + self.RECORD_TYPE_METADATA, + {"pairwise_id": self.pairwise_id}, + ) + return {record.tags["key"]: json.loads(record.value) for record in records} + + def __eq__(self, other: Any) -> bool: + """Comparison between records.""" + return super().__eq__(other) + + +class MaybeStoredConnRecordSchema(BaseRecordSchema): + """Schema to allow serialization/deserialization of connection records.""" + + class Meta: + """MaybeStoredConnRecordSchema metadata.""" + + model_class = PeerwiseRecord + + pairwise_id = fields.Str( + required=False, + metadata={"description": "Connection identifier", "example": UUID4_EXAMPLE}, + ) + my_did = fields.Str( + required=False, + validate=GENERIC_DID_VALIDATE, + metadata={ + "description": "Our DID for connection", + "example": GENERIC_DID_EXAMPLE, + }, + ) + their_did = fields.Str( + required=False, + validate=GENERIC_DID_VALIDATE, + metadata={ + "description": "Their DID for connection", + "example": GENERIC_DID_EXAMPLE, + }, + ) + endpoints = fields.List( + fields.Str(), + required=False, + metadata={ + "description": "list of endpoints that this peer-wise contact can be reached through", + "example": ["did:example:bob", "https://example.com/didcomm"], + }, + ) + #their_label = fields.Str( + # required=False, + # metadata={"description": "Their label for connection", "example": "Bob"}, + #) + invitation_msg_id = fields.Str( + required=False, + metadata={ + "description": "ID of out-of-band invitation message", + "example": UUID4_EXAMPLE, + }, + ) + accept = fields.Str( + required=False, + validate=validate.OneOf( + PeerwiseRecord.get_attributes_by_prefix("ACCEPT_", walk_mro=False) + ), + metadata={ + "description": "Connection acceptance: manual or auto", + "example": PeerwiseRecord.ACCEPT_AUTO, + }, + ) + alias = fields.Str( + required=False, + metadata={ + "description": "Optional alias to apply to connection for later use", + "example": "Bob, providing quotes", + }, + ) + aka = fields.List( + fields.Str(), + required=False, + metadata={ + "description": "Optional list of DIDs that this peer-wise contact is known as", + "example": ["did:example:bob", "did:example:bob-phone"], + }, + ) + + + +class PeerwiseRecordSchema(MaybeStoredConnRecordSchema): + """Schema representing stored ConnRecords.""" + + class Meta: + """ConnRecordSchema metadata.""" + + model_class = PeerwiseRecord + + pairwise_id = fields.Str( + required=True, + metadata={"description": "Connection identifier", "example": UUID4_EXAMPLE}, + ) diff --git a/acapy_agent/core/dispatcher.py b/acapy_agent/core/dispatcher.py index 962b69d97c..63f50e486c 100644 --- a/acapy_agent/core/dispatcher.py +++ b/acapy_agent/core/dispatcher.py @@ -12,9 +12,11 @@ from typing import Callable, Coroutine, Optional, Union from aiohttp.web import HTTPException +from didcomm_messaging import DIDCommMessaging, RoutingService from ..connections.base_manager import BaseConnectionManager from ..connections.models.conn_record import ConnRecord +from ..connections.models.connection_target import ConnectionTarget from ..core.profile import Profile from ..messaging.agent_message import AgentMessage from ..messaging.base_message import BaseMessage, DIDCommVersion @@ -28,12 +30,14 @@ from ..transport.inbound.message import InboundMessage from ..transport.outbound.message import OutboundMessage from ..transport.outbound.status import OutboundSendStatus +from ..storage.error import StorageNotFoundError from ..utils.classloader import DeferLoad from ..utils.stats import Collector from ..utils.task_queue import CompletedTask, PendingTask, TaskQueue from ..utils.tracing import get_timer, trace_event from .error import ProtocolMinorVersionNotSupported from .protocol_registry import ProtocolRegistry +from ..didcomm_v2.protocol_registry import V2ProtocolRegistry class ProblemReportParseError(MessageParseError): @@ -137,9 +141,109 @@ async def handle_v2_message( ): """Handle a DIDComm V2 message.""" + error_result = None + message = None + + session = await profile.session() + from ..connections.models.conn_peer_record import PeerwiseRecord + try: + existing_record = await PeerwiseRecord.retrieve_by_did( + session, + their_did=inbound_message.receipt.sender_verkey, + my_did=inbound_message.receipt.recipient_verkey, + ) + except StorageNotFoundError as err: + existing_record = None + #raise web.HTTPBadRequest(reason=err.roll_up) from err + + if existing_record: + peer = existing_record + else: + peer = PeerwiseRecord(their_did=inbound_message.receipt.sender_verkey, my_did=inbound_message.receipt.recipient_verkey) + + await profile.notify( + "acapy::webhook::pairwise_did", + { + "pairwise_id": peer.pairwise_id, + "status": "connected", + "recipient_did": inbound_message.receipt.sender_verkey, + "message": inbound_message.payload, + }, + ) + + try: + message = await self.make_v2_message(profile, inbound_message.payload) + except ProblemReportParseError: + pass # avoid problem report recursion + except MessageParseError as e: + self.logger.error(f"Message parsing failed: {str(e)}, sending problem report") + error_result = ProblemReport( + description={ + "en": str(e), + "code": "message-parse-failure", + } + ) + if inbound_message.receipt.thread_id: + error_result.assign_thread_id(inbound_message.receipt.thread_id) + + if not existing_record: # or existing_record["updated_at"] # Caching to not hammer did:web + messaging = session.inject(DIDCommMessaging) + routing_service = session.inject(RoutingService) + frm = inbound_message.payload.get("from") + + services = await routing_service._resolve_services(messaging.resolver, frm) + chain = [ + { + "did": frm, + "service": services, + } + ] + + # Loop through service DIDs until we run out of DIDs to forward to + to_did = services[0].service_endpoint.uri + found_forwardable_service = await routing_service.is_forwardable_service( + messaging.resolver, services[0] + ) + while found_forwardable_service: + services = await routing_service._resolve_services(messaging.resolver, to_did) + if services: + chain.append( + { + "did": to_did, + "service": services, + } + ) + to_did = services[0].service_endpoint.uri + found_forwardable_service = ( + await routing_service.is_forwardable_service( + messaging.resolver, services[0] + ) + if services + else False + ) + peer.endpoints = [ + service.service_endpoint.uri + for service in chain[-1]["service"] + if "didcomm/v2" in service.accept + ] + await peer.save(session) + elif not existing_record: + await peer.save(session) + + reply_destination = [ + ConnectionTarget( + did=inbound_message.receipt.sender_verkey, + endpoint=service.service_endpoint.uri, + recipient_keys=[inbound_message.receipt.sender_verkey], + sender_key=inbound_message.receipt.recipient_verkey, + ) + for service in chain[-1]["service"] + ] + # send a DCV2 Problem Report here for testing, and to punt procotol handling down # the road a bit context = RequestContext(profile) + context.message = message context.message_receipt = inbound_message.receipt responder = DispatcherResponder( context, @@ -147,21 +251,80 @@ async def handle_v2_message( send_outbound, reply_session_id=inbound_message.session_id, reply_to_verkey=inbound_message.receipt.sender_verkey, + target_list=reply_destination, ) context.injector.bind_instance(BaseResponder, responder) - error_result = V2AgentMessage( - message={ - "type": "https://didcomm.org/report-problem/2.0/problem-report", - "body": { - "comment": "No Handlers Found", - "code": "e.p.msg.not-found", - }, - } - ) - if inbound_message.receipt.thread_id: - error_result.message["pthid"] = inbound_message.receipt.thread_id - await responder.send_reply(error_result) + if not message: + error_result = V2AgentMessage( + message={ + "type": "https://didcomm.org/report-problem/2.0/problem-report", + "body": { + "comment": "No Handlers Found", + "code": "e.p.msg.not-found", + }, + "from": inbound_message.receipt.recipient_verkey, + "to": [inbound_message.receipt.sender_verkey], + } + ) + if inbound_message.receipt.thread_id: + error_result.message["pthid"] = inbound_message.receipt.thread_id + + if error_result: + await responder.send_reply(error_result) + elif context.message: + context.injector.bind_instance(BaseResponder, responder) + + handler = context.message + if self.collector: + handler = self.collector.wrap_coro(handler, [handler.__qualname__]) + await handler()(context, responder, payload=inbound_message.payload) + + async def make_v2_message(self, profile: Profile, parsed_msg: dict) -> BaseMessage: + """Deserialize a message dict into the appropriate message instance. + + Given a dict describing a message, this method + returns an instance of the related message class. + + Args: + parsed_msg: The parsed message + profile: Profile + + Returns: + An instance of the corresponding message class for this message + + Raises: + MessageParseError: If the message doesn't specify @type + MessageParseError: If there is no message class registered to handle + the given type + + """ + if not isinstance(parsed_msg, dict): + raise MessageParseError("Expected a JSON object") + message_type = parsed_msg.get("type") + + if not message_type: + raise MessageParseError("Message does not contain 'type' parameter") + + registry: V2ProtocolRegistry = self.profile.inject(V2ProtocolRegistry) + try: + message_cls = registry.protocols_matching_query(message_type) + except ProtocolMinorVersionNotSupported as e: + raise MessageParseError(f"Problem parsing message type. {e}") + + if not message_cls: + raise MessageParseError(f"Unrecognized message type {message_type}") + + try: + instance = registry.handlers[message_cls[0]] + if isinstance(instance, DeferLoad): + instance = instance.resolved + except BaseModelError as e: + if "/problem-report" in message_type: + raise ProblemReportParseError("Error parsing problem report message") + raise MessageParseError(f"Error deserializing message: {e}") from e + + return instance async def handle_v1_message( self, diff --git a/acapy_agent/core/plugin_registry.py b/acapy_agent/core/plugin_registry.py index b3fa709386..04b99c3f12 100644 --- a/acapy_agent/core/plugin_registry.py +++ b/acapy_agent/core/plugin_registry.py @@ -11,6 +11,7 @@ from .error import ProtocolDefinitionValidationError from .goal_code_registry import GoalCodeRegistry from .protocol_registry import ProtocolRegistry +from ..didcomm_v2.protocol_registry import V2ProtocolRegistry LOGGER = logging.getLogger(__name__) @@ -218,8 +219,12 @@ async def load_protocol_version( version_definition: Optional[dict] = None, ): """Load a particular protocol version.""" + v2_protocol_registry = context.inject(V2ProtocolRegistry) protocol_registry = context.inject(ProtocolRegistry) goal_code_registry = context.inject(GoalCodeRegistry) + if hasattr(mod, "HANDLERS"): + for message_type, handler in mod.HANDLERS: + v2_protocol_registry.register_handler(message_type, handler) if hasattr(mod, "MESSAGE_TYPES"): protocol_registry.register_message_types( mod.MESSAGE_TYPES, version_definition=version_definition diff --git a/acapy_agent/core/tests/test_plugin_registry.py b/acapy_agent/core/tests/test_plugin_registry.py index b5727a4546..6c47ffd434 100644 --- a/acapy_agent/core/tests/test_plugin_registry.py +++ b/acapy_agent/core/tests/test_plugin_registry.py @@ -11,6 +11,7 @@ from ..goal_code_registry import GoalCodeRegistry from ..plugin_registry import PluginRegistry from ..protocol_registry import ProtocolRegistry +from ...didcomm_v2.protocol_registry import V2ProtocolRegistry class TestPluginRegistry(IsolatedAsyncioTestCase): @@ -27,6 +28,7 @@ def setUp(self): register_controllers=mock.MagicMock(), ) self.context.injector.bind_instance(ProtocolRegistry, self.proto_registry) + self.context.injector.bind_instance(V2ProtocolRegistry, V2ProtocolRegistry()) self.context.injector.bind_instance(GoalCodeRegistry, self.goal_code_registry) async def test_setup(self): diff --git a/acapy_agent/didcomm_v2/protocol_registry.py b/acapy_agent/didcomm_v2/protocol_registry.py new file mode 100644 index 0000000000..df854bcd12 --- /dev/null +++ b/acapy_agent/didcomm_v2/protocol_registry.py @@ -0,0 +1,43 @@ +"""Registry for DIDComm V2 Protocols.""" + +from ..utils.classloader import DeferLoad +from typing import Coroutine, Dict, Sequence, Union + + +class V2ProtocolRegistry: + """DIDComm V2 Protocols.""" + + def __init__(self): + """Initialize a V2ProtocolRegistry instance.""" + self._type_to_message_handler: Dict[str, Coroutine] = {} + + @property + def handlers(self) -> Dict[str, Coroutine]: + """Accessor for a list of all message protocols.""" + return self._type_to_message_handler + + @property + def protocols(self) -> Sequence[str]: + """Accessor for a list of all message protocols.""" + return [str(key) for key in self._type_to_message_handler.keys()] + + def protocols_matching_query(self, query: str) -> Sequence[str]: + """Return a list of message protocols matching a query string.""" + all_types = self.protocols + result = None + + if query == "*" or query is None: + result = all_types + elif query: + if query.endswith("*"): + match = query[:-1] + result = tuple(k for k in all_types if k.startswith(match)) + elif query in all_types: + result = (query,) + return result or () + + def register_handler(self, message_type: str, handler: Union[Coroutine, str]): + """Register a new message type to handler association.""" + if isinstance(handler, str): + handler = DeferLoad(handler) + self._type_to_message_handler[message_type] = handler diff --git a/acapy_agent/messaging/responder.py b/acapy_agent/messaging/responder.py index 98a2bb6425..9ad65a5cdd 100644 --- a/acapy_agent/messaging/responder.py +++ b/acapy_agent/messaging/responder.py @@ -45,11 +45,14 @@ def __init__( connection_id: Optional[str] = None, reply_session_id: Optional[str] = None, reply_to_verkey: Optional[str] = None, + target: Optional[ConnectionTarget] = None, + target_list: Sequence[ConnectionTarget] = None, ): """Initialize a base responder.""" self.connection_id = connection_id self.reply_session_id = reply_session_id self.reply_to_verkey = reply_to_verkey + self.target_list = target_list async def create_outbound( self, @@ -133,7 +136,7 @@ async def send_reply( reply_session_id=self.reply_session_id, reply_to_verkey=self.reply_to_verkey, target=target, - target_list=target_list, + target_list=target_list or self.target_list, ) if isinstance(message, BaseMessage): msg_type = message._message_type diff --git a/acapy_agent/protocols_v2/__init__.py b/acapy_agent/protocols_v2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/basicmessage/__init__.py b/acapy_agent/protocols_v2/basicmessage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/basicmessage/definition.py b/acapy_agent/protocols_v2/basicmessage/definition.py new file mode 100644 index 0000000000..62bddef6f5 --- /dev/null +++ b/acapy_agent/protocols_v2/basicmessage/definition.py @@ -0,0 +1,10 @@ +"""Version definitions for this protocol.""" + +versions = [ + { + "major_version": 1, + "minimum_minor_version": 0, + "current_minor_version": 0, + "path": "v1_0", + } +] diff --git a/acapy_agent/protocols_v2/basicmessage/v1_0/__init__.py b/acapy_agent/protocols_v2/basicmessage/v1_0/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/basicmessage/v1_0/message_types.py b/acapy_agent/protocols_v2/basicmessage/v1_0/message_types.py new file mode 100644 index 0000000000..d5ea1b688b --- /dev/null +++ b/acapy_agent/protocols_v2/basicmessage/v1_0/message_types.py @@ -0,0 +1,47 @@ +"""Message type identifiers for Trust Pings.""" + +import logging +from ....messaging.v2_agent_message import V2AgentMessage + +SPEC_URI = "https://didcomm.org/basicmessage/2.0/message" + +# Message types +BASIC_MESSAGE = "https://didcomm.org/basicmessage/2.0/message" + +PROTOCOL_PACKAGE = "acapy_agent.protocols_v2.basicmessage.v1_0" + + +class basic_message: + """Basic Message 2.0 DIDComm V2 Protocol.""" + + async def __call__(self, *args, **kwargs): + """Call the Handler.""" + await self.handle(*args, **kwargs) + + @staticmethod + async def handle(context, responder, payload): + """Handle the incoming message.""" + logging.getLogger(__name__) + their_did = context.message_receipt.sender_verkey.split("#")[0] + our_did = context.message_receipt.recipient_verkey.split("#")[0] + error_result = V2AgentMessage( + message={ + "type": BASIC_MESSAGE, + "body": { + "content": "Hello from acapy", + }, + "to": [their_did], + "from": our_did, + "lang": "en", + } + ) + await responder.send_reply(error_result) + + +HANDLERS = { + BASIC_MESSAGE: f"{PROTOCOL_PACKAGE}.message_types.basic_message", +}.items() + +MESSAGE_TYPES = { + BASIC_MESSAGE: f"{PROTOCOL_PACKAGE}.message_types.basic_message", +} diff --git a/acapy_agent/protocols_v2/basicmessage/v1_0/routes.py b/acapy_agent/protocols_v2/basicmessage/v1_0/routes.py new file mode 100644 index 0000000000..39faf80e6d --- /dev/null +++ b/acapy_agent/protocols_v2/basicmessage/v1_0/routes.py @@ -0,0 +1,242 @@ +"""Trust ping admin routes.""" + +from aiohttp import web +from aiohttp_apispec import docs, request_schema, response_schema +from marshmallow import fields +from didcomm_messaging import DIDCommMessaging, RoutingService +from didcomm_messaging.resolver import DIDResolver as DMPResolver + +from ....admin.decorators.auth import tenant_authentication +from ....admin.request_context import AdminRequestContext +from ....messaging.models.openapi import OpenAPISchema +from ....messaging.valid import UUID4_EXAMPLE +from .message_types import SPEC_URI + +from ....wallet.base import BaseWallet +from ....wallet.did_info import DIDInfo +from ....wallet.did_method import ( + DIDMethod, + DIDMethods, +) +from ....wallet.did_posture import DIDPosture +from ....messaging.v2_agent_message import V2AgentMessage +from ....connections.models.connection_target import ConnectionTarget + + +class BaseDIDCommV2Schema(OpenAPISchema): + """Request schema for performing a ping.""" + + to_did = fields.Str( + required=True, + allow_none=False, + metadata={"description": "Comment for the ping message"}, + ) + + +class PingRequestSchema(BaseDIDCommV2Schema): + """Request schema for performing a ping.""" + + response_requested = fields.Bool( + required=False, + allow_none=True, + metadata={"description": "Comment for the ping message"}, + ) + + +class PingRequestResponseSchema(OpenAPISchema): + """Request schema for performing a ping.""" + + thread_id = fields.Str( + required=False, metadata={"description": "Thread ID of the ping message"} + ) + + +class PingConnIdMatchInfoSchema(OpenAPISchema): + """Path parameters and validators for request taking connection id.""" + + conn_id = fields.Str( + required=True, + metadata={"description": "Connection identifier", "example": UUID4_EXAMPLE}, + ) + + +def format_did_info(info: DIDInfo): + """Serialize a DIDInfo object.""" + if info: + return { + "did": info.did, + "verkey": info.verkey, + "posture": DIDPosture.get(info.metadata).moniker, + "key_type": info.key_type.key_type, + "method": info.method.method_name, + "metadata": info.metadata, + } + + +async def get_mydid(request: web.BaseRequest): + """Get a DID that can be used for communication.""" + context: AdminRequestContext = request["context"] + # filter_did = request.query.get("did") + # filter_verkey = request.query.get("verkey") + filter_posture = DIDPosture.get(request.query.get("posture")) + results = [] + async with context.session() as session: + did_methods: DIDMethods = session.inject(DIDMethods) + filter_method: DIDMethod | None = did_methods.from_method( + request.query.get("method") or "did:peer:2" + ) + # key_types = session.inject(KeyTypes) + # filter_key_type = key_types.from_key_type(request.query.get("key_type", "")) + wallet: BaseWallet | None = session.inject_or(BaseWallet) + if not wallet: + raise web.HTTPForbidden(reason="No wallet available") + else: + dids = await wallet.get_local_dids() + results = [ + format_did_info(info) + for info in dids + if ( + filter_posture is None + or DIDPosture.get(info.metadata) is DIDPosture.WALLET_ONLY + ) + and (not filter_method or info.method == filter_method) + # and (not filter_key_type or info.key_type == filter_key_type) + ] + + results.sort(key=lambda info: (DIDPosture.get(info["posture"]).ordinal, info["did"])) + our_did = results[0]["did"] + return our_did + + +async def get_target(request: web.BaseRequest, to_did: str, from_did: str): + """Get Connection Target from did.""" + context: AdminRequestContext = request["context"] + + try: + async with context.profile.session() as session: + resolver = session.inject(DMPResolver) + await resolver.resolve(to_did) + except Exception as err: + raise web.HTTPNotFound(reason=str(err)) from err + + async with context.session() as session: + ctx = session + messaging = ctx.inject(DIDCommMessaging) + routing_service = ctx.inject(RoutingService) + frm = to_did + services = await routing_service._resolve_services(messaging.resolver, frm) + chain = [ + { + "did": frm, + "service": services, + } + ] + + # Loop through service DIDs until we run out of DIDs to forward to + to_target = services[0].service_endpoint.uri + found_forwardable_service = await routing_service.is_forwardable_service( + messaging.resolver, services[0] + ) + while found_forwardable_service: + services = await routing_service._resolve_services( + messaging.resolver, to_target + ) + if services: + chain.append( + { + "did": to_target, + "service": services, + } + ) + to_target = services[0].service_endpoint.uri + found_forwardable_service = ( + await routing_service.is_forwardable_service( + messaging.resolver, services[0] + ) + if services + else False + ) + reply_destination = [ + ConnectionTarget( + did=f"{to_did}#key-1", + endpoint=service.service_endpoint.uri, + recipient_keys=[f"{to_did}#key-1"], + sender_key=from_did + "#key-1", + ) + for service in chain[-1]["service"] + ] + return reply_destination + + +class BasicMessageSchema(BaseDIDCommV2Schema): + """Request schema for performing a ping.""" + + content = fields.Str( + required=True, + allow_none=False, + metadata={"description": "Basic Message message content"}, + ) + + +@docs(tags=["basicmessagev2", "didcommv2"], summary="Send a Basic Message") +@request_schema(BasicMessageSchema()) +@response_schema(PingRequestResponseSchema(), 200, description="") +@tenant_authentication +async def basic_message_send(request: web.BaseRequest): + """Request handler for sending a trust ping to a connection. + + Args: + request: aiohttp request object + + """ + request["context"] + outbound_handler = request["outbound_message_router"] + body = await request.json() + to_did = body.get("to_did") + message = body.get("content") + + our_did = await get_mydid(request) + their_did = to_did + reply_destination = await get_target(request, to_did, our_did) + msg = V2AgentMessage( + message={ + "type": "https://didcomm.org/basicmessage/2.0/message", + "body": {"content": message}, + "lang": "en", + "to": [their_did], + "from": our_did, + } + ) + await outbound_handler(msg, target_list=reply_destination) + return web.json_response(msg.message) + + +async def register(app: web.Application): + """Register routes.""" + + app.add_routes([web.post("/basic-message/send-message", basic_message_send)]) + + +def post_process_routes(app: web.Application): + """Amend swagger API.""" + + # Add top-level tags description + if "tags" not in app._state["swagger_dict"]: + app._state["swagger_dict"]["tags"] = [] + app._state["swagger_dict"]["tags"].append( + { + "name": "basicmessagev2", + "description": "Basic Message to contact", + "externalDocs": {"description": "Specification", "url": SPEC_URI}, + } + ) + app._state["swagger_dict"]["tags"].append( + { + "name": "didcommv2", + "description": "DIDComm V2 based protocols for Interop-a-thon", + "externalDocs": { + "description": "Specification", + "url": "https://didcomm.org", + }, + } + ) diff --git a/acapy_agent/protocols_v2/connections/__init__.py b/acapy_agent/protocols_v2/connections/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/connections/definition.py b/acapy_agent/protocols_v2/connections/definition.py new file mode 100644 index 0000000000..62bddef6f5 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/definition.py @@ -0,0 +1,10 @@ +"""Version definitions for this protocol.""" + +versions = [ + { + "major_version": 1, + "minimum_minor_version": 0, + "current_minor_version": 0, + "path": "v1_0", + } +] diff --git a/acapy_agent/protocols_v2/connections/v1_0/__init__.py b/acapy_agent/protocols_v2/connections/v1_0/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/connections/v1_0/handlers/__init__.py b/acapy_agent/protocols_v2/connections/v1_0/handlers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/connections/v1_0/handlers/connection_invitation_handler.py b/acapy_agent/protocols_v2/connections/v1_0/handlers/connection_invitation_handler.py new file mode 100644 index 0000000000..124e9130e3 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/handlers/connection_invitation_handler.py @@ -0,0 +1,30 @@ +"""Connect invitation handler.""" + +from .....messaging.base_handler import BaseHandler, BaseResponder, RequestContext +from ..messages.connection_invitation import ConnectionInvitation +from ..messages.problem_report import ConnectionProblemReport, ProblemReportReason + + +class ConnectionInvitationHandler(BaseHandler): + """Handler class for connection invitations.""" + + async def handle(self, context: RequestContext, responder: BaseResponder): + """Handle connection invitation. + + Args: + context: Request context + responder: Responder callback + """ + + self._logger.debug(f"ConnectionInvitationHandler called with context {context}") + assert isinstance(context.message, ConnectionInvitation) + + report = ConnectionProblemReport( + description={ + "code": ProblemReportReason.INVITATION_NOT_ACCEPTED.value, + "en": ("Connection invitations cannot be submitted via agent messaging"), + } + ) + report.assign_thread_from(context.message) + # client likely needs to be using direct responses to receive the problem report + await responder.send_reply(report) diff --git a/acapy_agent/protocols_v2/connections/v1_0/handlers/connection_request_handler.py b/acapy_agent/protocols_v2/connections/v1_0/handlers/connection_request_handler.py new file mode 100644 index 0000000000..042759f225 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/handlers/connection_request_handler.py @@ -0,0 +1,58 @@ +"""Connection request handler.""" + +from .....connections.models.conn_record import ConnRecord +from .....messaging.base_handler import BaseHandler, BaseResponder, RequestContext +from ....coordinate_mediation.v1_0.manager import MediationManager +from ..manager import ConnectionManager, ConnectionManagerError +from ..messages.connection_request import ConnectionRequest + + +class ConnectionRequestHandler(BaseHandler): + """Handler class for connection requests.""" + + async def handle(self, context: RequestContext, responder: BaseResponder): + """Handle connection request. + + Args: + context: Request context + responder: Responder callback + """ + + self._logger.debug(f"ConnectionRequestHandler called with context {context}") + assert isinstance(context.message, ConnectionRequest) + + profile = context.profile + mgr = ConnectionManager(profile) + + mediation_id = None + if context.connection_record: + async with profile.session() as session: + mediation_metadata = await context.connection_record.metadata_get( + session, MediationManager.METADATA_KEY, {} + ) + mediation_id = mediation_metadata.get(MediationManager.METADATA_ID) + + try: + connection = await mgr.receive_request( + context.message, + context.message_receipt, + ) + + if connection.accept == ConnRecord.ACCEPT_AUTO: + response = await mgr.create_response( + connection, mediation_id=mediation_id + ) + await responder.send_reply( + response, connection_id=connection.connection_id + ) + else: + self._logger.debug("Connection request will await acceptance") + except ConnectionManagerError as e: + report, targets = mgr.manager_error_to_problem_report( + e, context.message, context.message_receipt + ) + if report and targets: + await responder.send_reply( + message=report, + target_list=targets, + ) diff --git a/acapy_agent/protocols_v2/connections/v1_0/handlers/connection_response_handler.py b/acapy_agent/protocols_v2/connections/v1_0/handlers/connection_response_handler.py new file mode 100644 index 0000000000..61fd814362 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/handlers/connection_response_handler.py @@ -0,0 +1,41 @@ +"""Connection response handler.""" + +from .....messaging.base_handler import BaseHandler, BaseResponder, RequestContext +from .....protocols.trustping.v1_0.messages.ping import Ping +from ..manager import ConnectionManager, ConnectionManagerError +from ..messages.connection_response import ConnectionResponse + + +class ConnectionResponseHandler(BaseHandler): + """Handler class for connection responses.""" + + async def handle(self, context: RequestContext, responder: BaseResponder): + """Handle connection response. + + Args: + context: Request context + responder: Responder callback + """ + self._logger.debug(f"ConnectionResponseHandler called with context {context}") + assert isinstance(context.message, ConnectionResponse) + + profile = context.profile + mgr = ConnectionManager(profile) + try: + connection = await mgr.accept_response( + context.message, context.message_receipt + ) + except ConnectionManagerError as e: + report, targets = mgr.manager_error_to_problem_report( + e, context.message, context.message_receipt + ) + if report and targets: + await responder.send_reply( + message=report, + target_list=targets, + ) + return + + # send trust ping in response + if context.settings.get("auto_ping_connection"): + (await responder.send(Ping(), connection_id=connection.connection_id),) diff --git a/acapy_agent/protocols_v2/connections/v1_0/handlers/problem_report_handler.py b/acapy_agent/protocols_v2/connections/v1_0/handlers/problem_report_handler.py new file mode 100644 index 0000000000..8be8ec31fd --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/handlers/problem_report_handler.py @@ -0,0 +1,46 @@ +"""Problem report handler for Connection Protocol.""" + +from .....connections.models.conn_record import ConnRecord +from .....messaging.base_handler import ( + BaseHandler, + BaseResponder, + HandlerException, + RequestContext, +) +from .....storage.error import StorageNotFoundError +from ..manager import ConnectionManager, ConnectionManagerError +from ..messages.problem_report import ConnectionProblemReport + + +class ConnectionProblemReportHandler(BaseHandler): + """Handler class for Connection problem report messages.""" + + async def handle(self, context: RequestContext, responder: BaseResponder): + """Handle problem report message.""" + self._logger.debug( + f"ConnectionProblemReportHandler called with context {context}" + ) + assert isinstance(context.message, ConnectionProblemReport) + + self._logger.info(f"Received problem report: {context.message.problem_code}") + profile = context.profile + mgr = ConnectionManager(profile) + try: + conn_rec = context.connection_record + if not conn_rec: + # try to find connection by thread_id/request_id + try: + async with profile.session() as session: + conn_rec = await ConnRecord.retrieve_by_request_id( + session, context.message._thread_id + ) + except StorageNotFoundError: + pass + + if conn_rec: + await mgr.receive_problem_report(conn_rec, context.message) + else: + raise HandlerException("No connection established for problem report") + except ConnectionManagerError: + # Unrecognized problem report code + self._logger.exception("Error receiving Connection problem report") diff --git a/acapy_agent/protocols_v2/connections/v1_0/handlers/tests/__init__.py b/acapy_agent/protocols_v2/connections/v1_0/handlers/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/connections/v1_0/handlers/tests/test_invitation_handler.py b/acapy_agent/protocols_v2/connections/v1_0/handlers/tests/test_invitation_handler.py new file mode 100644 index 0000000000..5c51c04f61 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/handlers/tests/test_invitation_handler.py @@ -0,0 +1,37 @@ +import pytest + +from ......messaging.request_context import RequestContext +from ......messaging.responder import MockResponder +from ......transport.inbound.receipt import MessageReceipt +from ......utils.testing import create_test_profile +from ...handlers.connection_invitation_handler import ConnectionInvitationHandler +from ...messages.connection_invitation import ConnectionInvitation +from ...messages.problem_report import ConnectionProblemReport, ProblemReportReason + + +@pytest.fixture() +async def request_context(): + ctx = RequestContext.test_context(await create_test_profile()) + ctx.message_receipt = MessageReceipt() + yield ctx + + +class TestInvitationHandler: + @pytest.mark.asyncio + async def test_problem_report(self, request_context): + request_context.message = ConnectionInvitation() + handler = ConnectionInvitationHandler() + responder = MockResponder() + await handler.handle(request_context, responder) + messages = responder.messages + assert len(messages) == 1 + result, target = messages[0] + assert ( + isinstance(result, ConnectionProblemReport) + and result.description + and ( + result.description["code"] + == ProblemReportReason.INVITATION_NOT_ACCEPTED.value + ) + ) + assert not target diff --git a/acapy_agent/protocols_v2/connections/v1_0/handlers/tests/test_request_handler.py b/acapy_agent/protocols_v2/connections/v1_0/handlers/tests/test_request_handler.py new file mode 100644 index 0000000000..57b7dc76ed --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/handlers/tests/test_request_handler.py @@ -0,0 +1,274 @@ +import pytest + +from acapy_agent.tests import mock + +from ......connections.models import connection_target +from ......connections.models.conn_record import ConnRecord +from ......connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service +from ......messaging.request_context import RequestContext +from ......messaging.responder import MockResponder +from ......storage.base import BaseStorage +from ......storage.error import StorageNotFoundError +from ......transport.inbound.receipt import MessageReceipt +from ......utils.testing import create_test_profile +from ...handlers import connection_request_handler as handler +from ...manager import ConnectionManagerError +from ...messages.connection_request import ConnectionRequest +from ...messages.problem_report import ConnectionProblemReport, ProblemReportReason +from ...models.connection_detail import ConnectionDetail + + +@pytest.fixture() +async def request_context(): + ctx = RequestContext.test_context(await create_test_profile()) + ctx.message_receipt = MessageReceipt() + yield ctx + + +@pytest.fixture() +async def session(request_context): + yield await request_context.session() + + +@pytest.fixture() +async def connection_record(request_context, session): + record = ConnRecord() + request_context.connection_record = record + await record.save(session) + yield record + + +TEST_DID = "55GkHamhTU1ZbTbV2ab9DE" +TEST_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" +TEST_LABEL = "Label" +TEST_ENDPOINT = "http://localhost" +TEST_IMAGE_URL = "http://aries.ca/images/sample.png" + + +@pytest.fixture() +def did_doc(): + doc = DIDDoc(did=TEST_DID) + controller = TEST_DID + ident = "1" + pk_value = TEST_VERKEY + pk = PublicKey( + TEST_DID, + ident, + pk_value, + PublicKeyType.ED25519_SIG_2018, + controller, + False, + ) + doc.set(pk) + recip_keys = [pk] + router_keys = [] + service = Service( + TEST_DID, + "indy", + "IndyAgent", + recip_keys, + router_keys, + TEST_ENDPOINT, + ) + doc.set(service) + yield doc + + +class TestRequestHandler: + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + async def test_called(self, mock_conn_mgr, request_context): + mock_conn_mgr.return_value.receive_request = mock.CoroutineMock() + request_context.message = ConnectionRequest() + handler_inst = handler.ConnectionRequestHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + mock_conn_mgr.return_value.receive_request.assert_called_once_with( + request_context.message, request_context.message_receipt + ) + assert not responder.messages + + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + async def test_called_with_auto_response(self, mock_conn_mgr, request_context): + mock_conn_rec = mock.MagicMock() + mock_conn_rec.accept = ConnRecord.ACCEPT_AUTO + mock_conn_mgr.return_value.receive_request = mock.CoroutineMock( + return_value=mock_conn_rec + ) + mock_conn_mgr.return_value.create_response = mock.CoroutineMock() + request_context.message = ConnectionRequest() + handler_inst = handler.ConnectionRequestHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + mock_conn_mgr.return_value.receive_request.assert_called_once_with( + request_context.message, request_context.message_receipt + ) + mock_conn_mgr.return_value.create_response.assert_called_once_with( + mock_conn_rec, mediation_id=None + ) + assert responder.messages + + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + async def test_connection_record_with_mediation_metadata_auto_response( + self, mock_conn_mgr, request_context, connection_record + ): + mock_conn_rec = mock.MagicMock() + mock_conn_rec.accept = ConnRecord.ACCEPT_AUTO + mock_conn_mgr.return_value.receive_request = mock.CoroutineMock( + return_value=mock_conn_rec + ) + mock_conn_mgr.return_value.create_response = mock.CoroutineMock() + request_context.message = ConnectionRequest() + with mock.patch.object( + connection_record, + "metadata_get", + mock.CoroutineMock(return_value={"id": "test-mediation-id"}), + ): + handler_inst = handler.ConnectionRequestHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + mock_conn_mgr.return_value.receive_request.assert_called_once() + mock_conn_mgr.return_value.create_response.assert_called_once_with( + mock_conn_rec, mediation_id="test-mediation-id" + ) + assert responder.messages + + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + async def test_connection_record_without_mediation_metadata( + self, mock_conn_mgr, request_context, session, connection_record + ): + mock_conn_mgr.return_value.receive_request = mock.CoroutineMock() + request_context.message = ConnectionRequest() + storage: BaseStorage = session.inject(BaseStorage) + with mock.patch.object( + storage, + "find_record", + mock.CoroutineMock(side_effect=StorageNotFoundError), + ): + handler_inst = handler.ConnectionRequestHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + mock_conn_mgr.return_value.receive_request.assert_called_once_with( + request_context.message, + request_context.message_receipt, + ) + assert not responder.messages + + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + @mock.patch.object(connection_target, "ConnectionTarget") + async def test_problem_report(self, mock_conn_target, mock_conn_mgr, request_context): + mock_conn_mgr.return_value.receive_request = mock.CoroutineMock() + mock_conn_mgr.return_value.receive_request.side_effect = ConnectionManagerError( + error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value + ) + mock_conn_mgr.return_value.manager_error_to_problem_report = mock.MagicMock( + return_value=( + ConnectionProblemReport( + description={ + "en": "test error", + "code": ProblemReportReason.REQUEST_NOT_ACCEPTED.value, + } + ), + [mock_conn_target], + ) + ) + request_context.message = ConnectionRequest() + handler_inst = handler.ConnectionRequestHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + messages = responder.messages + assert len(messages) == 1 + result, target = messages[0] + assert ( + isinstance(result, ConnectionProblemReport) + and result.description + and ( + result.description["code"] + == ProblemReportReason.REQUEST_NOT_ACCEPTED.value + ) + ) + assert target == {"target_list": [mock_conn_target]} + + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + @mock.patch.object(connection_target, "ConnectionTarget") + async def test_problem_report_did_doc( + self, mock_conn_target, mock_conn_mgr, request_context, did_doc + ): + mock_conn_mgr.return_value.receive_request = mock.CoroutineMock() + mock_conn_mgr.return_value.receive_request.side_effect = ConnectionManagerError( + error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value + ) + mock_conn_mgr.return_value.diddoc_connection_targets = mock.MagicMock( + return_value=[mock_conn_target] + ) + mock_conn_mgr.return_value.manager_error_to_problem_report = mock.MagicMock( + return_value=( + ConnectionProblemReport( + description={ + "en": "test error", + "code": ProblemReportReason.REQUEST_NOT_ACCEPTED.value, + } + ), + [mock_conn_target], + ) + ) + request_context.message = ConnectionRequest( + connection=ConnectionDetail(did=TEST_DID, did_doc=did_doc), + label=TEST_LABEL, + image_url=TEST_IMAGE_URL, + ) + handler_inst = handler.ConnectionRequestHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + messages = responder.messages + assert len(messages) == 1 + result, target = messages[0] + assert ( + isinstance(result, ConnectionProblemReport) + and result.description + and ( + result.description["code"] + == ProblemReportReason.REQUEST_NOT_ACCEPTED.value + ) + ) + assert target == {"target_list": [mock_conn_target]} + + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + @mock.patch.object(connection_target, "ConnectionTarget") + async def test_problem_report_did_doc_no_conn_target( + self, mock_conn_target, mock_conn_mgr, request_context, did_doc + ): + mock_conn_mgr.return_value.receive_request = mock.CoroutineMock() + mock_conn_mgr.return_value.receive_request.side_effect = ConnectionManagerError( + error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value + ) + mock_conn_mgr.return_value.diddoc_connection_targets = mock.MagicMock( + side_effect=ConnectionManagerError("no targets") + ) + mock_conn_mgr.return_value.manager_error_to_problem_report = mock.MagicMock( + return_value=( + ConnectionProblemReport( + description={ + "en": "test error", + "code": ProblemReportReason.REQUEST_NOT_ACCEPTED.value, + } + ), + None, + ) + ) + request_context.message = ConnectionRequest( + connection=ConnectionDetail(did=TEST_DID, did_doc=did_doc), + label=TEST_LABEL, + image_url=TEST_IMAGE_URL, + ) + handler_inst = handler.ConnectionRequestHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + messages = responder.messages + assert len(messages) == 0 # messages require a target! diff --git a/acapy_agent/protocols_v2/connections/v1_0/handlers/tests/test_response_handler.py b/acapy_agent/protocols_v2/connections/v1_0/handlers/tests/test_response_handler.py new file mode 100644 index 0000000000..414a8e0d39 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/handlers/tests/test_response_handler.py @@ -0,0 +1,203 @@ +import pytest + +from acapy_agent.tests import mock + +from ......connections.models import connection_target +from ......connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service +from ......messaging.request_context import RequestContext +from ......messaging.responder import MockResponder +from ......protocols.trustping.v1_0.messages.ping import Ping +from ......transport.inbound.receipt import MessageReceipt +from ......utils.testing import create_test_profile +from ...handlers import connection_response_handler as handler +from ...manager import ConnectionManagerError +from ...messages.connection_response import ConnectionResponse +from ...messages.problem_report import ConnectionProblemReport, ProblemReportReason +from ...models.connection_detail import ConnectionDetail + + +@pytest.fixture() +async def request_context(): + ctx = RequestContext.test_context(await create_test_profile()) + ctx.message_receipt = MessageReceipt() + yield ctx + + +TEST_DID = "55GkHamhTU1ZbTbV2ab9DE" +TEST_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" +TEST_LABEL = "Label" +TEST_ENDPOINT = "http://localhost" +TEST_IMAGE_URL = "http://aries.ca/images/sample.png" + + +@pytest.fixture() +def did_doc(): + doc = DIDDoc(did=TEST_DID) + controller = TEST_DID + ident = "1" + pk_value = TEST_VERKEY + pk = PublicKey( + TEST_DID, + ident, + pk_value, + PublicKeyType.ED25519_SIG_2018, + controller, + False, + ) + doc.set(pk) + recip_keys = [pk] + router_keys = [] + service = Service( + TEST_DID, + "indy", + "IndyAgent", + recip_keys, + router_keys, + TEST_ENDPOINT, + ) + doc.set(service) + yield doc + + +class TestResponseHandler: + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + async def test_called(self, mock_conn_mgr, request_context): + mock_conn_mgr.return_value.accept_response = mock.CoroutineMock() + request_context.message = ConnectionResponse() + handler_inst = handler.ConnectionResponseHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + mock_conn_mgr.return_value.accept_response.assert_called_once_with( + request_context.message, request_context.message_receipt + ) + assert not responder.messages + + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + async def test_called_auto_ping(self, mock_conn_mgr, request_context): + request_context.update_settings({"auto_ping_connection": True}) + mock_conn_mgr.return_value.accept_response = mock.CoroutineMock() + request_context.message = ConnectionResponse() + handler_inst = handler.ConnectionResponseHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + mock_conn_mgr.return_value.accept_response.assert_called_once_with( + request_context.message, request_context.message_receipt + ) + messages = responder.messages + assert len(messages) == 1 + result, _ = messages[0] + assert isinstance(result, Ping) + + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + @mock.patch.object(connection_target, "ConnectionTarget") + async def test_problem_report(self, mock_conn_target, mock_conn_mgr, request_context): + mock_conn_mgr.return_value.accept_response = mock.CoroutineMock() + mock_conn_mgr.return_value.accept_response.side_effect = ConnectionManagerError( + error_code=ProblemReportReason.RESPONSE_NOT_ACCEPTED.value, + ) + mock_conn_mgr.return_value.manager_error_to_problem_report = mock.MagicMock( + return_value=( + ConnectionProblemReport( + description={ + "en": "test error", + "code": ProblemReportReason.RESPONSE_NOT_ACCEPTED.value, + } + ), + [mock_conn_target], + ) + ) + request_context.message = ConnectionResponse() + handler_inst = handler.ConnectionResponseHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + messages = responder.messages + assert len(messages) == 1 + result, target = messages[0] + assert ( + isinstance(result, ConnectionProblemReport) + and result.description + and ( + result.description["code"] + == ProblemReportReason.RESPONSE_NOT_ACCEPTED.value + ) + ) + assert target == {"target_list": [mock_conn_target]} + + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + @mock.patch.object(connection_target, "ConnectionTarget") + async def test_problem_report_did_doc( + self, mock_conn_target, mock_conn_mgr, request_context, did_doc + ): + mock_conn_mgr.return_value.accept_response = mock.CoroutineMock() + mock_conn_mgr.return_value.accept_response.side_effect = ConnectionManagerError( + error_code=ProblemReportReason.RESPONSE_NOT_ACCEPTED.value, + ) + mock_conn_mgr.return_value.diddoc_connection_targets = mock.MagicMock( + return_value=[mock_conn_target] + ) + mock_conn_mgr.return_value.manager_error_to_problem_report = mock.MagicMock( + return_value=( + ConnectionProblemReport( + description={ + "en": "test error", + "code": ProblemReportReason.RESPONSE_NOT_ACCEPTED.value, + } + ), + [mock_conn_target], + ) + ) + request_context.message = ConnectionResponse( + connection=ConnectionDetail(did=TEST_DID, did_doc=did_doc) + ) + handler_inst = handler.ConnectionResponseHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + messages = responder.messages + assert len(messages) == 1 + result, target = messages[0] + assert ( + isinstance(result, ConnectionProblemReport) + and result.description + and ( + result.description["code"] + == ProblemReportReason.RESPONSE_NOT_ACCEPTED.value + ) + ) + assert target == {"target_list": [mock_conn_target]} + + @pytest.mark.asyncio + @mock.patch.object(handler, "ConnectionManager") + @mock.patch.object(connection_target, "ConnectionTarget") + async def test_problem_report_did_doc_no_conn_target( + self, mock_conn_target, mock_conn_mgr, request_context, did_doc + ): + mock_conn_mgr.return_value.accept_response = mock.CoroutineMock() + mock_conn_mgr.return_value.accept_response.side_effect = ConnectionManagerError( + error_code=ProblemReportReason.RESPONSE_NOT_ACCEPTED.value, + ) + mock_conn_mgr.return_value.diddoc_connection_targets = mock.MagicMock( + side_effect=ConnectionManagerError("no target") + ) + mock_conn_mgr.return_value.manager_error_to_problem_report = mock.MagicMock( + return_value=( + ConnectionProblemReport( + description={ + "en": "test error", + "code": ProblemReportReason.RESPONSE_NOT_ACCEPTED.value, + } + ), + None, + ) + ) + request_context.message = ConnectionResponse( + connection=ConnectionDetail(did=TEST_DID, did_doc=did_doc) + ) + handler_inst = handler.ConnectionResponseHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + messages = responder.messages + assert len(messages) == 0 # need a connection target to send message diff --git a/acapy_agent/protocols_v2/connections/v1_0/manager.py b/acapy_agent/protocols_v2/connections/v1_0/manager.py new file mode 100644 index 0000000000..302db6f1cc --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/manager.py @@ -0,0 +1,843 @@ +"""Classes to manage connections.""" + +import logging +import warnings +from typing import Optional, Sequence, Tuple, Union, cast + +from ....connections.base_manager import BaseConnectionManager +from ....connections.models.conn_record import ConnRecord +from ....connections.models.connection_target import ConnectionTarget +from ....core.error import BaseError +from ....core.oob_processor import OobMessageProcessor +from ....core.profile import Profile +from ....messaging.responder import BaseResponder +from ....messaging.valid import IndyDID +from ....storage.error import StorageNotFoundError +from ....transport.inbound.receipt import MessageReceipt +from ....wallet.base import BaseWallet +from ....wallet.did_method import SOV +from ....wallet.key_type import ED25519 +from ....protocols.coordinate_mediation.v1_0.manager import MediationManager +from .message_types import ARIES_PROTOCOL as CONN_PROTO +from .messages.connection_invitation import ConnectionInvitation +from .messages.connection_request import ConnectionRequest +from .messages.connection_response import ConnectionResponse +from .messages.problem_report import ConnectionProblemReport, ProblemReportReason +from .models.connection_detail import ConnectionDetail + + +class ConnectionManagerError(BaseError): + """Connection error.""" + + +class ConnectionManager(BaseConnectionManager): + """Class for managing connections.""" + + def __init__(self, profile: Profile): + """Initialize a ConnectionManager. + + Args: + profile: The profile for this connection manager + """ + self._profile = profile + self._logger = logging.getLogger(__name__) + super().__init__(self._profile) + + @property + def profile(self) -> Profile: + """Accessor for the current profile. + + Returns: + The profile for this connection manager + + """ + return self._profile + + def deprecation_warning(self): + """Log a deprecation warning.""" + warnings.warn( + "Aries RFC 0160: Connection Protocol is deprecated and support will be " + "removed in a future version; use RFC 0023: DID Exchange instead.", + DeprecationWarning, + ) + self._logger.warning( + "Aries RFC 0160: Connection Protocol is deprecated and support will be " + "removed in a future version; use RFC 0023: DID Exchange instead." + ) + + async def create_invitation( + self, + my_label: Optional[str] = None, + my_endpoint: Optional[str] = None, + auto_accept: Optional[bool] = None, + public: bool = False, + multi_use: bool = False, + alias: Optional[str] = None, + routing_keys: Optional[Sequence[str]] = None, + recipient_keys: Optional[Sequence[str]] = None, + metadata: Optional[dict] = None, + mediation_id: Optional[str] = None, + ) -> Tuple[ConnRecord, ConnectionInvitation]: + """Generate new connection invitation. + + This interaction represents an out-of-band communication channel. In the future + and in practice, these sort of invitations will be received over any number of + channels such as SMS, Email, QR Code, NFC, etc. + + Structure of an invite message: + + :: + + { + "@type": "https://didcomm.org/connections/1.0/invitation", + "label": "Alice", + "did": "did:sov:QmWbsNYhMrjHiqZDTUTEJs" + } + + Or, in the case of a peer DID: + + :: + + { + "@type": "https://didcomm.org/connections/1.0/invitation", + "label": "Alice", + "did": "did:peer:oiSqsNYhMrjHiqZDTUthsw", + "recipient_keys": ["8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K"], + "service_endpoint": "https://example.com/endpoint" + "routing_keys": ["9EH5gYEeNc3z7PYXmd53d5x6qAfCNrqQqEB4nS7Zfu6K"], + } + + Args: + my_label: label for this connection + my_endpoint: endpoint where other party can reach me + auto_accept: auto-accept a corresponding connection request + (None to use config) + public: set to create an invitation from the public DID + multi_use: set to True to create an invitation for multiple use + alias: optional alias to apply to connection for later use + routing_keys: optional list of routing keys for the invitation + recipient_keys: optional list of recipient keys for the invitation + metadata: optional metadata to include in the connection record + mediation_id: optional mediation ID for the connection + + Returns: + A tuple of the new `ConnRecord` and `ConnectionInvitation` instances + + Raises: + ConnectionManagerError: if public invitations are not enabled or + no public DID is available + + """ + self.deprecation_warning() + # Mediation Record can still be None after this operation if no + # mediation id passed and no default + mediation_record = await self._route_manager.mediation_record_if_id( + self.profile, + mediation_id, + or_default=True, + ) + image_url = self.profile.context.settings.get("image_url") + invitation = None + connection = None + + invitation_mode = ConnRecord.INVITATION_MODE_ONCE + if multi_use: + invitation_mode = ConnRecord.INVITATION_MODE_MULTI + + if not my_label: + my_label = self.profile.settings.get("default_label") + + accept = ( + ConnRecord.ACCEPT_AUTO + if ( + auto_accept + or ( + auto_accept is None + and self.profile.settings.get("debug.auto_accept_requests") + ) + ) + else ConnRecord.ACCEPT_MANUAL + ) + + if recipient_keys: + # TODO: register recipient keys for relay + # TODO: check that recipient keys are in wallet + invitation_key = recipient_keys[0] # TODO first key appropriate? + else: + # Create and store new invitation key + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + invitation_signing_key = await wallet.create_signing_key(key_type=ED25519) + invitation_key = invitation_signing_key.verkey + recipient_keys = [invitation_key] + + if public: + if not self.profile.settings.get("public_invites"): + raise ConnectionManagerError("Public invitations are not enabled") + + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + public_did = await wallet.get_public_did() + if not public_did: + raise ConnectionManagerError( + "Cannot create public invitation with no public DID" + ) + + # FIXME - allow ledger instance to format public DID with prefix? + public_did_did = public_did.did + if bool(IndyDID.PATTERN.match(public_did_did)): + public_did_did = f"did:sov:{public_did.did}" + + invitation = ConnectionInvitation( + label=my_label, did=public_did_did, image_url=image_url + ) + + connection = ConnRecord( # create connection record + invitation_key=public_did.verkey, + invitation_msg_id=invitation._id, + invitation_mode=invitation_mode, + their_role=ConnRecord.Role.REQUESTER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + accept=accept, + alias=alias, + connection_protocol=CONN_PROTO, + ) + + async with self.profile.session() as session: + await connection.save(session, reason="Created new invitation") + + # Add mapping for multitenant relaying. + # Mediation of public keys is not supported yet + await self._route_manager.route_verkey(self.profile, public_did.verkey) + + else: + # Create connection record + connection = ConnRecord( + invitation_key=invitation_key, # TODO: determine correct key to use + their_role=ConnRecord.Role.REQUESTER.rfc160, + state=ConnRecord.State.INVITATION.rfc160, + accept=accept, + invitation_mode=invitation_mode, + alias=alias, + connection_protocol=CONN_PROTO, + ) + async with self.profile.session() as session: + await connection.save(session, reason="Created new invitation") + + await self._route_manager.route_invitation( + self.profile, connection, mediation_record + ) + routing_keys, routing_endpoint = await self._route_manager.routing_info( + self.profile, + mediation_record, + ) + my_endpoint = ( + routing_endpoint + or my_endpoint + or cast(str, self.profile.settings.get("default_endpoint")) + ) + + # Create connection invitation message + # Note: Need to split this into two stages + # to support inbound routing of invites + # Would want to reuse create_did_document and convert the result + invitation = ConnectionInvitation( + label=my_label, + recipient_keys=recipient_keys, + routing_keys=routing_keys, + endpoint=my_endpoint, + image_url=image_url, + ) + + async with self.profile.session() as session: + await connection.attach_invitation(session, invitation) + + if metadata: + for key, value in metadata.items(): + await connection.metadata_set(session, key, value) + + return connection, invitation + + async def receive_invitation( + self, + invitation: ConnectionInvitation, + their_public_did: Optional[str] = None, + auto_accept: Optional[bool] = None, + alias: Optional[str] = None, + mediation_id: Optional[str] = None, + ) -> ConnRecord: + """Create a new connection record to track a received invitation. + + Args: + invitation: The `ConnectionInvitation` to store + their_public_did: The public DID of the inviting party (optional) + auto_accept: Set to True to auto-accept the invitation, False to manually + accept, or None to use the default setting from the configuration + (optional) + alias: An optional alias to set on the connection record (optional) + mediation_id: The mediation ID to associate with the connection (optional) + + Returns: + The new `ConnRecord` instance representing the connection + + Raises: + ConnectionManagerError: If the invitation is missing recipient keys or an + endpoint + + """ + self.deprecation_warning() + if not invitation.did: + if not invitation.recipient_keys: + raise ConnectionManagerError( + "Invitation must contain recipient key(s)", + error_code=ProblemReportReason.MISSING_RECIPIENT_KEYS.value, + ) + if not invitation.endpoint: + raise ConnectionManagerError( + "Invitation must contain an endpoint", + error_code=ProblemReportReason.MISSING_ENDPOINT.value, + ) + accept = ( + ConnRecord.ACCEPT_AUTO + if ( + auto_accept + or ( + auto_accept is None + and self.profile.settings.get("debug.auto_accept_invites") + ) + ) + else ConnRecord.ACCEPT_MANUAL + ) + # Create connection record + connection = ConnRecord( + invitation_key=invitation.recipient_keys and invitation.recipient_keys[0], + their_label=invitation.label, + invitation_msg_id=invitation._id, + their_role=ConnRecord.Role.RESPONDER.rfc160, + state=ConnRecord.State.INVITATION.rfc160, + accept=accept, + alias=alias, + their_public_did=their_public_did, + connection_protocol=CONN_PROTO, + ) + + async with self.profile.session() as session: + await connection.save( + session, + reason="Created new connection record from invitation", + log_params={"invitation": invitation, "their_label": invitation.label}, + ) + + # Save the invitation for later processing + await connection.attach_invitation(session, invitation) + + await self._route_manager.save_mediator_for_connection( + self.profile, connection, mediation_id=mediation_id + ) + + if connection.accept == ConnRecord.ACCEPT_AUTO: + request = await self.create_request(connection, mediation_id=mediation_id) + responder = self.profile.inject_or(BaseResponder) + if responder: + await responder.send(request, connection_id=connection.connection_id) + # refetch connection for accurate state + async with self.profile.session() as session: + connection = await ConnRecord.retrieve_by_id( + session, connection.connection_id + ) + else: + self._logger.debug("Connection invitation will await acceptance") + return connection + + async def create_request( + self, + connection: ConnRecord, + my_label: Optional[str] = None, + my_endpoint: Optional[str] = None, + mediation_id: Optional[str] = None, + ) -> ConnectionRequest: + """Create a new connection request for a previously-received invitation. + + Args: + connection: The `ConnRecord` representing the invitation to accept + my_label: My label + my_endpoint: My endpoint + mediation_id: The record id for mediation + + Returns: + A new `ConnectionRequest` message to send to the other agent + + """ + self.deprecation_warning() + + mediation_records = await self._route_manager.mediation_records_for_connection( + self.profile, + connection, + mediation_id, + or_default=True, + ) + + if connection.my_did: + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.get_local_did(connection.my_did) + else: + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + # Create new DID for connection + my_info = await wallet.create_local_did(SOV, ED25519) + connection.my_did = my_info.did + + # Idempotent; if routing has already been set up, no action taken + await self._route_manager.route_connection_as_invitee( + self.profile, connection, mediation_records + ) + + # Create connection request message + if my_endpoint: + my_endpoints = [my_endpoint] + else: + my_endpoints = [] + default_endpoint = self.profile.settings.get("default_endpoint") + if default_endpoint: + my_endpoints.append(default_endpoint) + my_endpoints.extend(self.profile.settings.get("additional_endpoints", [])) + + did_doc = await self.create_did_document( + my_info, + my_endpoints, + mediation_records=mediation_records, + ) + + if not my_label: + my_label = self.profile.settings.get("default_label") + request = ConnectionRequest( + label=my_label, + connection=ConnectionDetail(did=connection.my_did, did_doc=did_doc), + image_url=self.profile.settings.get("image_url"), + ) + request.assign_thread_id(thid=request._id, pthid=connection.invitation_msg_id) + + # Update connection state + connection.request_id = request._id + connection.state = ConnRecord.State.REQUEST.rfc160 + + async with self.profile.session() as session: + await connection.save(session, reason="Created connection request") + + return request + + async def receive_request( + self, + request: ConnectionRequest, + receipt: MessageReceipt, + ) -> ConnRecord: + """Receive and store a connection request. + + Args: + request: The `ConnectionRequest` to accept + receipt: The message receipt + + Returns: + The new or updated `ConnRecord` instance + + """ + self.deprecation_warning() + ConnRecord.log_state( + "Receiving connection request", + {"request": request}, + settings=self.profile.settings, + ) + + connection = None + connection_key = None + my_info = None + + # Determine what key will need to sign the response + if receipt.recipient_did_public: + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.get_local_did(receipt.recipient_did) + connection_key = my_info.verkey + else: + connection_key = receipt.recipient_verkey + try: + async with self.profile.session() as session: + connection = await ConnRecord.retrieve_by_invitation_key( + session=session, + invitation_key=connection_key, + their_role=ConnRecord.Role.REQUESTER.rfc160, + ) + except StorageNotFoundError: + raise ConnectionManagerError( + "No invitation found for pairwise connection " + f"in state {ConnRecord.State.INVITATION.rfc160}: " + "a prior connection request may have updated the connection state", + error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value, + ) + + invitation = None + if connection: + async with self.profile.session() as session: + invitation = await connection.retrieve_invitation(session) + connection_key = connection.invitation_key + ConnRecord.log_state( + "Found invitation", + {"invitation": invitation}, + settings=self.profile.settings, + ) + + if connection.is_multiuse_invitation: + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.create_local_did(SOV, ED25519) + + new_connection = ConnRecord( + invitation_key=connection_key, + my_did=my_info.did, + state=ConnRecord.State.REQUEST.rfc160, + accept=connection.accept, + their_role=connection.their_role, + connection_protocol=CONN_PROTO, + ) + async with self.profile.session() as session: + await new_connection.save( + session, + reason=( + "Received connection request from multi-use invitation DID" + ), + event=False, + ) + + # Transfer metadata from multi-use to new connection + # Must come after save so there's an ID to associate with metadata + async with self.profile.session() as session: + for key, value in ( + await connection.metadata_get_all(session) + ).items(): + await new_connection.metadata_set(session, key, value) + + connection = new_connection + + conn_did_doc = request.connection.did_doc + if not conn_did_doc: + raise ConnectionManagerError( + "No DIDDoc provided; cannot connect to public DID", + ) + if request.connection.did != conn_did_doc.did: + raise ConnectionManagerError( + "Connection DID does not match DIDDoc id", + error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value, + ) + await self.store_did_document(conn_did_doc) + + if connection: + connection.their_label = request.label + connection.their_did = request.connection.did + connection.state = ConnRecord.State.REQUEST.rfc160 + async with self.profile.session() as session: + # force emitting event that would be ignored for multi-use invitations + # since the record is not new, and the state was not updated + await connection.save( + session, + reason="Received connection request from invitation", + event=True, + ) + elif not self.profile.settings.get("public_invites"): + raise ConnectionManagerError("Public invitations are not enabled") + else: # request from public did + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.create_local_did(SOV, ED25519) + + async with self.profile.session() as session: + connection = await ConnRecord.retrieve_by_invitation_msg_id( + session=session, + invitation_msg_id=request._thread.pthid, + their_role=ConnRecord.Role.REQUESTER.rfc160, + ) + if not connection: + if not self.profile.settings.get("requests_through_public_did"): + raise ConnectionManagerError( + "Unsolicited connection requests to public DID is not enabled" + ) + connection = ConnRecord() + connection.invitation_key = connection_key + connection.my_did = my_info.did + connection.their_role = ConnRecord.Role.RESPONDER.rfc160 + connection.their_did = request.connection.did + connection.their_label = request.label + connection.accept = ( + ConnRecord.ACCEPT_AUTO + if self.profile.settings.get("debug.auto_accept_requests") + else ConnRecord.ACCEPT_MANUAL + ) + connection.state = ConnRecord.State.REQUEST.rfc160 + connection.connection_protocol = CONN_PROTO + async with self.profile.session() as session: + await connection.save( + session, reason="Received connection request from public DID" + ) + + async with self.profile.session() as session: + # Attach the connection request so it can be found and responded to + await connection.attach_request(session, request) + + # Clean associated oob record if not needed anymore + oob_processor = self.profile.inject(OobMessageProcessor) + await oob_processor.clean_finished_oob_record(self.profile, request) + + return connection + + async def create_response( + self, + connection: ConnRecord, + my_endpoint: Optional[str] = None, + mediation_id: Optional[str] = None, + ) -> ConnectionResponse: + """Create a connection response for a received connection request. + + Args: + connection: The `ConnRecord` with a pending connection request + my_endpoint: The endpoint I can be reached at + mediation_id: The record id for mediation that contains routing_keys and + service endpoint + Returns: + A tuple of the updated `ConnRecord` new `ConnectionResponse` message + + """ + self.deprecation_warning() + ConnRecord.log_state( + "Creating connection response", + {"connection_id": connection.connection_id}, + settings=self.profile.settings, + ) + + mediation_records = await self._route_manager.mediation_records_for_connection( + self.profile, connection, mediation_id + ) + + if ConnRecord.State.get(connection.state) not in ( + ConnRecord.State.REQUEST, + ConnRecord.State.RESPONSE, + ): + raise ConnectionManagerError( + "Connection is not in the request or response state" + ) + + async with self.profile.session() as session: + request = await connection.retrieve_request(session) + + if connection.my_did: + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.get_local_did(connection.my_did) + else: + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.create_local_did(SOV, ED25519) + connection.my_did = my_info.did + + # Idempotent; if routing has already been set up, no action taken + await self._route_manager.route_connection_as_inviter( + self.profile, connection, mediation_records + ) + + # Create connection response message + if my_endpoint: + my_endpoints = [my_endpoint] + else: + my_endpoints = [] + default_endpoint = self.profile.settings.get("default_endpoint") + if default_endpoint: + my_endpoints.append(default_endpoint) + my_endpoints.extend(self.profile.settings.get("additional_endpoints", [])) + + did_doc = await self.create_did_document( + my_info, + my_endpoints, + mediation_records=mediation_records, + ) + + response = ConnectionResponse( + connection=ConnectionDetail(did=my_info.did, did_doc=did_doc) + ) + + # Assign thread information + response.assign_thread_from(request) + response.assign_trace_from(request) + # Sign connection field using the invitation key + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + await response.sign_field("connection", connection.invitation_key, wallet) + + # Update connection state + connection.state = ConnRecord.State.RESPONSE.rfc160 + + await connection.save( + session, + reason="Created connection response", + log_params={"response": response}, + ) + + # TODO It's possible the mediation request sent here might arrive + # before the connection response. This would result in an error condition + # difficult to accommodate for without modifying handlers for trust ping + # to ensure the connection is active. + async with self.profile.session() as session: + send_mediation_request = await connection.metadata_get( + session, MediationManager.SEND_REQ_AFTER_CONNECTION + ) + if send_mediation_request: + mgr = MediationManager(self.profile) + _record, request = await mgr.prepare_request(connection.connection_id) + responder = self.profile.inject(BaseResponder) + await responder.send(request, connection_id=connection.connection_id) + + return response + + async def accept_response( + self, response: ConnectionResponse, receipt: MessageReceipt + ) -> ConnRecord: + """Accept a connection response. + + Process a ConnectionResponse message by looking up + the connection request and setting up the pairwise connection. + + Args: + response: The `ConnectionResponse` to accept + receipt: The message receipt + + Returns: + The updated `ConnRecord` representing the connection + + Raises: + ConnectionManagerError: If there is no DID associated with the + connection response + ConnectionManagerError: If the corresponding connection is not + at the request or response stage + + """ + self.deprecation_warning() + connection = None + if response._thread: + # identify the request by the thread ID + try: + async with self.profile.session() as session: + connection = await ConnRecord.retrieve_by_request_id( + session, response._thread_id + ) + except StorageNotFoundError: + pass + + if not connection and receipt.sender_did: + # identify connection by the DID they used for us + try: + async with self.profile.session() as session: + connection = await ConnRecord.retrieve_by_did( + session, receipt.sender_did, receipt.recipient_did + ) + except StorageNotFoundError: + pass + + if not connection: + raise ConnectionManagerError( + "No corresponding connection request found", + error_code=ProblemReportReason.RESPONSE_NOT_ACCEPTED.value, + ) + + if ConnRecord.State.get(connection.state) not in ( + ConnRecord.State.REQUEST, + ConnRecord.State.RESPONSE, + ): + raise ConnectionManagerError( + f"Cannot accept connection response for connection" + f" in state: {connection.state}" + ) + + their_did = response.connection.did + conn_did_doc = response.connection.did_doc + if not conn_did_doc: + raise ConnectionManagerError( + "No DIDDoc provided; cannot connect to public DID" + ) + if their_did != conn_did_doc.did: + raise ConnectionManagerError("Connection DID does not match DIDDoc id") + # Verify connection response using connection field + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + try: + await response.verify_signed_field( + "connection", wallet, connection.invitation_key + ) + except ValueError: + raise ConnectionManagerError( + "connection field verification using invitation_key failed" + ) + await self.store_did_document(conn_did_doc) + + connection.their_did = their_did + connection.state = ConnRecord.State.RESPONSE.rfc160 + async with self.profile.session() as session: + await connection.save(session, reason="Accepted connection response") + + send_mediation_request = await connection.metadata_get( + session, MediationManager.SEND_REQ_AFTER_CONNECTION + ) + if send_mediation_request: + mgr = MediationManager(self.profile) + _record, request = await mgr.prepare_request(connection.connection_id) + responder = self.profile.inject(BaseResponder) + await responder.send(request, connection_id=connection.connection_id) + + return connection + + async def receive_problem_report( + self, + conn_rec: ConnRecord, + report: ConnectionProblemReport, + ): + """Receive problem report.""" + self.deprecation_warning() + if not report.description: + raise ConnectionManagerError("Missing description in problem report") + + if report.description.get("code") in { + reason.value for reason in ProblemReportReason + }: + self._logger.info("Problem report indicates connection is abandoned") + async with self.profile.session() as session: + await conn_rec.abandon( + session, + reason=report.description.get("en"), + ) + else: + raise ConnectionManagerError( + f"Received unrecognized problem report: {report.description}" + ) + + def manager_error_to_problem_report( + self, + e: ConnectionManagerError, + message: Union[ConnectionRequest, ConnectionResponse], + message_receipt, + ) -> tuple[ConnectionProblemReport, Sequence[ConnectionTarget]]: + """Convert ConnectionManagerError to problem report.""" + self._logger.exception("Error receiving connection request") + targets = None + report = None + if e.error_code: + report = ConnectionProblemReport( + description={"en": e.message, "code": e.error_code} + ) + report.assign_thread_from(message) + if message.connection and message.connection.did_doc: + try: + targets = self.diddoc_connection_targets( + message.connection.did_doc, + message_receipt.recipient_verkey, + ) + except ConnectionManagerError: + self._logger.exception("Error parsing DIDDoc for problem report") + + return report, targets diff --git a/acapy_agent/protocols_v2/connections/v1_0/message_types.py b/acapy_agent/protocols_v2/connections/v1_0/message_types.py new file mode 100644 index 0000000000..436cd1df8d --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/message_types.py @@ -0,0 +1,34 @@ +"""Message type identifiers for Connections.""" + +from ....protocols.didcomm_prefix import DIDCommPrefix + +SPEC_URI = ( + "https://github.com/hyperledger/aries-rfcs/tree/" + "9b0aaa39df7e8bd434126c4b33c097aae78d65bf/features/0160-connection-protocol" +) +ARIES_PROTOCOL = "connections/1.0" + +# Message types +CONNECTION_INVITATION = f"{ARIES_PROTOCOL}/invitation" +CONNECTION_REQUEST = f"{ARIES_PROTOCOL}/request" +CONNECTION_RESPONSE = f"{ARIES_PROTOCOL}/response" +PROBLEM_REPORT = f"{ARIES_PROTOCOL}/problem_report" + +PROTOCOL_PACKAGE = "acapy_agent.protocols.connections.v1_0" + +MESSAGE_TYPES = DIDCommPrefix.qualify_all( + { + CONNECTION_INVITATION: ( + f"{PROTOCOL_PACKAGE}.messages.connection_invitation.ConnectionInvitation" + ), + CONNECTION_REQUEST: ( + f"{PROTOCOL_PACKAGE}.messages.connection_request.ConnectionRequest" + ), + CONNECTION_RESPONSE: ( + f"{PROTOCOL_PACKAGE}.messages.connection_response.ConnectionResponse" + ), + PROBLEM_REPORT: ( + f"{PROTOCOL_PACKAGE}.messages.problem_report.ConnectionProblemReport" + ), + } +) diff --git a/acapy_agent/protocols_v2/connections/v1_0/messages/__init__.py b/acapy_agent/protocols_v2/connections/v1_0/messages/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/connections/v1_0/messages/connection_invitation.py b/acapy_agent/protocols_v2/connections/v1_0/messages/connection_invitation.py new file mode 100644 index 0000000000..51e9f5f3c2 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/messages/connection_invitation.py @@ -0,0 +1,211 @@ +"""Represents an invitation message for establishing connection.""" + +from typing import Optional, Sequence +from urllib.parse import parse_qs, urljoin, urlparse + +from marshmallow import EXCLUDE, ValidationError, fields, pre_load, validates_schema + +from .....did.did_key import DIDKey +from .....messaging.agent_message import AgentMessage, AgentMessageSchema +from .....messaging.valid import ( + GENERIC_DID_EXAMPLE, + GENERIC_DID_VALIDATE, + RAW_ED25519_2018_PUBLIC_KEY_EXAMPLE, + RAW_ED25519_2018_PUBLIC_KEY_VALIDATE, +) +from .....wallet.util import b64_to_bytes, bytes_to_b64 +from ..message_types import CONNECTION_INVITATION, PROTOCOL_PACKAGE + +HANDLER_CLASS = ( + f"{PROTOCOL_PACKAGE}.handlers" + ".connection_invitation_handler.ConnectionInvitationHandler" +) + + +class ConnectionInvitation(AgentMessage): + """Class representing a connection invitation.""" + + class Meta: + """Metadata for a connection invitation.""" + + handler_class = HANDLER_CLASS + message_type = CONNECTION_INVITATION + schema_class = "ConnectionInvitationSchema" + + def __init__( + self, + *, + label: Optional[str] = None, + did: Optional[str] = None, + recipient_keys: Sequence[str] = None, + endpoint: Optional[str] = None, + routing_keys: Sequence[str] = None, + image_url: Optional[str] = None, + **kwargs, + ): + """Initialize connection invitation object. + + Args: + label: Optional label for connection invitation + did: DID for this connection invitation + recipient_keys: List of recipient keys + endpoint: Endpoint which this agent can be reached at + routing_keys: List of routing keys + image_url: Optional image URL for connection invitation + kwargs: Additional keyword arguments for the message + """ + super().__init__(**kwargs) + self.label = label + self.did = did + self.recipient_keys = list(recipient_keys) if recipient_keys else None + self.endpoint = endpoint + self.routing_keys = list(routing_keys) if routing_keys else None + self.routing_keys = ( + [ + ( + DIDKey.from_did(key).public_key_b58 + if key.startswith("did:key:") + else key + ) + for key in self.routing_keys + ] + if self.routing_keys + else None + ) + self.image_url = image_url + + def to_url(self, base_url: Optional[str] = None) -> str: + """Convert an invitation to URL format for sharing. + + Returns: + An invite url + + """ + c_json = self.to_json() + c_i = bytes_to_b64(c_json.encode("ascii"), urlsafe=True, pad=False) + result = urljoin(base_url or self.endpoint or "", "?c_i={}".format(c_i)) + return result + + @classmethod + def from_url(cls, url: str) -> "ConnectionInvitation": + """Parse a URL-encoded invitation into a `ConnectionInvitation` message. + + Args: + url: Url to decode + + Returns: + A `ConnectionInvitation` object. + + """ + parts = urlparse(url) + query = parse_qs(parts.query) + if "c_i" in query: + c_i = b64_to_bytes(query["c_i"][0], urlsafe=True) + return cls.from_json(c_i) + return None + + +class ConnectionInvitationSchema(AgentMessageSchema): + """Connection invitation schema class.""" + + class Meta: + """Connection invitation schema metadata.""" + + model_class = ConnectionInvitation + unknown = EXCLUDE + + label = fields.Str( + required=False, + metadata={ + "description": "Optional label for connection invitation", + "example": "Bob", + }, + ) + did = fields.Str( + required=False, + validate=GENERIC_DID_VALIDATE, + metadata={ + "description": "DID for connection invitation", + "example": GENERIC_DID_EXAMPLE, + }, + ) + recipient_keys = fields.List( + fields.Str( + validate=RAW_ED25519_2018_PUBLIC_KEY_VALIDATE, + metadata={ + "description": "Recipient public key", + "example": RAW_ED25519_2018_PUBLIC_KEY_EXAMPLE, + }, + ), + data_key="recipientKeys", + required=False, + metadata={"description": "List of recipient keys"}, + ) + endpoint = fields.Str( + data_key="serviceEndpoint", + required=False, + metadata={ + "description": "Service endpoint at which to reach this agent", + "example": "http://192.168.56.101:8020", + }, + ) + routing_keys = fields.List( + fields.Str( + validate=RAW_ED25519_2018_PUBLIC_KEY_VALIDATE, + metadata={ + "description": "Routing key", + "example": RAW_ED25519_2018_PUBLIC_KEY_EXAMPLE, + }, + ), + data_key="routingKeys", + required=False, + metadata={"description": "List of routing keys"}, + ) + image_url = fields.URL( + data_key="imageUrl", + required=False, + allow_none=True, + metadata={ + "description": "Optional image URL for connection invitation", + "example": "http://192.168.56.101/img/logo.jpg", + }, + ) + + @pre_load + def transform_routing_keys(self, data, **kwargs): + """Transform routingKeys from did:key refs, if necessary.""" + routing_keys = data.get("routingKeys") + if routing_keys: + data["routingKeys"] = [ + ( + DIDKey.from_did(key).public_key_b58 + if key.startswith("did:key:") + else key + ) + for key in routing_keys + ] + return data + + @validates_schema + def validate_fields(self, data, **kwargs): + """Validate schema fields. + + Args: + data: The data to validate + kwargs: Additional keyword arguments + + Raises: + ValidationError: If any of the fields do not validate + + """ + if data.get("did"): + if data.get("recipient_keys"): + raise ValidationError("Fields are incompatible", ("did", "recipientKeys")) + if data.get("endpoint"): + raise ValidationError( + "Fields are incompatible", ("did", "serviceEndpoint") + ) + elif not data.get("recipient_keys") or not data.get("endpoint"): + raise ValidationError( + "Missing required field(s)", ("did", "recipientKeys", "serviceEndpoint") + ) diff --git a/acapy_agent/protocols_v2/connections/v1_0/messages/connection_request.py b/acapy_agent/protocols_v2/connections/v1_0/messages/connection_request.py new file mode 100644 index 0000000000..d1a6940be5 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/messages/connection_request.py @@ -0,0 +1,74 @@ +"""Represents a connection request message.""" + +from typing import Optional + +from marshmallow import EXCLUDE, fields + +from .....messaging.agent_message import AgentMessage, AgentMessageSchema +from ..message_types import CONNECTION_REQUEST, PROTOCOL_PACKAGE +from ..models.connection_detail import ConnectionDetail, ConnectionDetailSchema + +HANDLER_CLASS = ( + f"{PROTOCOL_PACKAGE}.handlers.connection_request_handler.ConnectionRequestHandler" +) + + +class ConnectionRequest(AgentMessage): + """Class representing a connection request.""" + + class Meta: + """Metadata for a connection request.""" + + handler_class = HANDLER_CLASS + message_type = CONNECTION_REQUEST + schema_class = "ConnectionRequestSchema" + + def __init__( + self, + *, + connection: Optional[ConnectionDetail] = None, + label: Optional[str] = None, + image_url: Optional[str] = None, + **kwargs, + ): + """Initialize connection request object. + + Args: + connection (ConnectionDetail): Connection details object + label: Label for this connection request + image_url: Optional image URL for this connection request + kwargs: Additional keyword arguments for the message + + """ + super().__init__(**kwargs) + self.connection = connection + self.label = label + self.image_url = image_url + + +class ConnectionRequestSchema(AgentMessageSchema): + """Connection request schema class.""" + + class Meta: + """Connection request schema metadata.""" + + model_class = ConnectionRequest + unknown = EXCLUDE + + connection = fields.Nested(ConnectionDetailSchema, required=True) + label = fields.Str( + required=True, + metadata={ + "description": "Label for connection request", + "example": "Request to connect with Bob", + }, + ) + image_url = fields.Str( + data_key="imageUrl", + required=False, + allow_none=True, + metadata={ + "description": "Optional image URL for connection request", + "example": "http://192.168.56.101/img/logo.jpg", + }, + ) diff --git a/acapy_agent/protocols_v2/connections/v1_0/messages/connection_response.py b/acapy_agent/protocols_v2/connections/v1_0/messages/connection_response.py new file mode 100644 index 0000000000..aeb9512314 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/messages/connection_response.py @@ -0,0 +1,48 @@ +"""Represents a connection response message.""" + +from typing import Optional + +from marshmallow import EXCLUDE, fields + +from .....messaging.agent_message import AgentMessage, AgentMessageSchema +from ..message_types import CONNECTION_RESPONSE, PROTOCOL_PACKAGE +from ..models.connection_detail import ConnectionDetail, ConnectionDetailSchema + +HANDLER_CLASS = ( + f"{PROTOCOL_PACKAGE}.handlers.connection_response_handler.ConnectionResponseHandler" +) + + +class ConnectionResponse(AgentMessage): + """Class representing a connection response.""" + + class Meta: + """Metadata for a connection response.""" + + handler_class = HANDLER_CLASS + schema_class = "ConnectionResponseSchema" + message_type = CONNECTION_RESPONSE + + def __init__(self, *, connection: Optional[ConnectionDetail] = None, **kwargs): + """Initialize connection response object. + + Args: + connection: Connection details object + kwargs: Additional keyword arguments for the message + + """ + super().__init__(**kwargs) + self.connection = connection + + +class ConnectionResponseSchema(AgentMessageSchema): + """Connection response schema class.""" + + class Meta: + """Connection response schema metadata.""" + + model_class = ConnectionResponse + signed_fields = ("connection",) + unknown = EXCLUDE + + connection = fields.Nested(ConnectionDetailSchema, required=True) diff --git a/acapy_agent/protocols_v2/connections/v1_0/messages/problem_report.py b/acapy_agent/protocols_v2/connections/v1_0/messages/problem_report.py new file mode 100644 index 0000000000..b38a6235ec --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/messages/problem_report.py @@ -0,0 +1,85 @@ +"""Represents a connection problem report message.""" + +import logging +from enum import Enum +from typing import Optional + +from marshmallow import EXCLUDE, ValidationError, validates_schema + +from .....protocols.problem_report.v1_0.message import ProblemReport, ProblemReportSchema +from ..message_types import PROBLEM_REPORT + +HANDLER_CLASS = ( + "acapy_agent.protocols.connections.v1_0.handlers." + "problem_report_handler.ConnectionProblemReportHandler" +) + +LOGGER = logging.getLogger(__name__) + + +class ProblemReportReason(Enum): + """Supported reason codes.""" + + INVITATION_NOT_ACCEPTED = "invitation_not_accepted" + REQUEST_NOT_ACCEPTED = "request_not_accepted" + REQUEST_PROCESSING_ERROR = "request_processing_error" + RESPONSE_NOT_ACCEPTED = "response_not_accepted" + RESPONSE_PROCESSING_ERROR = "response_processing_error" + MISSING_RECIPIENT_KEYS = "invitation_missing_recipient_keys" + MISSING_ENDPOINT = "invitation_missing_endpoint" + + +class ConnectionProblemReport(ProblemReport): + """Base class representing a connection problem report message.""" + + class Meta: + """Connection problem report metadata.""" + + handler_class = HANDLER_CLASS + message_type = PROBLEM_REPORT + schema_class = "ConnectionProblemReportSchema" + + def __init__( + self, + *, + problem_code: Optional[str] = None, + explain: Optional[str] = None, + **kwargs, + ): + """Initialize a ProblemReport message instance. + + Args: + problem_code: The standard error identifier + explain: The localized error explanation + kwargs: Additional keyword arguments + """ + super().__init__(**kwargs) + self.explain = explain + self.problem_code = problem_code + + +class ConnectionProblemReportSchema(ProblemReportSchema): + """Schema for ConnectionProblemReport base class.""" + + class Meta: + """Metadata for connection problem report schema.""" + + model_class = ConnectionProblemReport + unknown = EXCLUDE + + @validates_schema + def validate_fields(self, data, **kwargs): + """Validate schema fields.""" + + if not data.get("description", {}).get("code", ""): + raise ValidationError("Value for description.code must be present") + elif data.get("description", {}).get("code", "") not in [ + prr.value for prr in ProblemReportReason + ]: + locales = list(data.get("description").keys()) + locales.remove("code") + LOGGER.warning( + "Unexpected error code received.\n" + f"Code: {data.get('description').get('code')}, " + f"Description: {data.get('description').get(locales[0])}" + ) diff --git a/acapy_agent/protocols_v2/connections/v1_0/messages/tests/__init__.py b/acapy_agent/protocols_v2/connections/v1_0/messages/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/connections/v1_0/messages/tests/test_connection_invitation.py b/acapy_agent/protocols_v2/connections/v1_0/messages/tests/test_connection_invitation.py new file mode 100644 index 0000000000..0149bd0224 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/messages/tests/test_connection_invitation.py @@ -0,0 +1,118 @@ +from unittest import TestCase, mock + +from ......messaging.models.base import BaseModelError +from .....didcomm_prefix import DIDCommPrefix +from ...message_types import CONNECTION_INVITATION +from ..connection_invitation import ConnectionInvitation + + +class TestConnectionInvitation(TestCase): + label = "Label" + did = "did:sov:QmWbsNYhMrjHiqZDTUTEJs" + endpoint_url = "https://example.com/endpoint" + endpoint_did = "did:sov:A2wBhNYhMrjHiqZDTUYH7u" + image_url = "https://example.com/image.jpg" + key = "8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K" + + def test_init(self): + connection_invitation = ConnectionInvitation( + label=self.label, recipient_keys=[self.key], endpoint=self.endpoint_url + ) + assert connection_invitation.label == self.label + assert connection_invitation.recipient_keys == [self.key] + assert connection_invitation.endpoint == self.endpoint_url + + connection_invitation = ConnectionInvitation(label=self.label, did=self.did) + assert connection_invitation.did == self.did + + def test_type(self): + connection_invitation = ConnectionInvitation( + label=self.label, recipient_keys=[self.key], endpoint=self.endpoint_url + ) + + assert connection_invitation._type == DIDCommPrefix.qualify_current( + CONNECTION_INVITATION + ) + + @mock.patch( + "acapy_agent.protocols.connections.v1_0.messages." + "connection_invitation.ConnectionInvitationSchema.load" + ) + def test_deserialize(self, mock_connection_invitation_schema_load): + obj = {"obj": "obj"} + + connection_invitation = ConnectionInvitation.deserialize(obj) + mock_connection_invitation_schema_load.assert_called_once_with(obj) + + assert ( + connection_invitation is mock_connection_invitation_schema_load.return_value + ) + + @mock.patch( + "acapy_agent.protocols.connections.v1_0.messages." + "connection_invitation.ConnectionInvitationSchema.dump" + ) + def test_serialize(self, mock_connection_invitation_schema_dump): + connection_invitation = ConnectionInvitation( + label=self.label, recipient_keys=[self.key], endpoint=self.endpoint_url + ) + + connection_invitation_dict = connection_invitation.serialize() + mock_connection_invitation_schema_dump.assert_called_once_with( + connection_invitation + ) + + assert ( + connection_invitation_dict + is mock_connection_invitation_schema_dump.return_value + ) + + def test_url_round_trip(self): + connection_invitation = ConnectionInvitation( + label=self.label, recipient_keys=[self.key], endpoint=self.endpoint_url + ) + url = connection_invitation.to_url() + assert isinstance(url, str) + invitation = ConnectionInvitation.from_url(url) + assert isinstance(invitation, ConnectionInvitation) + + def test_from_no_url(self): + url = "http://aries.ca/no_ci" + assert ConnectionInvitation.from_url(url) is None + + +class TestConnectionInvitationSchema(TestCase): + connection_invitation = ConnectionInvitation( + label="label", did="did:sov:QmWbsNYhMrjHiqZDTUTEJs" + ) + + def test_make_model(self): + data = self.connection_invitation.serialize() + model_instance = ConnectionInvitation.deserialize(data) + assert isinstance(model_instance, ConnectionInvitation) + + def test_make_model_invalid(self): + x_conns = [ + ConnectionInvitation( + label="did-and-recip-keys", + did="did:sov:QmWbsNYhMrjHiqZDTUTEJs", + recipient_keys=["8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K"], + ), + ConnectionInvitation( + label="did-and-endpoint", + did="did:sov:QmWbsNYhMrjHiqZDTUTEJs", + endpoint="https://example.com/endpoint", + ), + ConnectionInvitation( + label="no-did-no-recip-keys", + endpoint="https://example.com/endpoint", + ), + ConnectionInvitation( + label="no-did-no-endpoint", + recipient_keys=["8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K"], + ), + ] + for x_conn in x_conns: + data = x_conn.serialize() + with self.assertRaises(BaseModelError): + ConnectionInvitation.deserialize(data) diff --git a/acapy_agent/protocols_v2/connections/v1_0/messages/tests/test_connection_request.py b/acapy_agent/protocols_v2/connections/v1_0/messages/tests/test_connection_request.py new file mode 100644 index 0000000000..f5ea7e4433 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/messages/tests/test_connection_request.py @@ -0,0 +1,119 @@ +from unittest import IsolatedAsyncioTestCase, TestCase, mock + +from ......connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service +from .....didcomm_prefix import DIDCommPrefix +from ...message_types import CONNECTION_REQUEST +from ...models.connection_detail import ConnectionDetail +from ..connection_request import ConnectionRequest + + +class TestConfig: + test_seed = "testseed000000000000000000000001" + test_did = "55GkHamhTU1ZbTbV2ab9DE" + test_verkey = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" + test_label = "Label" + test_endpoint = "http://localhost" + + def make_did_doc(self): + doc = DIDDoc(did=self.test_did) + controller = self.test_did + ident = "1" + pk_value = self.test_verkey + pk = PublicKey( + self.test_did, + ident, + pk_value, + PublicKeyType.ED25519_SIG_2018, + controller, + False, + ) + doc.set(pk) + recip_keys = [pk] + router_keys = [] + service = Service( + self.test_did, + "indy", + "IndyAgent", + recip_keys, + router_keys, + self.test_endpoint, + ) + doc.set(service) + return doc + + +class TestConnectionRequest(TestCase, TestConfig): + def setUp(self): + self.connection_request = ConnectionRequest( + connection=ConnectionDetail(did=self.test_did, did_doc=self.make_did_doc()), + label=self.test_label, + ) + + def test_init(self): + """Test initialization.""" + assert self.connection_request.label == self.test_label + assert self.connection_request.connection.did == self.test_did + # assert self.connection_request.verkey == self.verkey + + def test_type(self): + """Test type.""" + assert self.connection_request._type == DIDCommPrefix.qualify_current( + CONNECTION_REQUEST + ) + + @mock.patch( + "acapy_agent.protocols.connections.v1_0.messages." + "connection_request.ConnectionRequestSchema.load" + ) + def test_deserialize(self, mock_connection_request_schema_load): + """ + Test deserialization. + """ + obj = {"obj": "obj"} + + connection_request = ConnectionRequest.deserialize(obj) + mock_connection_request_schema_load.assert_called_once_with(obj) + + assert connection_request is mock_connection_request_schema_load.return_value + + @mock.patch( + "acapy_agent.protocols.connections.v1_0.messages." + "connection_request.ConnectionRequestSchema.dump" + ) + def test_serialize(self, mock_connection_request_schema_dump): + """ + Test serialization. + """ + connection_request_dict = self.connection_request.serialize() + mock_connection_request_schema_dump.assert_called_once_with( + self.connection_request + ) + + assert connection_request_dict is mock_connection_request_schema_dump.return_value + + +class TestConnectionRequestSchema(IsolatedAsyncioTestCase, TestConfig): + """Test connection request schema.""" + + async def test_make_model(self): + connection_request = ConnectionRequest( + connection=ConnectionDetail(did=self.test_did, did_doc=self.make_did_doc()), + label=self.test_label, + ) + data = connection_request.serialize() + model_instance = ConnectionRequest.deserialize(data) + assert type(model_instance) is type(connection_request) + + async def test_make_model_conn_detail_interpolate_authn_service(self): + did_doc_dict = self.make_did_doc().serialize() + del did_doc_dict["authentication"] + del did_doc_dict["service"] + did_doc = DIDDoc.deserialize(did_doc_dict) + + connection_request = ConnectionRequest( + connection=ConnectionDetail(did=self.test_did, did_doc=did_doc), + label=self.test_label, + ) + data = connection_request.serialize() + model_instance = ConnectionRequest.deserialize(data) + assert type(model_instance) is type(connection_request) diff --git a/acapy_agent/protocols_v2/connections/v1_0/messages/tests/test_connection_response.py b/acapy_agent/protocols_v2/connections/v1_0/messages/tests/test_connection_response.py new file mode 100644 index 0000000000..d3f08ad618 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/messages/tests/test_connection_response.py @@ -0,0 +1,106 @@ +from unittest import IsolatedAsyncioTestCase, TestCase, mock + +from ......connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service +from ......utils.testing import create_test_profile +from ......wallet.base import BaseWallet +from ......wallet.key_type import ED25519 +from .....didcomm_prefix import DIDCommPrefix +from ...message_types import CONNECTION_RESPONSE +from ...models.connection_detail import ConnectionDetail +from ..connection_response import ConnectionResponse + + +class TestConfig: + test_seed = "testseed000000000000000000000001" + test_did = "55GkHamhTU1ZbTbV2ab9DE" + test_verkey = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" + test_endpoint = "http://localhost" + + def make_did_doc(self): + doc = DIDDoc(did=self.test_did) + controller = self.test_did + ident = "1" + pk_value = self.test_verkey + pk = PublicKey( + self.test_did, + ident, + pk_value, + PublicKeyType.ED25519_SIG_2018, + controller, + False, + ) + doc.set(pk) + recip_keys = [pk] + routing_keys = [] + service = Service( + self.test_did, + "indy", + "IndyAgent", + recip_keys, + routing_keys, + self.test_endpoint, + ) + doc.set(service) + return doc + + +class TestConnectionResponse(TestCase, TestConfig): + def setUp(self): + self.connection_response = ConnectionResponse( + connection=ConnectionDetail(did=self.test_did, did_doc=self.make_did_doc()) + ) + + def test_init(self): + assert self.connection_response.connection.did == self.test_did + + def test_type(self): + assert self.connection_response._type == DIDCommPrefix.qualify_current( + CONNECTION_RESPONSE + ) + + @mock.patch( + "acapy_agent.protocols.connections.v1_0.messages." + "connection_response.ConnectionResponseSchema.load" + ) + def test_deserialize(self, mock_connection_response_schema_load): + """ + Test deserialization. + """ + obj = {"obj": "obj"} + + connection_response = ConnectionResponse.deserialize(obj) + mock_connection_response_schema_load.assert_called_once_with(obj) + + assert connection_response is mock_connection_response_schema_load.return_value + + @mock.patch( + "acapy_agent.protocols.connections.v1_0.messages." + "connection_response.ConnectionResponseSchema.dump" + ) + def test_serialize(self, mock_connection_response_schema_dump): + """ + Test serialization. + """ + connection_response_dict = self.connection_response.serialize() + mock_connection_response_schema_dump.assert_called_once_with( + self.connection_response + ) + + assert ( + connection_response_dict is mock_connection_response_schema_dump.return_value + ) + + +class TestConnectionResponseSchema(IsolatedAsyncioTestCase, TestConfig): + async def test_make_model(self): + connection_response = ConnectionResponse( + connection=ConnectionDetail(did=self.test_did, did_doc=self.make_did_doc()) + ) + self.profile = await create_test_profile() + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + key_info = await wallet.create_signing_key(ED25519) + await connection_response.sign_field("connection", key_info.verkey, wallet) + data = connection_response.serialize() + model_instance = ConnectionResponse.deserialize(data) + assert type(model_instance) is type(connection_response) diff --git a/acapy_agent/protocols_v2/connections/v1_0/models/__init__.py b/acapy_agent/protocols_v2/connections/v1_0/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/connections/v1_0/models/connection_detail.py b/acapy_agent/protocols_v2/connections/v1_0/models/connection_detail.py new file mode 100644 index 0000000000..6029cb1614 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/models/connection_detail.py @@ -0,0 +1,114 @@ +"""An object for containing the connection request/response DID information.""" + +from typing import Optional + +from marshmallow import EXCLUDE, fields + +from .....connections.models.diddoc import DIDDoc +from .....messaging.models.base import BaseModel, BaseModelSchema +from .....messaging.valid import INDY_DID_EXAMPLE, INDY_DID_VALIDATE + + +class DIDDocWrapper(fields.Field): + """Field that loads and serializes DIDDoc.""" + + def _serialize(self, value: DIDDoc, attr, obj, **kwargs): + """Serialize the DIDDoc. + + Args: + value: The value to serialize + attr: The attribute being serialized + obj: The object being serialized + kwargs: Additional keyword arguments + + Returns: + The serialized DIDDoc + + """ + return value.serialize(normalize_routing_keys=True) + + def _deserialize(self, value, attr=None, data=None, **kwargs): + """Deserialize a value into a DIDDoc. + + Args: + value: The value to deserialize + attr: The attribute being deserialized + data: The full data being deserialized + kwargs: Additional keyword arguments + + Returns: + The deserialized value + + """ + return DIDDoc.deserialize(value) + + +class ConnectionDetail(BaseModel): + """Class representing the details of a connection.""" + + class Meta: + """ConnectionDetail metadata.""" + + schema_class = "ConnectionDetailSchema" + + def __init__( + self, *, did: Optional[str] = None, did_doc: Optional[DIDDoc] = None, **kwargs + ): + """Initialize a ConnectionDetail instance. + + Args: + did: DID for the connection detail + did_doc: DIDDoc for connection detail + kwargs: Additional keyword arguments + + """ + super().__init__(**kwargs) + self._did = did + self._did_doc = did_doc + + @property + def did(self) -> str: + """Accessor for the connection DID. + + Returns: + The DID for this connection + + """ + return self._did + + @property + def did_doc(self) -> DIDDoc: + """Accessor for the connection DID Document. + + Returns: + The DIDDoc for this connection + + """ + return self._did_doc + + +class ConnectionDetailSchema(BaseModelSchema): + """ConnectionDetail schema.""" + + class Meta: + """ConnectionDetailSchema metadata.""" + + model_class = ConnectionDetail + unknown = EXCLUDE + + did = fields.Str( + data_key="DID", + required=False, + validate=INDY_DID_VALIDATE, + metadata={ + "description": "DID for connection detail", + "example": INDY_DID_EXAMPLE, + }, + ) + did_doc = DIDDocWrapper( + data_key="DIDDoc", + required=False, + metadata={ + "description": "DID document for connection detail", + }, + ) diff --git a/acapy_agent/protocols_v2/connections/v1_0/routes.py b/acapy_agent/protocols_v2/connections/v1_0/routes.py new file mode 100644 index 0000000000..84e43b9874 --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/routes.py @@ -0,0 +1,701 @@ +"""Connection handling admin routes.""" + +import json +from typing import cast + +from aiohttp import web +from aiohttp_apispec import ( + docs, + match_info_schema, + querystring_schema, + request_schema, + response_schema, +) +from marshmallow import fields, validate, validates_schema + +from ....admin.decorators.auth import tenant_authentication +from ....admin.request_context import AdminRequestContext +from ....cache.base import BaseCache +from ....connections.models.conn_record import ConnRecord, ConnRecordSchema +from ....connections.models.conn_peer_record import PeerwiseRecord, PeerwiseRecordSchema +from ....messaging.models.base import BaseModelError +from ....messaging.models.openapi import OpenAPISchema +from ....messaging.models.paginated_query import PaginatedQuerySchema, get_limit_offset +from ....messaging.valid import ( + ENDPOINT_EXAMPLE, + ENDPOINT_VALIDATE, + GENERIC_DID_VALIDATE, + INDY_DID_EXAMPLE, + INDY_DID_VALIDATE, + RAW_ED25519_2018_PUBLIC_KEY_EXAMPLE, + RAW_ED25519_2018_PUBLIC_KEY_VALIDATE, + UUID4_EXAMPLE, + UUID4_VALIDATE, +) +from ....storage.error import StorageError, StorageNotFoundError +from ....wallet.error import WalletError +from .manager import ConnectionManager, ConnectionManagerError +from .message_types import SPEC_URI +from .messages.connection_invitation import ( + ConnectionInvitation, + ConnectionInvitationSchema, +) + + +class ConnectionModuleResponseSchema(OpenAPISchema): + """Response schema for connection module.""" + + +class ConnectionListSchema(OpenAPISchema): + """Result schema for connection list.""" + + results = fields.List( + fields.Nested(ConnRecordSchema()), + required=True, + metadata={"description": "List of connection records"}, + ) + + +class ConnectionMetadataSchema(OpenAPISchema): + """Result schema for connection metadata.""" + + results = fields.Dict( + metadata={"description": "Dictionary of metadata associated with connection."} + ) + + +class ConnectionMetadataSetRequestSchema(OpenAPISchema): + """Request Schema for set metadata.""" + + metadata = fields.Dict( + required=True, + metadata={"description": "Dictionary of metadata to set for connection."}, + ) + + +class ConnectionMetadataQuerySchema(OpenAPISchema): + """Query schema for metadata.""" + + key = fields.Str(required=False, metadata={"description": "Key to retrieve."}) + + +class ReceiveInvitationRequestSchema(ConnectionInvitationSchema): + """Request schema for receive invitation request.""" + + @validates_schema + def validate_fields(self, data, **kwargs): + """Bypass middleware field validation: marshmallow has no data yet.""" + + +class CreateInvitationRequestSchema(OpenAPISchema): + """Request schema for invitation connection target.""" + + recipient_keys = fields.List( + fields.Str( + validate=RAW_ED25519_2018_PUBLIC_KEY_VALIDATE, + metadata={ + "description": "Recipient public key", + "example": RAW_ED25519_2018_PUBLIC_KEY_EXAMPLE, + }, + ), + required=False, + metadata={"description": "List of recipient keys"}, + ) + service_endpoint = fields.Str( + required=False, + metadata={ + "description": "Connection endpoint", + "example": "http://192.168.56.102:8020", + }, + ) + routing_keys = fields.List( + fields.Str( + validate=RAW_ED25519_2018_PUBLIC_KEY_VALIDATE, + metadata={ + "description": "Routing key", + "example": RAW_ED25519_2018_PUBLIC_KEY_EXAMPLE, + }, + ), + required=False, + metadata={"description": "List of routing keys"}, + ) + my_label = fields.Str( + required=False, + metadata={ + "description": "Optional label for connection invitation", + "example": "Bob", + }, + ) + metadata = fields.Dict( + required=False, + metadata={ + "description": ( + "Optional metadata to attach to the connection created with the" + " invitation" + ) + }, + ) + mediation_id = fields.Str( + required=False, + validate=UUID4_VALIDATE, + metadata={ + "description": "Identifier for active mediation record to be used", + "example": UUID4_EXAMPLE, + }, + ) + + +class InvitationResultSchema(OpenAPISchema): + """Result schema for a new connection invitation.""" + + connection_id = fields.Str( + required=True, + metadata={"description": "Connection identifier", "example": UUID4_EXAMPLE}, + ) + invitation = fields.Nested(ConnectionInvitationSchema(), required=True) + invitation_url = fields.Str( + required=True, + metadata={ + "description": "Invitation URL", + "example": "http://192.168.56.101:8020/invite?c_i=eyJAdHlwZSI6Li4ufQ==", + }, + ) + + +class ConnectionStaticRequestSchema(OpenAPISchema): + """Request schema for a new static connection.""" + + my_seed = fields.Str( + required=False, metadata={"description": "Seed to use for the local DID"} + ) + my_did = fields.Str( + required=False, + validate=INDY_DID_VALIDATE, + metadata={"description": "Local DID", "example": INDY_DID_EXAMPLE}, + ) + their_seed = fields.Str( + required=False, metadata={"description": "Seed to use for the remote DID"} + ) + their_did = fields.Str( + required=False, + validate=INDY_DID_VALIDATE, + metadata={"description": "Remote DID", "example": INDY_DID_EXAMPLE}, + ) + their_verkey = fields.Str( + required=False, metadata={"description": "Remote verification key"} + ) + their_endpoint = fields.Str( + required=False, + validate=ENDPOINT_VALIDATE, + metadata={ + "description": "URL endpoint for other party", + "example": ENDPOINT_EXAMPLE, + }, + ) + their_label = fields.Str( + required=False, + metadata={"description": "Other party's label for this connection"}, + ) + alias = fields.Str( + required=False, metadata={"description": "Alias to assign to this connection"} + ) + + +class ConnectionStaticResultSchema(OpenAPISchema): + """Result schema for new static connection.""" + + my_did = fields.Str( + required=True, + validate=INDY_DID_VALIDATE, + metadata={"description": "Local DID", "example": INDY_DID_EXAMPLE}, + ) + my_verkey = fields.Str( + required=True, + validate=RAW_ED25519_2018_PUBLIC_KEY_VALIDATE, + metadata={ + "description": "My verification key", + "example": RAW_ED25519_2018_PUBLIC_KEY_EXAMPLE, + }, + ) + my_endpoint = fields.Str( + required=True, + validate=ENDPOINT_VALIDATE, + metadata={"description": "My URL endpoint", "example": ENDPOINT_EXAMPLE}, + ) + their_did = fields.Str( + required=True, + validate=INDY_DID_VALIDATE, + metadata={"description": "Remote DID", "example": INDY_DID_EXAMPLE}, + ) + their_verkey = fields.Str( + required=True, + validate=RAW_ED25519_2018_PUBLIC_KEY_VALIDATE, + metadata={ + "description": "Remote verification key", + "example": RAW_ED25519_2018_PUBLIC_KEY_EXAMPLE, + }, + ) + record = fields.Nested(ConnRecordSchema(), required=True) + + +class ConnectionsListQueryStringSchema(PaginatedQuerySchema): + """Parameters and validators for connections list request query string.""" + + alias = fields.Str( + required=False, metadata={"description": "Alias", "example": "Barry"} + ) + invitation_key = fields.Str( + required=False, + validate=RAW_ED25519_2018_PUBLIC_KEY_VALIDATE, + metadata={ + "description": "invitation key", + "example": RAW_ED25519_2018_PUBLIC_KEY_EXAMPLE, + }, + ) + my_did = fields.Str( + required=False, + validate=GENERIC_DID_VALIDATE, + metadata={"description": "My DID", "example": INDY_DID_EXAMPLE}, + ) + state = fields.Str( + required=False, + validate=validate.OneOf( + sorted({label for state in ConnRecord.State for label in state.value}) + ), + metadata={"description": "Connection state"}, + ) + their_did = fields.Str( + required=False, + validate=GENERIC_DID_VALIDATE, + metadata={"description": "Their DID", "example": INDY_DID_EXAMPLE}, + ) + their_public_did = fields.Str( + required=False, + validate=GENERIC_DID_VALIDATE, + metadata={"description": "Their Public DID", "example": INDY_DID_EXAMPLE}, + ) + their_role = fields.Str( + required=False, + validate=validate.OneOf( + [label for role in ConnRecord.Role for label in role.value] + ), + metadata={ + "description": "Their role in the connection protocol", + "example": ConnRecord.Role.REQUESTER.rfc160, + }, + ) + connection_protocol = fields.Str( + required=False, + validate=validate.OneOf(ConnRecord.SUPPORTED_PROTOCOLS), + metadata={ + "description": "Connection protocol used", + "example": "connections/1.0", + }, + ) + invitation_msg_id = fields.Str( + required=False, + metadata={ + "description": "Identifier of the associated Invitation Message", + "example": UUID4_EXAMPLE, + }, + ) + + +class CreateInvitationQueryStringSchema(OpenAPISchema): + """Parameters and validators for create invitation request query string.""" + + alias = fields.Str( + required=False, metadata={"description": "Alias", "example": "Barry"} + ) + auto_accept = fields.Boolean( + required=False, + metadata={"description": "Auto-accept connection (defaults to configuration)"}, + ) + public = fields.Boolean( + required=False, + metadata={"description": "Create invitation from public DID (default false)"}, + ) + multi_use = fields.Boolean( + required=False, + metadata={"description": "Create invitation for multiple use (default false)"}, + ) + + +class ReceiveInvitationQueryStringSchema(OpenAPISchema): + """Parameters and validators for receive invitation request query string.""" + + alias = fields.Str( + required=False, metadata={"description": "Alias", "example": "Barry"} + ) + auto_accept = fields.Boolean( + required=False, + metadata={"description": "Auto-accept connection (defaults to configuration)"}, + ) + mediation_id = fields.Str( + required=False, + validate=UUID4_VALIDATE, + metadata={ + "description": "Identifier for active mediation record to be used", + "example": UUID4_EXAMPLE, + }, + ) + + +class AcceptInvitationQueryStringSchema(OpenAPISchema): + """Parameters and validators for accept invitation request query string.""" + + my_endpoint = fields.Str( + required=False, + validate=ENDPOINT_VALIDATE, + metadata={"description": "My URL endpoint", "example": ENDPOINT_EXAMPLE}, + ) + my_label = fields.Str( + required=False, + metadata={"description": "Label for connection", "example": "Broker"}, + ) + mediation_id = fields.Str( + required=False, + validate=UUID4_VALIDATE, + metadata={ + "description": "Identifier for active mediation record to be used", + "example": UUID4_EXAMPLE, + }, + ) + + +class AcceptRequestQueryStringSchema(OpenAPISchema): + """Parameters and validators for accept conn-request web-request query string.""" + + my_endpoint = fields.Str( + required=False, + validate=ENDPOINT_VALIDATE, + metadata={"description": "My URL endpoint", "example": ENDPOINT_EXAMPLE}, + ) + + +class ConnectionsConnIdMatchInfoSchema(OpenAPISchema): + """Path parameters and validators for request taking connection id.""" + + conn_id = fields.Str( + required=True, + metadata={"description": "Connection identifier", "example": UUID4_EXAMPLE}, + ) + + +class ConnIdRefIdMatchInfoSchema(OpenAPISchema): + """Path parameters and validators for request taking connection and ref ids.""" + + conn_id = fields.Str( + required=True, + metadata={"description": "Connection identifier", "example": UUID4_EXAMPLE}, + ) + + ref_id = fields.Str( + required=True, + metadata={ + "description": "Inbound connection identifier", + "example": UUID4_EXAMPLE, + }, + ) + + +class EndpointsResultSchema(OpenAPISchema): + """Result schema for connection endpoints.""" + + my_endpoint = fields.Str( + validate=ENDPOINT_VALIDATE, + metadata={"description": "My endpoint", "example": ENDPOINT_EXAMPLE}, + ) + their_endpoint = fields.Str( + validate=ENDPOINT_VALIDATE, + metadata={"description": "Their endpoint", "example": ENDPOINT_EXAMPLE}, + ) + + +def connection_sort_key(conn): + """Get the sorting key for a particular connection.""" + + conn_rec_state = ConnRecord.State.get(conn["state"]) + if conn_rec_state is ConnRecord.State.ABANDONED: + pfx = "2" + elif conn_rec_state is ConnRecord.State.INVITATION: + pfx = "1" + else: + pfx = "0" + + return pfx + conn["created_at"] + + +@docs( + tags=["connection-v2"], + summary="Query agent-to-agent connections", +) +@querystring_schema(ConnectionsListQueryStringSchema()) +@response_schema(ConnectionListSchema(), 200, description="") +@tenant_authentication +async def connections_list(request: web.BaseRequest): + """Request handler for searching connection records. + + Args: + request: aiohttp request object + + Returns: + The connection list response + + """ + context: AdminRequestContext = request["context"] + + tag_filter = {} + for param_name in ( + "invitation_id", + "my_did", + "their_did", + "request_id", + "invitation_key", + "their_public_did", + "invitation_msg_id", + ): + if param_name in request.query and request.query[param_name] != "": + tag_filter[param_name] = request.query[param_name] + + post_filter = {} + if request.query.get("alias"): + post_filter["alias"] = request.query["alias"] + if request.query.get("state"): + post_filter["state"] = list(ConnRecord.State.get(request.query["state"]).value) + if request.query.get("their_role"): + post_filter["their_role"] = list( + ConnRecord.Role.get(request.query["their_role"]).value + ) + if request.query.get("connection_protocol"): + post_filter["connection_protocol"] = request.query["connection_protocol"] + + limit, offset = get_limit_offset(request) + + profile = context.profile + try: + async with profile.session() as session: + records = await PeerwiseRecord.query( + session, + tag_filter, + limit=limit, + offset=offset, + post_filter_positive=post_filter, + alt=True, + ) + results = [record.serialize() for record in records] + #results.sort(key=connection_sort_key) + except (StorageError, BaseModelError) as err: + raise web.HTTPBadRequest(reason=err.roll_up) from err + + return web.json_response({"results": results}) + + +@docs(tags=["connection-v2"], summary="Fetch a single connection record") +@match_info_schema(ConnectionsConnIdMatchInfoSchema()) +@response_schema(ConnRecordSchema(), 200, description="") +@tenant_authentication +async def connections_retrieve(request: web.BaseRequest): + """Request handler for fetching a single connection record. + + Args: + request: aiohttp request object + + Returns: + The connection record response + + """ + context: AdminRequestContext = request["context"] + connection_id = request.match_info["conn_id"] + + profile = context.profile + try: + async with profile.session() as session: + record = await PeerwiseRecord.retrieve_by_id(session, connection_id) + result = record.serialize() + except StorageNotFoundError as err: + raise web.HTTPNotFound(reason=err.roll_up) from err + except BaseModelError as err: + raise web.HTTPBadRequest(reason=err.roll_up) from err + + return web.json_response(result) + + +@docs(tags=["connection-v2"], summary="Fetch connection remote endpoint") +@match_info_schema(ConnectionsConnIdMatchInfoSchema()) +@response_schema(EndpointsResultSchema(), 200, description="") +@tenant_authentication +async def connections_endpoints(request: web.BaseRequest): + """Request handler for fetching connection endpoints. + + Args: + request: aiohttp request object + + Returns: + The endpoints response + + """ + context: AdminRequestContext = request["context"] + connection_id = request.match_info["conn_id"] + + profile = context.profile + try: + async with profile.session() as session: + record = await PeerwiseRecord.retrieve_by_id(session, connection_id) + endpoints = record.endpoints + except StorageNotFoundError as err: + raise web.HTTPNotFound(reason=err.roll_up) from err + except (BaseModelError, StorageError, WalletError) as err: + raise web.HTTPBadRequest(reason=err.roll_up) from err + + return web.json_response(dict("their_endpoints", endpoints)) + + +@docs( + tags=["connection-v2"], + summary="Accept a stored connection request", + deprecated=True, +) +@match_info_schema(ConnectionsConnIdMatchInfoSchema()) +@querystring_schema(AcceptRequestQueryStringSchema()) +@response_schema(ConnRecordSchema(), 200, description="") +@tenant_authentication +async def connections_accept_request(request: web.BaseRequest): + """Request handler for accepting a stored connection request. + + Args: + request: aiohttp request object + + Returns: + The resulting connection record details + + """ + context: AdminRequestContext = request["context"] + outbound_handler = request["outbound_message_router"] + connection_id = request.match_info["conn_id"] + + profile = context.profile + try: + async with profile.session() as session: + connection = await ConnRecord.retrieve_by_id(session, connection_id) + connection_mgr = ConnectionManager(profile) + my_endpoint = request.query.get("my_endpoint") or None + response = await connection_mgr.create_response(connection, my_endpoint) + result = connection.serialize() + except StorageNotFoundError as err: + raise web.HTTPNotFound(reason=err.roll_up) from err + except (StorageError, WalletError, ConnectionManagerError, BaseModelError) as err: + raise web.HTTPBadRequest(reason=err.roll_up) from err + + await outbound_handler(response, connection_id=connection.connection_id) + return web.json_response(result) + + +@docs(tags=["connection-v2"], summary="Remove an existing connection record") +@match_info_schema(ConnectionsConnIdMatchInfoSchema()) +@response_schema(ConnectionModuleResponseSchema, 200, description="") +@tenant_authentication +async def connections_remove(request: web.BaseRequest): + """Request handler for removing a connection record. + + Args: + request: aiohttp request object + """ + context: AdminRequestContext = request["context"] + connection_id = request.match_info["conn_id"] + profile = context.profile + + try: + async with profile.session() as session: + connection = await PeerwiseRecord.retrieve_by_id(session, connection_id) + await connection.delete_record(session) + cache = session.inject_or(BaseCache) + if cache: + await cache.clear(f"conn_rec_state::{connection_id}") + except StorageNotFoundError as err: + raise web.HTTPNotFound(reason=err.roll_up) from err + except StorageError as err: + raise web.HTTPBadRequest(reason=err.roll_up) from err + + return web.json_response({}) + + +@docs(tags=["connection-v2"], summary="Create a new static connection") +@request_schema(ConnectionStaticRequestSchema()) +@response_schema(ConnectionStaticResultSchema(), 200, description="") +@tenant_authentication +async def connections_create_static(request: web.BaseRequest): + """Request handler for creating a new static connection. + + Args: + request: aiohttp request object + + Returns: + The new connection record + + """ + context: AdminRequestContext = request["context"] + body = await request.json() + + profile = context.profile + connection_mgr = ConnectionManager(profile) + try: + ( + my_info, + their_info, + connection, + ) = await connection_mgr.create_static_connection( + my_seed=body.get("my_seed") or None, + my_did=body.get("my_did") or None, + their_seed=body.get("their_seed") or None, + their_did=body.get("their_did") or None, + their_verkey=body.get("their_verkey") or None, + their_endpoint=body.get("their_endpoint") or None, + their_label=body.get("their_label") or None, + alias=body.get("alias") or None, + ) + response = { + "my_did": my_info.did, + "my_verkey": my_info.verkey, + "my_endpoint": context.settings.get("default_endpoint"), + "their_did": their_info.did, + "their_verkey": their_info.verkey, + "record": connection.serialize(), + } + except (WalletError, StorageError, BaseModelError) as err: + raise web.HTTPBadRequest(reason=err.roll_up) from err + + return web.json_response(response) + + +async def register(app: web.Application): + """Register routes.""" + + app.add_routes( + [ + web.get("/connections-v2", connections_list, allow_head=False), + web.get("/connections-v2/{conn_id}", connections_retrieve, allow_head=False), + web.get( + "/connections-v2/{conn_id}/endpoints", + connections_endpoints, + allow_head=False, + ), + web.post("/connections-v2/create-static", connections_create_static), + web.delete("/connections-v2/{conn_id}", connections_remove), + ] + ) + + +def post_process_routes(app: web.Application): + """Amend swagger API.""" + + # Add top-level tags description + if "tags" not in app._state["swagger_dict"]: + app._state["swagger_dict"]["tags"] = [] + app._state["swagger_dict"]["tags"].append( + { + "name": "connection-v2", + "description": "Connection management V2", + "externalDocs": {"description": "Specification", "url": SPEC_URI}, + } + ) diff --git a/acapy_agent/protocols_v2/connections/v1_0/tests/__init__.py b/acapy_agent/protocols_v2/connections/v1_0/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/connections/v1_0/tests/test_manager.py b/acapy_agent/protocols_v2/connections/v1_0/tests/test_manager.py new file mode 100644 index 0000000000..eb2c5b7f6d --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/tests/test_manager.py @@ -0,0 +1,1241 @@ +from unittest import IsolatedAsyncioTestCase + +import pytest + +from .....cache.base import BaseCache +from .....cache.in_memory import InMemoryCache +from .....connections.models.conn_record import ConnRecord +from .....connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service +from .....core.oob_processor import OobMessageProcessor +from .....messaging.responder import BaseResponder, MockResponder +from .....multitenant.base import BaseMultitenantManager +from .....multitenant.manager import MultitenantManager +from .....resolver.default.legacy_peer import LegacyPeerDIDResolver +from .....resolver.did_resolver import DIDResolver +from .....storage.error import StorageNotFoundError +from .....tests import mock +from .....transport.inbound.receipt import MessageReceipt +from .....utils.testing import create_test_profile +from .....wallet.askar import AskarWallet +from .....wallet.base import BaseWallet, DIDInfo +from .....wallet.did_method import SOV, DIDMethods +from .....wallet.key_type import ED25519, KeyTypes +from ....coordinate_mediation.v1_0.manager import MediationManager +from ....coordinate_mediation.v1_0.messages.mediate_request import MediationRequest +from ....coordinate_mediation.v1_0.models.mediation_record import MediationRecord +from ....coordinate_mediation.v1_0.route_manager import RouteManager +from ..manager import ConnectionManager, ConnectionManagerError +from ..messages.connection_invitation import ConnectionInvitation +from ..messages.connection_request import ConnectionRequest +from ..messages.connection_response import ConnectionResponse +from ..models.connection_detail import ConnectionDetail + + +@pytest.mark.filterwarnings("ignore:Aries RFC 0160.*:DeprecationWarning") +class TestConnectionManager(IsolatedAsyncioTestCase): + def make_did_doc(self, did, verkey): + doc = DIDDoc(did=did) + controller = did + ident = "1" + pk_value = verkey + pk = PublicKey( + did, ident, pk_value, PublicKeyType.ED25519_SIG_2018, controller, False + ) + doc.set(pk) + recip_keys = [pk] + router_keys = [] + service = Service( + did, "indy", "IndyAgent", recip_keys, router_keys, self.test_endpoint + ) + doc.set(service) + return doc + + async def asyncSetUp(self): + self.test_seed = "testseed000000000000000000000001" + self.test_did = "55GkHamhTU1ZbTbV2ab9DE" + self.test_verkey = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" + self.test_endpoint = "http://localhost" + + self.test_target_did = "GbuDUYXaUZRfHD2jeDuQuP" + self.test_target_verkey = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" + + self.responder = MockResponder() + + self.oob_mock = mock.MagicMock(OobMessageProcessor, autospec=True) + self.oob_mock.clean_finished_oob_record = mock.CoroutineMock(return_value=None) + self.route_manager = mock.MagicMock(RouteManager) + self.route_manager.routing_info = mock.CoroutineMock( + return_value=([], self.test_endpoint) + ) + self.route_manager.mediation_record_if_id = mock.CoroutineMock(return_value=None) + self.resolver = DIDResolver() + self.resolver.register_resolver(LegacyPeerDIDResolver()) + + self.profile = await create_test_profile( + { + "default_endpoint": "http://aries.ca/endpoint", + "default_label": "This guy", + "additional_endpoints": ["http://aries.ca/another-endpoint"], + "debug.auto_accept_invites": True, + "debug.auto_accept_requests": True, + }, + ) + + self.profile.context.injector.bind_instance(BaseResponder, self.responder) + self.profile.context.injector.bind_instance(BaseCache, InMemoryCache()) + self.profile.context.injector.bind_instance(OobMessageProcessor, self.oob_mock) + self.profile.context.injector.bind_instance(RouteManager, self.route_manager) + self.profile.context.injector.bind_instance(DIDMethods, DIDMethods()) + self.profile.context.injector.bind_instance(DIDResolver, self.resolver) + self.profile.context.injector.bind_instance(KeyTypes, KeyTypes()) + self.context = self.profile.context + + self.multitenant_mgr = mock.MagicMock(MultitenantManager, autospec=True) + self.context.injector.bind_instance(BaseMultitenantManager, self.multitenant_mgr) + + self.test_mediator_routing_keys = ["3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRR"] + self.test_mediator_conn_id = "mediator-conn-id" + self.test_mediator_endpoint = "http://mediator.example.com" + + self.manager = ConnectionManager(self.profile) + assert self.manager.profile + + async def test_create_invitation_non_multi_use_invitation_fails_on_reuse(self): + connect_record, _ = await self.manager.create_invitation() + + receipt = MessageReceipt(recipient_verkey=connect_record.invitation_key) + + request_a = ConnectionRequest( + connection=ConnectionDetail( + did=self.test_target_did, + did_doc=self.make_did_doc(self.test_target_did, self.test_target_verkey), + ), + label="SameInviteRequestA", + ) + + await self.manager.receive_request(request_a, receipt) + + request_b = ConnectionRequest( + connection=ConnectionDetail( + did=self.test_did, + did_doc=self.make_did_doc(self.test_did, self.test_verkey), + ), + label="SameInviteRequestB", + ) + + # requestB fails because the invitation was not set to multi-use + with self.assertRaises(ConnectionManagerError): + await self.manager.receive_request(request_b, receipt) + + async def test_create_invitation_public(self): + self.context.update_settings({"public_invites": True}) + + self.route_manager.route_verkey = mock.CoroutineMock() + with mock.patch.object( + AskarWallet, "get_public_did", autospec=True + ) as mock_wallet_get_public_did: + mock_wallet_get_public_did.return_value = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + connect_record, connect_invite = await self.manager.create_invitation( + public=True, my_endpoint="testendpoint" + ) + + assert connect_record + assert connect_invite.did.endswith(self.test_did) + self.route_manager.route_verkey.assert_called_once_with( + self.profile, self.test_verkey + ) + + async def test_create_invitation_public_no_public_invites(self): + self.context.update_settings({"public_invites": False}) + + with self.assertRaises(ConnectionManagerError): + await self.manager.create_invitation(public=True, my_endpoint="testendpoint") + + async def test_create_invitation_public_no_public_did(self): + self.context.update_settings({"public_invites": True}) + + with mock.patch.object( + AskarWallet, "get_public_did", autospec=True + ) as mock_wallet_get_public_did: + mock_wallet_get_public_did.return_value = None + with self.assertRaises(ConnectionManagerError): + await self.manager.create_invitation( + public=True, my_endpoint="testendpoint" + ) + + async def test_create_invitation_multi_use(self): + connect_record, _ = await self.manager.create_invitation( + my_endpoint="testendpoint", multi_use=True + ) + + receipt = MessageReceipt(recipient_verkey=connect_record.invitation_key) + + request_a = ConnectionRequest( + connection=ConnectionDetail( + did=self.test_target_did, + did_doc=self.make_did_doc(self.test_target_did, self.test_target_verkey), + ), + label="SameInviteRequestA", + ) + + await self.manager.receive_request(request_a, receipt) + + request_b = ConnectionRequest( + connection=ConnectionDetail( + did=self.test_did, + did_doc=self.make_did_doc(self.test_did, self.test_verkey), + ), + label="SameInviteRequestB", + ) + + await self.manager.receive_request(request_b, receipt) + + async def test_create_invitation_recipient_routing_endpoint(self): + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + await wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + connect_record, _ = await self.manager.create_invitation( + my_endpoint=self.test_endpoint, + recipient_keys=[self.test_verkey], + routing_keys=[self.test_verkey], + ) + + receipt = MessageReceipt(recipient_verkey=connect_record.invitation_key) + + request_a = ConnectionRequest( + connection=ConnectionDetail( + did=self.test_target_did, + did_doc=self.make_did_doc( + self.test_target_did, self.test_target_verkey + ), + ), + label="InviteRequestA", + ) + + await self.manager.receive_request(request_a, receipt) + + async def test_create_invitation_metadata_assigned(self): + async with self.profile.session() as session: + record, _ = await self.manager.create_invitation(metadata={"hello": "world"}) + + assert await record.metadata_get_all(session) == {"hello": "world"} + + async def test_create_invitation_multi_use_metadata_transfers_to_connection(self): + async with self.profile.session() as session: + connect_record, _ = await self.manager.create_invitation( + my_endpoint="testendpoint", multi_use=True, metadata={"test": "value"} + ) + + receipt = MessageReceipt(recipient_verkey=connect_record.invitation_key) + + request = ConnectionRequest( + connection=ConnectionDetail( + did=self.test_target_did, + did_doc=self.make_did_doc( + self.test_target_did, self.test_target_verkey + ), + ), + label="request", + ) + + new_conn_rec = await self.manager.receive_request(request, receipt) + assert new_conn_rec != connect_record + assert await new_conn_rec.metadata_get_all(session) == {"test": "value"} + + async def test_create_invitation_mediation_overwrites_routing_and_endpoint(self): + self.route_manager.routing_info = mock.CoroutineMock( + return_value=(self.test_mediator_routing_keys, self.test_mediator_endpoint) + ) + async with self.profile.session() as session: + mediation_record = MediationRecord( + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + await mediation_record.save(session) + with mock.patch.object( + MediationManager, + "get_default_mediator", + ) as mock_get_default_mediator: + _, invite = await self.manager.create_invitation( + routing_keys=[self.test_verkey], + my_endpoint=self.test_endpoint, + mediation_id=mediation_record.mediation_id, + ) + assert invite.routing_keys == self.test_mediator_routing_keys + assert invite.endpoint == self.test_mediator_endpoint + mock_get_default_mediator.assert_not_called() + + async def test_create_invitation_mediation_using_default(self): + self.route_manager.routing_info = mock.CoroutineMock( + return_value=(self.test_mediator_routing_keys, self.test_mediator_endpoint) + ) + async with self.profile.session() as session: + mediation_record = MediationRecord( + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + await mediation_record.save(session) + with mock.patch.object( + self.route_manager, + "mediation_record_if_id", + mock.CoroutineMock(return_value=mediation_record), + ): + _, invite = await self.manager.create_invitation( + routing_keys=[self.test_verkey], + my_endpoint=self.test_endpoint, + ) + assert invite.routing_keys == self.test_mediator_routing_keys + assert invite.endpoint == self.test_mediator_endpoint + self.route_manager.routing_info.assert_awaited_once_with( + self.profile, mediation_record + ) + + async def test_receive_invitation(self): + (_, connect_invite) = await self.manager.create_invitation( + my_endpoint="testendpoint" + ) + + invitee_record = await self.manager.receive_invitation(connect_invite) + assert ConnRecord.State.get(invitee_record.state) is ConnRecord.State.REQUEST + + async def test_receive_invitation_no_auto_accept(self): + (_, connect_invite) = await self.manager.create_invitation( + my_endpoint="testendpoint" + ) + + invitee_record = await self.manager.receive_invitation( + connect_invite, auto_accept=False + ) + assert ConnRecord.State.get(invitee_record.state) is ConnRecord.State.INVITATION + + async def test_receive_invitation_bad_invitation(self): + x_invites = [ + ConnectionInvitation(), + ConnectionInvitation( + recipient_keys=["3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx"] + ), + ] + + for x_invite in x_invites: + with self.assertRaises(ConnectionManagerError): + await self.manager.receive_invitation(x_invite) + + async def test_receive_invitation_with_did(self): + """Test invitation received with a public DID instead of service info.""" + invite = ConnectionInvitation(did=self.test_did) + invitee_record = await self.manager.receive_invitation(invite) + assert ConnRecord.State.get(invitee_record.state) is ConnRecord.State.REQUEST + + async def test_receive_invitation_mediation_passes_id_when_auto_accept(self): + with mock.patch.object(ConnectionManager, "create_request") as create_request: + _, connect_invite = await self.manager.create_invitation( + my_endpoint="testendpoint" + ) + + invitee_record = await self.manager.receive_invitation( + connect_invite, mediation_id="test-mediation-id", auto_accept=True + ) + create_request.assert_called_once_with( + invitee_record, mediation_id="test-mediation-id" + ) + + async def test_create_request(self): + conn_req = await self.manager.create_request( + ConnRecord( + invitation_key=self.test_verkey, + their_label="Hello", + their_role=ConnRecord.Role.RESPONDER.rfc160, + alias="Bob", + ) + ) + assert conn_req + + async def test_create_request_my_endpoint(self): + conn_req = await self.manager.create_request( + ConnRecord( + invitation_key=self.test_verkey, + their_label="Hello", + their_role=ConnRecord.Role.RESPONDER.rfc160, + alias="Bob", + ), + my_endpoint="http://testendpoint.com/endpoint", + ) + assert conn_req + + async def test_create_request_my_did(self): + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + await wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=self.test_did, + ) + conn_req = await self.manager.create_request( + ConnRecord( + invitation_key=self.test_verkey, + my_did=self.test_did, + their_label="Hello", + their_role=ConnRecord.Role.RESPONDER.rfc160, + alias="Bob", + ) + ) + assert conn_req + + async def test_create_request_multitenant(self): + self.context.update_settings( + {"wallet.id": "test_wallet", "multitenant.enabled": True} + ) + mediation_record = MediationRecord( + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + + with ( + mock.patch.object( + AskarWallet, "create_local_did", autospec=True + ) as mock_wallet_create_local_did, + mock.patch.object( + ConnectionManager, "create_did_document", autospec=True + ) as create_did_document, + mock.patch.object( + self.route_manager, + "mediation_records_for_connection", + mock.CoroutineMock(return_value=[mediation_record]), + ), + ): + mock_wallet_create_local_did.return_value = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + await self.manager.create_request( + ConnRecord( + invitation_key=self.test_verkey, + their_label="Hello", + their_role=ConnRecord.Role.RESPONDER.rfc160, + alias="Bob", + ), + my_endpoint=self.test_endpoint, + ) + create_did_document.assert_called_once_with( + self.manager, + mock_wallet_create_local_did.return_value, + [self.test_endpoint], + mediation_records=[mediation_record], + ) + self.route_manager.route_connection_as_invitee.assert_called_once() + + async def test_create_request_mediation_id(self): + mediation_record = MediationRecord( + mediation_id="test_mediation_id", + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + + record = ConnRecord( + invitation_key=self.test_verkey, + their_label="Hello", + their_role=ConnRecord.Role.RESPONDER.rfc160, + alias="Bob", + ) + + # Ensure the path with new did creation is hit + record.my_did = None + + with ( + mock.patch.object( + ConnectionManager, "create_did_document", autospec=True + ) as create_did_document, + mock.patch.object(AskarWallet, "create_local_did") as create_local_did, + mock.patch.object( + self.route_manager, + "mediation_records_for_connection", + mock.CoroutineMock(return_value=[mediation_record]), + ), + ): + did_info = DIDInfo( + did=self.test_did, + verkey=self.test_verkey, + metadata={}, + method=SOV, + key_type=ED25519, + ) + create_local_did.return_value = did_info + await self.manager.create_request( + record, + mediation_id=mediation_record.mediation_id, + my_endpoint=self.test_endpoint, + ) + create_local_did.assert_called_once_with(SOV, ED25519) + create_did_document.assert_called_once_with( + self.manager, + did_info, + [self.test_endpoint], + mediation_records=[mediation_record], + ) + + async def test_create_request_default_mediator(self): + async with self.profile.session() as session: + mediation_record = MediationRecord( + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + await mediation_record.save(session) + + record = ConnRecord( + invitation_key=self.test_verkey, + their_label="Hello", + their_role=ConnRecord.Role.RESPONDER.rfc160, + alias="Bob", + ) + + # Ensure the path with new did creation is hit + record.my_did = None + + with ( + mock.patch.object( + ConnectionManager, "create_did_document", autospec=True + ) as create_did_document, + mock.patch.object(AskarWallet, "create_local_did") as create_local_did, + mock.patch.object( + self.route_manager, + "mediation_records_for_connection", + mock.CoroutineMock(return_value=[mediation_record]), + ), + ): + did_info = DIDInfo( + did=self.test_did, + verkey=self.test_verkey, + metadata={}, + method=SOV, + key_type=ED25519, + ) + create_local_did.return_value = did_info + await self.manager.create_request( + record, + my_endpoint=self.test_endpoint, + ) + create_local_did.assert_called_once_with(SOV, ED25519) + create_did_document.assert_called_once_with( + self.manager, + did_info, + [self.test_endpoint], + mediation_records=[mediation_record], + ) + + async def test_receive_request_public_did_oob_invite(self): + async with self.profile.session() as session: + mock_request = mock.MagicMock() + mock_request.connection = mock.MagicMock() + mock_request.connection.did = self.test_did + mock_request.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_request.connection.did_doc.did = self.test_did + + receipt = MessageReceipt( + recipient_did=self.test_did, recipient_did_public=True + ) + wallet = session.inject(BaseWallet) + await wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=self.test_did, + ) + + self.context.update_settings({"public_invites": True}) + with ( + mock.patch.object(ConnRecord, "connection_id", autospec=True), + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object(ConnRecord, "attach_request", autospec=True), + mock.patch.object(ConnRecord, "retrieve_by_id", autospec=True), + mock.patch.object(ConnRecord, "retrieve_request", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_invitation_msg_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_msg_id, + mock.patch.object( + self.manager, "store_did_document", mock.CoroutineMock() + ), + ): + mock_conn_retrieve_by_invitation_msg_id.return_value = ConnRecord() + conn_rec = await self.manager.receive_request(mock_request, receipt) + assert conn_rec + + self.oob_mock.clean_finished_oob_record.assert_called_once_with( + self.profile, mock_request + ) + + async def test_receive_request_public_did_unsolicited_fails(self): + async with self.profile.session() as session: + mock_request = mock.MagicMock() + mock_request.connection = mock.MagicMock() + mock_request.connection.did = self.test_did + mock_request.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_request.connection.did_doc.did = self.test_did + + receipt = MessageReceipt( + recipient_did=self.test_did, recipient_did_public=True + ) + wallet = session.inject(BaseWallet) + await wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=self.test_did, + ) + + self.context.update_settings({"public_invites": True}) + with ( + self.assertRaises(ConnectionManagerError), + mock.patch.object(ConnRecord, "connection_id", autospec=True), + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object(ConnRecord, "attach_request", autospec=True), + mock.patch.object(ConnRecord, "retrieve_by_id", autospec=True), + mock.patch.object(ConnRecord, "retrieve_request", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_invitation_msg_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_msg_id, + mock.patch.object( + self.manager, "store_did_document", mock.CoroutineMock() + ), + ): + mock_conn_retrieve_by_invitation_msg_id.return_value = None + await self.manager.receive_request(mock_request, receipt) + + async def test_receive_request_public_did_conn_invite(self): + async with self.profile.session() as session: + mock_request = mock.MagicMock() + mock_request.connection = mock.MagicMock() + mock_request.connection.did = self.test_did + mock_request.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_request.connection.did_doc.did = self.test_did + + receipt = MessageReceipt( + recipient_did=self.test_did, recipient_did_public=True + ) + wallet = session.inject(BaseWallet) + await wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=self.test_did, + ) + + mock_connection_record = mock.MagicMock() + mock_connection_record.save = mock.CoroutineMock() + mock_connection_record.attach_request = mock.CoroutineMock() + + self.context.update_settings({"public_invites": True}) + with ( + mock.patch.object(ConnRecord, "connection_id", autospec=True), + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object(ConnRecord, "attach_request", autospec=True), + mock.patch.object(ConnRecord, "retrieve_by_id", autospec=True), + mock.patch.object(ConnRecord, "retrieve_request", autospec=True), + mock.patch.object( + ConnRecord, + "retrieve_by_invitation_msg_id", + mock.CoroutineMock(return_value=mock_connection_record), + ), + mock.patch.object( + self.manager, "store_did_document", mock.CoroutineMock() + ), + ): + conn_rec = await self.manager.receive_request(mock_request, receipt) + assert conn_rec + + async def test_receive_request_public_did_unsolicited(self): + async with self.profile.session() as session: + mock_request = mock.MagicMock() + mock_request.connection = mock.MagicMock() + mock_request.connection.did = self.test_did + mock_request.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_request.connection.did_doc.did = self.test_did + + receipt = MessageReceipt( + recipient_did=self.test_did, recipient_did_public=True + ) + wallet = session.inject(BaseWallet) + await wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=self.test_did, + ) + + self.context.update_settings({"public_invites": True}) + self.context.update_settings({"requests_through_public_did": True}) + with ( + mock.patch.object(ConnRecord, "connection_id", autospec=True), + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object(ConnRecord, "attach_request", autospec=True), + mock.patch.object(ConnRecord, "retrieve_by_id", autospec=True), + mock.patch.object(ConnRecord, "retrieve_request", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_invitation_msg_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_msg_id, + mock.patch.object( + self.manager, "store_did_document", mock.CoroutineMock() + ), + ): + mock_conn_retrieve_by_invitation_msg_id.return_value = None + conn_rec = await self.manager.receive_request(mock_request, receipt) + assert conn_rec + + async def test_receive_request_public_did_no_did_doc(self): + async with self.profile.session() as session: + mock_request = mock.MagicMock() + mock_request.connection = mock.MagicMock() + mock_request.connection.did = self.test_did + mock_request.connection.did_doc = None + + receipt = MessageReceipt( + recipient_did=self.test_did, recipient_did_public=True + ) + wallet = session.inject(BaseWallet) + await wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=self.test_did, + ) + + self.context.update_settings({"public_invites": True}) + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object(ConnRecord, "attach_request", autospec=True), + mock.patch.object(ConnRecord, "retrieve_by_id", autospec=True), + mock.patch.object(ConnRecord, "retrieve_request", autospec=True), + ): + with self.assertRaises(ConnectionManagerError): + await self.manager.receive_request(mock_request, receipt) + + async def test_receive_request_public_did_wrong_did(self): + async with self.profile.session() as session: + mock_request = mock.MagicMock() + mock_request.connection = mock.MagicMock() + mock_request.connection.did = self.test_did + mock_request.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_request.connection.did_doc.did = "dummy" + + receipt = MessageReceipt( + recipient_did=self.test_did, recipient_did_public=True + ) + wallet = session.inject(BaseWallet) + await wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=self.test_did, + ) + + self.context.update_settings({"public_invites": True}) + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object(ConnRecord, "attach_request", autospec=True), + mock.patch.object(ConnRecord, "retrieve_by_id", autospec=True), + mock.patch.object(ConnRecord, "retrieve_request", autospec=True), + ): + with self.assertRaises(ConnectionManagerError): + await self.manager.receive_request(mock_request, receipt) + + async def test_receive_request_public_did_no_public_invites(self): + mock_request = mock.MagicMock() + mock_request.connection = mock.MagicMock() + mock_request.connection.did = self.test_did + mock_request.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_request.connection.did_doc.did = self.test_did + + receipt = MessageReceipt(recipient_did=self.test_did, recipient_did_public=True) + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + await wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=self.test_did, + ) + + self.context.update_settings({"public_invites": False}) + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object(ConnRecord, "attach_request", autospec=True), + mock.patch.object(ConnRecord, "retrieve_by_id", autospec=True), + mock.patch.object(ConnRecord, "retrieve_request", autospec=True), + mock.patch.object(self.manager, "store_did_document", mock.CoroutineMock()), + ): + with self.assertRaises(ConnectionManagerError): + await self.manager.receive_request(mock_request, receipt) + + async def test_receive_request_public_did_no_auto_accept(self): + async with self.profile.session() as session: + mock_request = mock.MagicMock() + mock_request.connection = mock.MagicMock() + mock_request.connection.did = self.test_did + mock_request.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_request.connection.did_doc.did = self.test_did + + receipt = MessageReceipt( + recipient_did=self.test_did, recipient_did_public=True + ) + wallet = session.inject(BaseWallet) + await wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=self.test_did, + ) + + self.context.update_settings( + {"public_invites": True, "debug.auto_accept_requests": False} + ) + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object(ConnRecord, "attach_request", autospec=True), + mock.patch.object(ConnRecord, "retrieve_by_id", autospec=True), + mock.patch.object(ConnRecord, "retrieve_request", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_invitation_msg_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_msg_id, + mock.patch.object( + self.manager, "store_did_document", mock.CoroutineMock() + ), + ): + mock_conn_retrieve_by_invitation_msg_id.return_value = ConnRecord() + conn_rec = await self.manager.receive_request(mock_request, receipt) + assert conn_rec + + messages = self.responder.messages + assert not messages + + async def test_create_response(self): + conn_rec = ConnRecord(state=ConnRecord.State.REQUEST.rfc160) + + with ( + mock.patch.object(ConnRecord, "log_state", autospec=True), + mock.patch.object(ConnRecord, "retrieve_request", autospec=True), + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object(ConnectionResponse, "sign_field", autospec=True), + mock.patch.object(conn_rec, "metadata_get", mock.CoroutineMock()), + ): + await self.manager.create_response(conn_rec, "http://10.20.30.40:5060/") + + async def test_create_response_multitenant(self): + self.context.update_settings( + {"wallet.id": "test_wallet", "multitenant.enabled": True} + ) + + mediation_record = MediationRecord( + mediation_id="test_mediation_id", + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + + with ( + mock.patch.object(ConnRecord, "log_state", autospec=True), + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object( + ConnRecord, "metadata_get", mock.CoroutineMock(return_value=False) + ), + mock.patch.object(ConnRecord, "retrieve_request", autospec=True), + mock.patch.object(ConnectionResponse, "sign_field", autospec=True), + mock.patch.object( + AskarWallet, "create_local_did", autospec=True + ) as mock_wallet_create_local_did, + mock.patch.object( + ConnectionManager, "create_did_document", autospec=True + ) as create_did_document, + mock.patch.object( + self.route_manager, + "mediation_records_for_connection", + mock.CoroutineMock(return_value=[mediation_record]), + ), + ): + mock_wallet_create_local_did.return_value = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + await self.manager.create_response( + ConnRecord( + state=ConnRecord.State.REQUEST, + ), + my_endpoint=self.test_endpoint, + ) + create_did_document.assert_called_once_with( + self.manager, + mock_wallet_create_local_did.return_value, + [self.test_endpoint], + mediation_records=[mediation_record], + ) + self.route_manager.route_connection_as_inviter.assert_called_once() + + async def test_create_response_bad_state(self): + with self.assertRaises(ConnectionManagerError): + await self.manager.create_response( + ConnRecord( + invitation_key=self.test_verkey, + their_label="Hello", + their_role=ConnRecord.Role.RESPONDER.rfc160, + alias="Bob", + state=ConnRecord.State.ABANDONED.rfc160, + ) + ) + + async def test_create_response_mediation(self): + mediation_record = MediationRecord( + mediation_id="test_mediation_id", + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + + record = ConnRecord( + connection_id="test-conn-id", + invitation_key=self.test_verkey, + their_label="Hello", + their_role=ConnRecord.Role.RESPONDER.rfc160, + alias="Bob", + state=ConnRecord.State.REQUEST.rfc160, + ) + + # Ensure the path with new did creation is hit + record.my_did = None + + with ( + mock.patch.object(ConnRecord, "log_state", autospec=True), + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object( + record, "metadata_get", mock.CoroutineMock(return_value=False) + ), + mock.patch.object( + ConnectionManager, "create_did_document", autospec=True + ) as create_did_document, + mock.patch.object(AskarWallet, "create_local_did") as create_local_did, + mock.patch.object( + self.route_manager, + "mediation_records_for_connection", + mock.CoroutineMock(return_value=[mediation_record]), + ), + mock.patch.object(record, "retrieve_request", autospec=True), + mock.patch.object(ConnectionResponse, "sign_field", autospec=True), + ): + did_info = DIDInfo( + did=self.test_did, + verkey=self.test_verkey, + metadata={}, + method=SOV, + key_type=ED25519, + ) + create_local_did.return_value = did_info + await self.manager.create_response( + record, + mediation_id=mediation_record.mediation_id, + my_endpoint=self.test_endpoint, + ) + create_local_did.assert_called_once_with(SOV, ED25519) + create_did_document.assert_called_once_with( + self.manager, + did_info, + [self.test_endpoint], + mediation_records=[mediation_record], + ) + self.route_manager.route_connection_as_inviter.assert_called_once() + + async def test_create_response_auto_send_mediation_request(self): + conn_rec = ConnRecord( + state=ConnRecord.State.REQUEST.rfc160, + ) + conn_rec.my_did = None + + with ( + mock.patch.object(ConnRecord, "log_state", autospec=True), + mock.patch.object(ConnRecord, "retrieve_request", autospec=True), + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object(ConnectionResponse, "sign_field", autospec=True), + mock.patch.object( + conn_rec, "metadata_get", mock.CoroutineMock(return_value=True) + ), + ): + await self.manager.create_response(conn_rec) + + assert len(self.responder.messages) == 1 + message, target = self.responder.messages[0] + assert isinstance(message, MediationRequest) + assert target["connection_id"] == conn_rec.connection_id + + async def test_accept_response_find_by_thread_id(self): + mock_response = mock.MagicMock() + mock_response._thread = mock.MagicMock() + mock_response.connection = mock.MagicMock() + mock_response.connection.did = self.test_target_did + mock_response.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_response.connection.did_doc.did = self.test_target_did + mock_response.verify_signed_field = mock.CoroutineMock(return_value="sig_verkey") + receipt = MessageReceipt(recipient_did=self.test_did, recipient_did_public=True) + + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_request_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_req_id, + mock.patch.object( + MediationManager, "get_default_mediator", mock.CoroutineMock() + ), + mock.patch.object(self.manager, "store_did_document", mock.CoroutineMock()), + ): + mock_conn_retrieve_by_req_id.return_value = mock.MagicMock( + did=self.test_target_did, + did_doc=mock.MagicMock(did=self.test_target_did), + state=ConnRecord.State.RESPONSE.rfc23, + save=mock.CoroutineMock(), + metadata_get=mock.CoroutineMock(), + connection_id="test-conn-id", + invitation_key="test-invitation-key", + ) + conn_rec = await self.manager.accept_response(mock_response, receipt) + assert conn_rec.their_did == self.test_target_did + assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.RESPONSE + + async def test_accept_response_not_found_by_thread_id_receipt_has_sender_did(self): + mock_response = mock.MagicMock() + mock_response._thread = mock.MagicMock() + mock_response.connection = mock.MagicMock() + mock_response.connection.did = self.test_target_did + mock_response.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_response.connection.did_doc.did = self.test_target_did + mock_response.verify_signed_field = mock.CoroutineMock(return_value="sig_verkey") + + receipt = MessageReceipt(sender_did=self.test_target_did) + + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_request_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_req_id, + mock.patch.object( + ConnRecord, "retrieve_by_did", mock.CoroutineMock() + ) as mock_conn_retrieve_by_did, + mock.patch.object( + MediationManager, "get_default_mediator", mock.CoroutineMock() + ), + mock.patch.object(self.manager, "store_did_document", mock.CoroutineMock()), + ): + mock_conn_retrieve_by_req_id.side_effect = StorageNotFoundError() + mock_conn_retrieve_by_did.return_value = mock.MagicMock( + did=self.test_target_did, + did_doc=mock.MagicMock(did=self.test_target_did), + state=ConnRecord.State.RESPONSE.rfc23, + save=mock.CoroutineMock(), + metadata_get=mock.CoroutineMock(return_value=False), + connection_id="test-conn-id", + invitation_key="test-invitation-id", + ) + + conn_rec = await self.manager.accept_response(mock_response, receipt) + assert conn_rec.their_did == self.test_target_did + assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.RESPONSE + + assert not self.responder.messages + + async def test_accept_response_not_found_by_thread_id_nor_receipt_sender_did(self): + mock_response = mock.MagicMock() + mock_response._thread = mock.MagicMock() + mock_response.connection = mock.MagicMock() + mock_response.connection.did = self.test_target_did + mock_response.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_response.connection.did_doc.did = self.test_target_did + + receipt = MessageReceipt(sender_did=self.test_target_did) + + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_request_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_req_id, + mock.patch.object( + ConnRecord, "retrieve_by_did", mock.CoroutineMock() + ) as mock_conn_retrieve_by_did, + ): + mock_conn_retrieve_by_req_id.side_effect = StorageNotFoundError() + mock_conn_retrieve_by_did.side_effect = StorageNotFoundError() + + with self.assertRaises(ConnectionManagerError): + await self.manager.accept_response(mock_response, receipt) + + async def test_accept_response_find_by_thread_id_bad_state(self): + mock_response = mock.MagicMock() + mock_response._thread = mock.MagicMock() + mock_response.connection = mock.MagicMock() + mock_response.connection.did = self.test_target_did + mock_response.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_response.connection.did_doc.did = self.test_target_did + + receipt = MessageReceipt(sender_did=self.test_target_did) + + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_request_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_req_id, + ): + mock_conn_retrieve_by_req_id.return_value = mock.MagicMock( + state=ConnRecord.State.ABANDONED.rfc23 + ) + + with self.assertRaises(ConnectionManagerError): + await self.manager.accept_response(mock_response, receipt) + + async def test_accept_response_find_by_thread_id_no_connection_did_doc(self): + mock_response = mock.MagicMock() + mock_response._thread = mock.MagicMock() + mock_response.connection = mock.MagicMock() + mock_response.connection.did = self.test_target_did + mock_response.connection.did_doc = None + + receipt = MessageReceipt(sender_did=self.test_target_did) + + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_request_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_req_id, + ): + mock_conn_retrieve_by_req_id.return_value = mock.MagicMock( + did=self.test_target_did, + did_doc=mock.MagicMock(did=self.test_target_did), + state=ConnRecord.State.RESPONSE.rfc23, + ) + + with self.assertRaises(ConnectionManagerError): + await self.manager.accept_response(mock_response, receipt) + + async def test_accept_response_find_by_thread_id_did_mismatch(self): + mock_response = mock.MagicMock() + mock_response._thread = mock.MagicMock() + mock_response.connection = mock.MagicMock() + mock_response.connection.did = self.test_target_did + mock_response.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_response.connection.did_doc.did = self.test_did + + receipt = MessageReceipt(sender_did=self.test_target_did) + + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_request_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_req_id, + ): + mock_conn_retrieve_by_req_id.return_value = mock.MagicMock( + did=self.test_target_did, + did_doc=mock.MagicMock(did=self.test_target_did), + state=ConnRecord.State.RESPONSE.rfc23, + ) + + with self.assertRaises(ConnectionManagerError): + await self.manager.accept_response(mock_response, receipt) + + async def test_accept_response_verify_invitation_key_sign_failure(self): + mock_response = mock.MagicMock() + mock_response._thread = mock.MagicMock() + mock_response.connection = mock.MagicMock() + mock_response.connection.did = self.test_target_did + mock_response.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_response.connection.did_doc.did = self.test_target_did + mock_response.verify_signed_field = mock.CoroutineMock(side_effect=ValueError) + receipt = MessageReceipt(recipient_did=self.test_did, recipient_did_public=True) + + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_request_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_req_id, + mock.patch.object( + MediationManager, "get_default_mediator", mock.CoroutineMock() + ), + ): + mock_conn_retrieve_by_req_id.return_value = mock.MagicMock( + did=self.test_target_did, + did_doc=mock.MagicMock(did=self.test_target_did), + state=ConnRecord.State.RESPONSE.rfc23, + save=mock.CoroutineMock(), + metadata_get=mock.CoroutineMock(), + connection_id="test-conn-id", + invitation_key="test-invitation-key", + ) + with self.assertRaises(ConnectionManagerError): + await self.manager.accept_response(mock_response, receipt) + + async def test_accept_response_auto_send_mediation_request(self): + mock_response = mock.MagicMock() + mock_response._thread = mock.MagicMock() + mock_response.connection = mock.MagicMock() + mock_response.connection.did = self.test_target_did + mock_response.connection.did_doc = mock.MagicMock(spec=DIDDoc) + mock_response.connection.did_doc.did = self.test_target_did + mock_response.verify_signed_field = mock.CoroutineMock(return_value="sig_verkey") + receipt = MessageReceipt(recipient_did=self.test_did, recipient_did_public=True) + + with ( + mock.patch.object(ConnRecord, "save", autospec=True), + mock.patch.object( + ConnRecord, "retrieve_by_request_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_req_id, + mock.patch.object( + MediationManager, "get_default_mediator", mock.CoroutineMock() + ), + mock.patch.object(self.manager, "store_did_document", mock.CoroutineMock()), + ): + mock_conn_retrieve_by_req_id.return_value = mock.MagicMock( + did=self.test_target_did, + did_doc=mock.MagicMock(did=self.test_target_did), + state=ConnRecord.State.RESPONSE.rfc23, + save=mock.CoroutineMock(), + metadata_get=mock.CoroutineMock(return_value=True), + connection_id="test-conn-id", + invitation_key="test-invitation-key", + ) + conn_rec = await self.manager.accept_response(mock_response, receipt) + assert conn_rec.their_did == self.test_target_did + assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.RESPONSE + + assert len(self.responder.messages) == 1 + message, target = self.responder.messages[0] + assert isinstance(message, MediationRequest) + assert target["connection_id"] == conn_rec.connection_id diff --git a/acapy_agent/protocols_v2/connections/v1_0/tests/test_routes.py b/acapy_agent/protocols_v2/connections/v1_0/tests/test_routes.py new file mode 100644 index 0000000000..917c795c5f --- /dev/null +++ b/acapy_agent/protocols_v2/connections/v1_0/tests/test_routes.py @@ -0,0 +1,828 @@ +import json +from unittest import IsolatedAsyncioTestCase +from unittest.mock import ANY + +from .....admin.request_context import AdminRequestContext +from .....cache.base import BaseCache +from .....cache.in_memory import InMemoryCache +from .....connections.models.conn_record import ConnRecord +from .....storage.error import StorageNotFoundError +from .....tests import mock +from .....utils.testing import create_test_profile +from .. import routes as test_module + + +class TestConnectionRoutes(IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.session_inject = {} + self.profile = await create_test_profile( + settings={ + "admin.admin_api_key": "secret-key", + } + ) + self.context = AdminRequestContext.test_context(self.session_inject, self.profile) + self.request_dict = { + "context": self.context, + "outbound_message_router": mock.CoroutineMock(), + } + self.request = mock.MagicMock( + app={}, + match_info={}, + query={}, + __getitem__=lambda _, k: self.request_dict[k], + headers={"x-api-key": "secret-key"}, + ) + + async def test_connections_list(self): + self.request.query = { + "invitation_id": "dummy", # exercise tag filter assignment + "their_role": ConnRecord.Role.REQUESTER.rfc160, + "connection_protocol": "connections/1.0", + "invitation_key": "some-invitation-key", + "their_public_did": "a_public_did", + "invitation_msg_id": "dummy_msg", + } + + STATE_COMPLETED = ConnRecord.State.COMPLETED + STATE_INVITATION = ConnRecord.State.INVITATION + STATE_ABANDONED = ConnRecord.State.ABANDONED + with mock.patch.object(test_module, "ConnRecord", autospec=True) as mock_conn_rec: + mock_conn_rec.query = mock.CoroutineMock() + mock_conn_rec.Role = ConnRecord.Role + mock_conn_rec.State = mock.MagicMock( + COMPLETED=STATE_COMPLETED, + INVITATION=STATE_INVITATION, + ABANDONED=STATE_ABANDONED, + get=mock.MagicMock( + side_effect=[ + ConnRecord.State.ABANDONED, + ConnRecord.State.COMPLETED, + ConnRecord.State.INVITATION, + ] + ), + ) + conns = [ # in ascending order here + mock.MagicMock( + serialize=mock.MagicMock( + return_value={ + "state": ConnRecord.State.COMPLETED.rfc23, + "created_at": "1234567890", + } + ) + ), + mock.MagicMock( + serialize=mock.MagicMock( + return_value={ + "state": ConnRecord.State.INVITATION.rfc23, + "created_at": "1234567890", + } + ) + ), + mock.MagicMock( + serialize=mock.MagicMock( + return_value={ + "state": ConnRecord.State.ABANDONED.rfc23, + "created_at": "1234567890", + } + ) + ), + ] + mock_conn_rec.query.return_value = [conns[2], conns[0], conns[1]] # jumbled + + with mock.patch.object(test_module.web, "json_response") as mock_response: + await test_module.connections_list(self.request) + mock_conn_rec.query.assert_called_once_with( + ANY, + { + "invitation_id": "dummy", + "invitation_key": "some-invitation-key", + "their_public_did": "a_public_did", + "invitation_msg_id": "dummy_msg", + }, + limit=100, + offset=0, + post_filter_positive={ + "their_role": list(ConnRecord.Role.REQUESTER.value), + "connection_protocol": "connections/1.0", + }, + alt=True, + ) + mock_response.assert_called_once_with( + { + "results": [ + { + k: c.serialize.return_value[k] + for k in ["state", "created_at"] + } + for c in conns + ] + } # sorted + ) + + async def test_connections_list_x(self): + self.request.query = { + "their_role": ConnRecord.Role.REQUESTER.rfc160, + "alias": "my connection", + "state": ConnRecord.State.COMPLETED.rfc23, + } + + STATE_COMPLETED = ConnRecord.State.COMPLETED + ROLE_REQUESTER = ConnRecord.Role.REQUESTER + with mock.patch.object(test_module, "ConnRecord", autospec=True) as mock_conn_rec: + mock_conn_rec.Role = mock.MagicMock(return_value=ROLE_REQUESTER) + mock_conn_rec.State = mock.MagicMock( + COMPLETED=STATE_COMPLETED, + get=mock.MagicMock(return_value=ConnRecord.State.COMPLETED), + ) + mock_conn_rec.query = mock.CoroutineMock( + side_effect=test_module.StorageError() + ) + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_list(self.request) + + async def test_connections_retrieve(self): + self.request.match_info = {"conn_id": "dummy"} + mock_conn_rec = mock.MagicMock() + mock_conn_rec.serialize = mock.MagicMock(return_value={"hello": "world"}) + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + + await test_module.connections_retrieve(self.request) + mock_response.assert_called_once_with({"hello": "world"}) + + async def test_connections_endpoints(self): + self.request.match_info = {"conn_id": "dummy"} + + with ( + mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr_cls, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_mgr_cls.return_value = mock.MagicMock( + get_endpoints=mock.CoroutineMock( + return_value=("localhost:8080", "1.2.3.4:8081") + ) + ) + await test_module.connections_endpoints(self.request) + mock_response.assert_called_once_with( + { + "my_endpoint": "localhost:8080", + "their_endpoint": "1.2.3.4:8081", + } + ) + + async def test_connections_endpoints_x(self): + self.request.match_info = {"conn_id": "dummy"} + + with ( + mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr_cls, + mock.patch.object(test_module.web, "json_response"), + ): + mock_conn_mgr_cls.return_value = mock.MagicMock( + get_endpoints=mock.CoroutineMock(side_effect=StorageNotFoundError()) + ) + + with self.assertRaises(test_module.web.HTTPNotFound): + await test_module.connections_endpoints(self.request) + + mock_conn_mgr_cls.return_value = mock.MagicMock( + get_endpoints=mock.CoroutineMock(side_effect=test_module.WalletError()) + ) + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_endpoints(self.request) + + async def test_connections_metadata(self): + self.request.match_info = {"conn_id": "dummy"} + mock_conn_rec = mock.MagicMock() + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, + mock.patch.object( + mock_conn_rec, "metadata_get_all", mock.CoroutineMock() + ) as mock_metadata_get_all, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + mock_metadata_get_all.return_value = {"hello": "world"} + + await test_module.connections_metadata(self.request) + mock_metadata_get_all.assert_called_once() + mock_response.assert_called_once_with({"results": {"hello": "world"}}) + + async def test_connections_metadata_get_single(self): + self.request.match_info = {"conn_id": "dummy"} + mock_conn_rec = mock.MagicMock() + self.request.query = {"key": "test"} + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, + mock.patch.object(mock_conn_rec, "metadata_get_all", mock.CoroutineMock()), + mock.patch.object( + mock_conn_rec, "metadata_get", mock.CoroutineMock() + ) as mock_metadata_get, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + mock_metadata_get.return_value = {"test": "value"} + + await test_module.connections_metadata(self.request) + mock_metadata_get.assert_called_once() + mock_response.assert_called_once_with({"results": {"test": "value"}}) + + async def test_connections_metadata_x(self): + self.request.match_info = {"conn_id": "dummy"} + mock_conn_rec = mock.MagicMock() + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, + mock.patch.object( + mock_conn_rec, "metadata_get_all", mock.CoroutineMock() + ) as mock_metadata_get_all, + mock.patch.object(test_module.web, "json_response"), + ): + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + mock_metadata_get_all.side_effect = StorageNotFoundError() + + with self.assertRaises(test_module.web.HTTPNotFound): + await test_module.connections_metadata(self.request) + + mock_metadata_get_all.side_effect = test_module.BaseModelError() + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_metadata(self.request) + + async def test_connections_metadata_set(self): + self.request.match_info = {"conn_id": "dummy"} + mock_conn_rec = mock.MagicMock() + self.request.json = mock.CoroutineMock( + return_value={"metadata": {"hello": "world"}} + ) + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, + mock.patch.object( + mock_conn_rec, "metadata_get_all", mock.CoroutineMock() + ) as mock_metadata_get_all, + mock.patch.object( + mock_conn_rec, "metadata_set", mock.CoroutineMock() + ) as mock_metadata_set, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + mock_metadata_get_all.return_value = {"hello": "world"} + + await test_module.connections_metadata_set(self.request) + mock_metadata_set.assert_called_once() + mock_response.assert_called_once_with({"results": {"hello": "world"}}) + + async def test_connections_metadata_set_x(self): + self.request.match_info = {"conn_id": "dummy"} + mock_conn_rec = mock.MagicMock() + self.request.json = mock.CoroutineMock( + return_value={"metadata": {"hello": "world"}} + ) + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, + mock.patch.object(mock_conn_rec, "metadata_get_all", mock.CoroutineMock()), + mock.patch.object( + mock_conn_rec, "metadata_set", mock.CoroutineMock() + ) as mock_metadata_set, + mock.patch.object(test_module.web, "json_response"), + ): + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + mock_metadata_set.side_effect = StorageNotFoundError() + + with self.assertRaises(test_module.web.HTTPNotFound): + await test_module.connections_metadata_set(self.request) + + mock_metadata_set.side_effect = test_module.BaseModelError() + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_metadata_set(self.request) + + async def test_connections_retrieve_not_found(self): + self.request.match_info = {"conn_id": "dummy"} + + with mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.side_effect = StorageNotFoundError() + + with self.assertRaises(test_module.web.HTTPNotFound): + await test_module.connections_retrieve(self.request) + + async def test_connections_retrieve_x(self): + self.request.match_info = {"conn_id": "dummy"} + mock_conn_rec = mock.MagicMock() + mock_conn_rec.serialize = mock.MagicMock(side_effect=test_module.BaseModelError()) + + with mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_retrieve(self.request) + + async def test_connections_create_invitation(self): + self.context.update_settings({"public_invites": True}) + body = { + "recipient_keys": ["test"], + "routing_keys": ["test"], + "service_endpoint": "http://example.com", + "metadata": {"hello": "world"}, + "mediation_id": "some-id", + } + self.request.json = mock.CoroutineMock(return_value=body) + self.request.query = { + "auto_accept": "true", + "alias": "alias", + "public": "true", + "multi_use": "true", + } + + with ( + mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_mgr.return_value.create_invitation = mock.CoroutineMock( + return_value=( + mock.MagicMock( # connection record + connection_id="dummy", alias="conn-alias" + ), + mock.MagicMock( # invitation + serialize=mock.MagicMock(return_value={"a": "value"}), + to_url=mock.MagicMock(return_value="http://endpoint.ca"), + ), + ) + ) + + await test_module.connections_create_invitation(self.request) + mock_conn_mgr.return_value.create_invitation.assert_called_once_with( + **{ + key: json.loads(value) if key != "alias" else value + for key, value in self.request.query.items() + }, + my_label=None, + recipient_keys=body["recipient_keys"], + routing_keys=body["routing_keys"], + my_endpoint=body["service_endpoint"], + metadata=body["metadata"], + mediation_id="some-id", + ) + mock_response.assert_called_once_with( + { + "connection_id": "dummy", + "invitation": {"a": "value"}, + "invitation_url": "http://endpoint.ca", + "alias": "conn-alias", + } + ) + + async def test_connections_create_invitation_x(self): + self.context.update_settings({"public_invites": True}) + self.request.json = mock.CoroutineMock() + self.request.query = { + "auto_accept": "true", + "alias": "alias", + "public": "true", + "multi_use": "true", + } + + with mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr: + mock_conn_mgr.return_value.create_invitation = mock.CoroutineMock( + side_effect=test_module.ConnectionManagerError() + ) + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_create_invitation(self.request) + + async def test_connections_create_invitation_x_bad_mediation_id(self): + self.context.update_settings({"public_invites": True}) + body = { + "recipient_keys": ["test"], + "routing_keys": ["test"], + "service_endpoint": "http://example.com", + "metadata": {"hello": "world"}, + "mediation_id": "some-id", + } + self.request.json = mock.CoroutineMock(return_value=body) + self.request.query = { + "auto_accept": "true", + "alias": "alias", + "public": "true", + "multi_use": "true", + } + with mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr: + mock_conn_mgr.return_value.create_invitation = mock.CoroutineMock( + side_effect=StorageNotFoundError() + ) + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_create_invitation(self.request) + + async def test_connections_create_invitation_public_forbidden(self): + self.context.update_settings({"public_invites": False}) + self.request.json = mock.CoroutineMock() + self.request.query = { + "auto_accept": "true", + "alias": "alias", + "public": "true", + "multi_use": "true", + } + + with self.assertRaises(test_module.web.HTTPForbidden): + await test_module.connections_create_invitation(self.request) + + async def test_connections_receive_invitation(self): + self.request.json = mock.CoroutineMock() + self.request.query = { + "auto_accept": "true", + "alias": "alias", + } + + mock_conn_rec = mock.MagicMock() + mock_conn_rec.serialize = mock.MagicMock() + + with ( + mock.patch.object( + test_module.ConnectionInvitation, "deserialize", autospec=True + ), + mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_mgr.return_value.receive_invitation = mock.CoroutineMock( + return_value=mock_conn_rec + ) + + await test_module.connections_receive_invitation(self.request) + mock_response.assert_called_once_with(mock_conn_rec.serialize.return_value) + + async def test_connections_receive_invitation_bad(self): + self.request.json = mock.CoroutineMock() + self.request.query = { + "auto_accept": "true", + "alias": "alias", + } + + mock_conn_rec = mock.MagicMock() + mock_conn_rec.serialize = mock.MagicMock() + + with ( + mock.patch.object( + test_module.ConnectionInvitation, "deserialize", autospec=True + ) as mock_inv_deser, + mock.patch.object(test_module, "ConnectionManager", autospec=True), + ): + mock_inv_deser.side_effect = test_module.BaseModelError() + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_receive_invitation(self.request) + + async def test_connections_receive_invitation_forbidden(self): + self.context.update_settings({"admin.no_receive_invites": True}) + + with self.assertRaises(test_module.web.HTTPForbidden): + await test_module.connections_receive_invitation(self.request) + + async def test_connections_receive_invitation_x_bad_mediation_id(self): + self.request.json = mock.CoroutineMock() + self.request.query = { + "auto_accept": "true", + "alias": "alias", + "mediation_id": "some-id", + } + + mock_conn_rec = mock.MagicMock() + mock_conn_rec.serialize = mock.MagicMock() + + with ( + mock.patch.object( + test_module.ConnectionInvitation, "deserialize", autospec=True + ), + mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr, + ): + mock_conn_mgr.return_value.receive_invitation = mock.CoroutineMock( + side_effect=StorageNotFoundError() + ) + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_receive_invitation(self.request) + + async def test_connections_accept_invitation(self): + self.request.match_info = {"conn_id": "dummy"} + self.request.query = { + "my_label": "label", + "my_endpoint": "http://endpoint.ca", + } + + mock_conn_rec = mock.MagicMock() + mock_conn_rec.serialize = mock.MagicMock() + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, + mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + mock_conn_mgr.return_value.create_request = mock.CoroutineMock() + + await test_module.connections_accept_invitation(self.request) + mock_response.assert_called_once_with(mock_conn_rec.serialize.return_value) + + async def test_connections_accept_invitation_not_found(self): + self.request.match_info = {"conn_id": "dummy"} + + with mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.side_effect = StorageNotFoundError() + + with self.assertRaises(test_module.web.HTTPNotFound): + await test_module.connections_accept_invitation(self.request) + + async def test_connections_accept_invitation_x(self): + self.request.match_info = {"conn_id": "dummy"} + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ), + mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr, + ): + mock_conn_mgr.return_value.create_request = mock.CoroutineMock( + side_effect=test_module.ConnectionManagerError() + ) + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_accept_invitation(self.request) + + async def test_connections_accept_invitation_x_bad_mediation_id(self): + self.request.match_info = {"conn_id": "dummy"} + self.request.query["mediation_id"] = "some-id" + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ), + mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr, + ): + mock_conn_mgr.return_value.create_request = mock.CoroutineMock( + side_effect=StorageNotFoundError() + ) + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_accept_invitation(self.request) + + async def test_connections_accept_request(self): + self.request.match_info = {"conn_id": "dummy"} + self.request.query = { + "my_endpoint": "http://endpoint.ca", + } + + mock_conn_rec = mock.MagicMock() + mock_conn_rec.serialize = mock.MagicMock() + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, + mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + mock_conn_mgr.return_value.create_response = mock.CoroutineMock() + + await test_module.connections_accept_request(self.request) + mock_response.assert_called_once_with(mock_conn_rec.serialize.return_value) + + async def test_connections_accept_request_not_found(self): + self.request.match_info = {"conn_id": "dummy"} + + with mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.side_effect = StorageNotFoundError() + + with self.assertRaises(test_module.web.HTTPNotFound): + await test_module.connections_accept_request(self.request) + + async def test_connections_accept_request_x(self): + self.request.match_info = {"conn_id": "dummy"} + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ), + mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr, + mock.patch.object(test_module.web, "json_response"), + ): + mock_conn_mgr.return_value.create_response = mock.CoroutineMock( + side_effect=test_module.ConnectionManagerError() + ) + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_accept_request(self.request) + + async def test_connections_remove(self): + self.request.match_info = {"conn_id": "dummy"} + mock_conn_rec = mock.MagicMock() + mock_conn_rec.delete_record = mock.CoroutineMock() + + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + + await test_module.connections_remove(self.request) + mock_response.assert_called_once_with({}) + + async def test_connections_remove_cache_key(self): + cache = InMemoryCache() + profile = self.context.profile + await cache.set("conn_rec_state::dummy", "active") + profile.context.injector.bind_instance(BaseCache, cache) + self.request.match_info = {"conn_id": "dummy"} + mock_conn_rec = mock.MagicMock() + mock_conn_rec.delete_record = mock.CoroutineMock() + assert (await cache.get("conn_rec_state::dummy")) == "active" + with ( + mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + + await test_module.connections_remove(self.request) + mock_response.assert_called_once_with({}) + assert not (await cache.get("conn_rec_state::dummy")) + + async def test_connections_remove_not_found(self): + self.request.match_info = {"conn_id": "dummy"} + + with mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.side_effect = StorageNotFoundError() + + with self.assertRaises(test_module.web.HTTPNotFound): + await test_module.connections_remove(self.request) + + async def test_connections_remove_x(self): + self.request.match_info = {"conn_id": "dummy"} + mock_conn_rec = mock.MagicMock( + delete_record=mock.CoroutineMock(side_effect=test_module.StorageError()) + ) + + with mock.patch.object( + test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.return_value = mock_conn_rec + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_remove(self.request) + + async def test_connections_create_static(self): + self.request.json = mock.CoroutineMock( + return_value={ + "my_seed": "my_seed", + "my_did": "my_did", + "their_seed": "their_seed", + "their_did": "their_did", + "their_verkey": "their_verkey", + "their_endpoint": "their_endpoint", + "their_role": "their_role", + "alias": "alias", + } + ) + self.request.query = { + "auto_accept": "true", + "alias": "alias", + } + self.request.match_info = {"conn_id": "dummy"} + + mock_conn_rec = mock.MagicMock() + mock_conn_rec.serialize = mock.MagicMock() + mock_my_info = mock.MagicMock() + mock_my_info.did = "my_did" + mock_my_info.verkey = "my_verkey" + mock_their_info = mock.MagicMock() + mock_their_info.did = "their_did" + mock_their_info.verkey = "their_verkey" + + with ( + mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr, + mock.patch.object(test_module.web, "json_response") as mock_response, + ): + mock_conn_mgr.return_value.create_static_connection = mock.CoroutineMock( + return_value=(mock_my_info, mock_their_info, mock_conn_rec) + ) + + await test_module.connections_create_static(self.request) + mock_response.assert_called_once_with( + { + "my_did": mock_my_info.did, + "my_verkey": mock_my_info.verkey, + "their_did": mock_their_info.did, + "their_verkey": mock_their_info.verkey, + "my_endpoint": self.context.settings.get("default_endpoint"), + "record": mock_conn_rec.serialize.return_value, + } + ) + + async def test_connections_create_static_x(self): + self.request.json = mock.CoroutineMock( + return_value={ + "my_seed": "my_seed", + "my_did": "my_did", + "their_seed": "their_seed", + "their_did": "their_did", + "their_verkey": "their_verkey", + "their_endpoint": "their_endpoint", + "their_role": "their_role", + "alias": "alias", + } + ) + self.request.query = { + "auto_accept": "true", + "alias": "alias", + } + self.request.match_info = {"conn_id": "dummy"} + + mock_conn_rec = mock.MagicMock() + mock_conn_rec.serialize = mock.MagicMock() + mock_my_info = mock.MagicMock() + mock_my_info.did = "my_did" + mock_my_info.verkey = "my_verkey" + mock_their_info = mock.MagicMock() + mock_their_info.did = "their_did" + mock_their_info.verkey = "their_verkey" + + with mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as mock_conn_mgr: + mock_conn_mgr.return_value.create_static_connection = mock.CoroutineMock( + side_effect=test_module.WalletError() + ) + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.connections_create_static(self.request) + + async def test_register(self): + mock_app = mock.MagicMock() + mock_app.add_routes = mock.MagicMock() + + await test_module.register(mock_app) + mock_app.add_routes.assert_called_once() + + async def test_post_process_routes(self): + mock_app = mock.MagicMock(_state={"swagger_dict": {}}) + test_module.post_process_routes(mock_app) + assert "tags" in mock_app._state["swagger_dict"] diff --git a/acapy_agent/protocols_v2/discovery/__init__.py b/acapy_agent/protocols_v2/discovery/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/discovery/definition.py b/acapy_agent/protocols_v2/discovery/definition.py new file mode 100644 index 0000000000..62bddef6f5 --- /dev/null +++ b/acapy_agent/protocols_v2/discovery/definition.py @@ -0,0 +1,10 @@ +"""Version definitions for this protocol.""" + +versions = [ + { + "major_version": 1, + "minimum_minor_version": 0, + "current_minor_version": 0, + "path": "v1_0", + } +] diff --git a/acapy_agent/protocols_v2/discovery/v1_0/__init__.py b/acapy_agent/protocols_v2/discovery/v1_0/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/discovery/v1_0/message_types.py b/acapy_agent/protocols_v2/discovery/v1_0/message_types.py new file mode 100644 index 0000000000..e85a45570f --- /dev/null +++ b/acapy_agent/protocols_v2/discovery/v1_0/message_types.py @@ -0,0 +1,58 @@ +"""Message type identifiers for Trust Pings.""" + +import logging +from ....messaging.v2_agent_message import V2AgentMessage + +SPEC_URI = "https://didcomm.org/discover-features/2.0/queries" + +# Message types +QUERIES = "https://didcomm.org/discover-features/2.0/queries" +DISCLOSE = "https://didcomm.org/discover-features/2.0/disclose" + +PROTOCOL_PACKAGE = "acapy_agent.protocols_v2.discovery.v1_0" + +BASIC_MESSAGE = "https://didcomm.org/basicmessage/2.0/message" +EMPTY = "https://didcomm.org/empty/1.0/empty" +PING = "https://didcomm.org/trust-ping/2.0/ping" + + +class discover_features: + """Discover Features 2.0 DIDComm V2 Protocol.""" + + async def __call__(self, *args, **kwargs): + """Call the Handler.""" + await self.handle(*args, **kwargs) + + @staticmethod + async def handle(context, responder, payload): + """Handle the incoming message.""" + logging.getLogger(__name__) + their_did = context.message_receipt.sender_verkey.split("#")[0] + our_did = context.message_receipt.recipient_verkey.split("#")[0] + error_result = V2AgentMessage( + message={ + "type": DISCLOSE, + "thid": payload["id"], + "body": { + "disclosures": [ + { + "feature-type": "protocol", + "id": protocol.rsplit("/", 1)[0], + } + for protocol in [PING, BASIC_MESSAGE, QUERIES, EMPTY] + ], + }, + "to": [their_did], + "from": our_did, + } + ) + await responder.send_reply(error_result) + + +HANDLERS = { + QUERIES: f"{PROTOCOL_PACKAGE}.message_types.discover_features", +}.items() + +MESSAGE_TYPES = { + QUERIES: f"{PROTOCOL_PACKAGE}.message_types.discover_features", +} diff --git a/acapy_agent/protocols_v2/discovery/v1_0/routes.py b/acapy_agent/protocols_v2/discovery/v1_0/routes.py new file mode 100644 index 0000000000..ac74320d33 --- /dev/null +++ b/acapy_agent/protocols_v2/discovery/v1_0/routes.py @@ -0,0 +1,251 @@ +"""Trust ping admin routes.""" + +from aiohttp import web +from aiohttp_apispec import docs, request_schema, response_schema +from marshmallow import fields +from didcomm_messaging import DIDCommMessaging, RoutingService +from didcomm_messaging.resolver import DIDResolver as DMPResolver + +from ....admin.decorators.auth import tenant_authentication +from ....admin.request_context import AdminRequestContext +from ....messaging.models.openapi import OpenAPISchema +from ....messaging.valid import UUID4_EXAMPLE +from .message_types import SPEC_URI + +from ....wallet.base import BaseWallet +from ....wallet.did_info import DIDInfo +from ....wallet.did_method import ( + DIDMethod, + DIDMethods, +) +from ....wallet.did_posture import DIDPosture +from ....messaging.v2_agent_message import V2AgentMessage +from ....connections.models.connection_target import ConnectionTarget + + +class BaseDIDCommV2Schema(OpenAPISchema): + """Request schema for performing a ping.""" + + to_did = fields.Str( + required=True, + allow_none=False, + metadata={"description": "Comment for the ping message"}, + ) + + +class PingRequestSchema(BaseDIDCommV2Schema): + """Request schema for performing a ping.""" + + response_requested = fields.Bool( + required=False, + allow_none=True, + metadata={"description": "Comment for the ping message"}, + ) + + +class PingRequestResponseSchema(OpenAPISchema): + """Request schema for performing a ping.""" + + thread_id = fields.Str( + required=False, metadata={"description": "Thread ID of the ping message"} + ) + + +class PingConnIdMatchInfoSchema(OpenAPISchema): + """Path parameters and validators for request taking connection id.""" + + conn_id = fields.Str( + required=True, + metadata={"description": "Connection identifier", "example": UUID4_EXAMPLE}, + ) + + +def format_did_info(info: DIDInfo): + """Serialize a DIDInfo object.""" + if info: + return { + "did": info.did, + "verkey": info.verkey, + "posture": DIDPosture.get(info.metadata).moniker, + "key_type": info.key_type.key_type, + "method": info.method.method_name, + "metadata": info.metadata, + } + + +async def get_mydid(request: web.BaseRequest): + """Get a DID that can be used for communication.""" + context: AdminRequestContext = request["context"] + # filter_did = request.query.get("did") + # filter_verkey = request.query.get("verkey") + filter_posture = DIDPosture.get(request.query.get("posture")) + results = [] + async with context.session() as session: + did_methods: DIDMethods = session.inject(DIDMethods) + filter_method: DIDMethod | None = did_methods.from_method( + request.query.get("method") or "did:peer:2" + ) + # key_types = session.inject(KeyTypes) + # filter_key_type = key_types.from_key_type(request.query.get("key_type", "")) + wallet: BaseWallet | None = session.inject_or(BaseWallet) + if not wallet: + raise web.HTTPForbidden(reason="No wallet available") + else: + dids = await wallet.get_local_dids() + results = [ + format_did_info(info) + for info in dids + if ( + filter_posture is None + or DIDPosture.get(info.metadata) is DIDPosture.WALLET_ONLY + ) + and (not filter_method or info.method == filter_method) + # and (not filter_key_type or info.key_type == filter_key_type) + ] + + results.sort(key=lambda info: (DIDPosture.get(info["posture"]).ordinal, info["did"])) + our_did = results[0]["did"] + return our_did + + +async def get_target(request: web.BaseRequest, to_did: str, from_did: str): + """Get Connection Target from did.""" + context: AdminRequestContext = request["context"] + + try: + async with context.profile.session() as session: + resolver = session.inject(DMPResolver) + await resolver.resolve(to_did) + except Exception as err: + raise web.HTTPNotFound(reason=str(err)) from err + + async with context.session() as session: + ctx = session + messaging = ctx.inject(DIDCommMessaging) + routing_service = ctx.inject(RoutingService) + frm = to_did + services = await routing_service._resolve_services(messaging.resolver, frm) + chain = [ + { + "did": frm, + "service": services, + } + ] + + # Loop through service DIDs until we run out of DIDs to forward to + to_target = services[0].service_endpoint.uri + found_forwardable_service = await routing_service.is_forwardable_service( + messaging.resolver, services[0] + ) + while found_forwardable_service: + services = await routing_service._resolve_services( + messaging.resolver, to_target + ) + if services: + chain.append( + { + "did": to_target, + "service": services, + } + ) + to_target = services[0].service_endpoint.uri + found_forwardable_service = ( + await routing_service.is_forwardable_service( + messaging.resolver, services[0] + ) + if services + else False + ) + reply_destination = [ + ConnectionTarget( + did=f"{to_did}#key-1", + endpoint=service.service_endpoint.uri, + recipient_keys=[f"{to_did}#key-1"], + sender_key=from_did + "#key-1", + ) + for service in chain[-1]["service"] + ] + return reply_destination + + +class DiscoverFeaturesQuerySchema(BaseDIDCommV2Schema): + """Request schema for performing a ping.""" + + queries = fields.Bool( + required=False, + allow_none=True, + metadata={"description": "Comment for the ping message"}, + ) + + +@docs(tags=["discoveryv2", "didcommv2"], summary="Request the list of supported features") +@request_schema(DiscoverFeaturesQuerySchema()) +@response_schema(PingRequestResponseSchema(), 200, description="") +@tenant_authentication +async def discover_features_query(request: web.BaseRequest): + """Request handler for sending a trust ping to a connection. + + Args: + request: aiohttp request object + + """ + request["context"] + outbound_handler = request["outbound_message_router"] + body = await request.json() + to_did = body.get("to_did") + + our_did = await get_mydid(request) + their_did = to_did + reply_destination = await get_target(request, to_did, our_did) + msg = V2AgentMessage( + message={ + "type": "https://didcomm.org/discover-features/2.0/queries", + "body": { + "queries": [ + { + "feature-type": "protocol", + "match": "https://didcomm.org/*", + }, + { + "feature-type": "goal-code", + "match": "org.didcomm.*", + }, + ] + }, + "to": [their_did], + "from": our_did, + } + ) + await outbound_handler(msg, target_list=reply_destination) + return web.json_response(msg.message) + + +async def register(app: web.Application): + """Register routes.""" + + app.add_routes([web.post("/discover-features/send-query", discover_features_query)]) + + +def post_process_routes(app: web.Application): + """Amend swagger API.""" + + # Add top-level tags description + if "tags" not in app._state["swagger_dict"]: + app._state["swagger_dict"]["tags"] = [] + app._state["swagger_dict"]["tags"].append( + { + "name": "discoveryv2", + "description": "Feature Discovery to Contact", + "externalDocs": {"description": "Specification", "url": SPEC_URI}, + } + ) + app._state["swagger_dict"]["tags"].append( + { + "name": "didcommv2", + "description": "DIDComm V2 based protocols for Interop-a-thon", + "externalDocs": { + "description": "Specification", + "url": "https://didcomm.org", + }, + } + ) diff --git a/acapy_agent/protocols_v2/empty/__init__.py b/acapy_agent/protocols_v2/empty/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/empty/definition.py b/acapy_agent/protocols_v2/empty/definition.py new file mode 100644 index 0000000000..62bddef6f5 --- /dev/null +++ b/acapy_agent/protocols_v2/empty/definition.py @@ -0,0 +1,10 @@ +"""Version definitions for this protocol.""" + +versions = [ + { + "major_version": 1, + "minimum_minor_version": 0, + "current_minor_version": 0, + "path": "v1_0", + } +] diff --git a/acapy_agent/protocols_v2/empty/v1_0/__init__.py b/acapy_agent/protocols_v2/empty/v1_0/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/empty/v1_0/message_types.py b/acapy_agent/protocols_v2/empty/v1_0/message_types.py new file mode 100644 index 0000000000..e1064f1168 --- /dev/null +++ b/acapy_agent/protocols_v2/empty/v1_0/message_types.py @@ -0,0 +1,33 @@ +"""Message type identifiers for Trust Pings.""" + +import logging + +SPEC_URI = "https://identity.foundation/didcomm-messaging/spec/v2.1/#the-empty-message" + +# Message types +EMPTY = "https://didcomm.org/empty/1.0/empty" + +PROTOCOL_PACKAGE = "acapy_agent.protocols_v2.empty.v1_0" + + +class basic_message: + """Empty 1.0 DIDComm V2 Protocol.""" + + async def __call__(self, *args, **kwargs): + """Call the Handler.""" + await self.handle(*args, **kwargs) + + @staticmethod + async def handle(context, responder, payload): + """Handle the incoming message.""" + logger = logging.getLogger(__name__) + logger.trace("Received empty message") + + +HANDLERS = { + EMPTY: f"{PROTOCOL_PACKAGE}.message_types.empty", +}.items() + +MESSAGE_TYPES = { + EMPTY: f"{PROTOCOL_PACKAGE}.message_types.empty", +} diff --git a/acapy_agent/protocols_v2/nametag/__init__.py b/acapy_agent/protocols_v2/nametag/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/nametag/definition.py b/acapy_agent/protocols_v2/nametag/definition.py new file mode 100644 index 0000000000..62bddef6f5 --- /dev/null +++ b/acapy_agent/protocols_v2/nametag/definition.py @@ -0,0 +1,10 @@ +"""Version definitions for this protocol.""" + +versions = [ + { + "major_version": 1, + "minimum_minor_version": 0, + "current_minor_version": 0, + "path": "v1_0", + } +] diff --git a/acapy_agent/protocols_v2/nametag/v1_0/__init__.py b/acapy_agent/protocols_v2/nametag/v1_0/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/nametag/v1_0/message_types.py b/acapy_agent/protocols_v2/nametag/v1_0/message_types.py new file mode 100644 index 0000000000..f4ad62304c --- /dev/null +++ b/acapy_agent/protocols_v2/nametag/v1_0/message_types.py @@ -0,0 +1,47 @@ +"""Message type identifiers for Trust Pings.""" + +import logging +from ....messaging.v2_agent_message import V2AgentMessage + +SPEC_URI = "https://didcomm.org/basicmessage/2.0/message" + +# Message types +BASIC_MESSAGE = "https://colton.wolkins.net/dev/name-tag/2.0/get-name" + +PROTOCOL_PACKAGE = "acapy_agent.protocols_v2.nametag.v1_0" + + +class basic_message: + """Basic Message 2.0 DIDComm V2 Protocol.""" + + async def __call__(self, *args, **kwargs): + """Call the Handler.""" + await self.handle(*args, **kwargs) + + @staticmethod + async def handle(context, responder, payload): + """Handle the incoming message.""" + logging.getLogger(__name__) + their_did = context.message_receipt.sender_verkey.split("#")[0] + our_did = context.message_receipt.recipient_verkey.split("#")[0] + error_result = V2AgentMessage( + message={ + "type": BASIC_MESSAGE, + "body": { + "content": "Hello from acapy", + }, + "to": [their_did], + "from": our_did, + "lang": "en", + } + ) + await responder.send_reply(error_result) + + +HANDLERS = { + BASIC_MESSAGE: f"{PROTOCOL_PACKAGE}.message_types.basic_message", +}.items() + +MESSAGE_TYPES = { + BASIC_MESSAGE: f"{PROTOCOL_PACKAGE}.message_types.basic_message", +} diff --git a/acapy_agent/protocols_v2/nametag/v1_0/routes.py b/acapy_agent/protocols_v2/nametag/v1_0/routes.py new file mode 100644 index 0000000000..f8798ca200 --- /dev/null +++ b/acapy_agent/protocols_v2/nametag/v1_0/routes.py @@ -0,0 +1,250 @@ +"""Trust ping admin routes.""" + +from aiohttp import web +from aiohttp_apispec import docs, request_schema, response_schema +from marshmallow import fields +from didcomm_messaging import DIDCommMessaging, RoutingService +from didcomm_messaging.resolver import DIDResolver as DMPResolver + +from ....admin.decorators.auth import tenant_authentication +from ....admin.request_context import AdminRequestContext +from ....messaging.models.openapi import OpenAPISchema +from ....messaging.valid import UUID4_EXAMPLE +from .message_types import SPEC_URI + +from ....wallet.base import BaseWallet +from ....wallet.did_info import DIDInfo +from ....wallet.did_method import ( + DIDMethod, + DIDMethods, +) +from ....wallet.did_posture import DIDPosture +from ....messaging.v2_agent_message import V2AgentMessage +from ....connections.models.connection_target import ConnectionTarget + + +class BaseDIDCommV2Schema(OpenAPISchema): + """Request schema for performing a ping.""" + + to_did = fields.Str( + required=True, + allow_none=False, + metadata={"description": "Comment for the ping message"}, + ) + + +class PingRequestSchema(BaseDIDCommV2Schema): + """Request schema for performing a ping.""" + + response_requested = fields.Bool( + required=False, + allow_none=True, + metadata={"description": "Comment for the ping message"}, + ) + + +class PingRequestResponseSchema(OpenAPISchema): + """Request schema for performing a ping.""" + + thread_id = fields.Str( + required=False, metadata={"description": "Thread ID of the ping message"} + ) + + +class PingConnIdMatchInfoSchema(OpenAPISchema): + """Path parameters and validators for request taking connection id.""" + + conn_id = fields.Str( + required=True, + metadata={"description": "Connection identifier", "example": UUID4_EXAMPLE}, + ) + + +def format_did_info(info: DIDInfo): + """Serialize a DIDInfo object.""" + if info: + return { + "did": info.did, + "verkey": info.verkey, + "posture": DIDPosture.get(info.metadata).moniker, + "key_type": info.key_type.key_type, + "method": info.method.method_name, + "metadata": info.metadata, + } + + +async def get_mydid(request: web.BaseRequest): + """Get a DID that can be used for communication.""" + context: AdminRequestContext = request["context"] + # filter_did = request.query.get("did") + # filter_verkey = request.query.get("verkey") + filter_posture = DIDPosture.get(request.query.get("posture")) + results = [] + async with context.session() as session: + did_methods: DIDMethods = session.inject(DIDMethods) + filter_method: DIDMethod | None = did_methods.from_method( + request.query.get("method") or "did:peer:2" + ) + # key_types = session.inject(KeyTypes) + # filter_key_type = key_types.from_key_type(request.query.get("key_type", "")) + wallet: BaseWallet | None = session.inject_or(BaseWallet) + if not wallet: + raise web.HTTPForbidden(reason="No wallet available") + else: + dids = await wallet.get_local_dids() + results = [ + format_did_info(info) + for info in dids + if ( + filter_posture is None + or DIDPosture.get(info.metadata) is DIDPosture.WALLET_ONLY + ) + and (not filter_method or info.method == filter_method) + # and (not filter_key_type or info.key_type == filter_key_type) + ] + + results.sort(key=lambda info: (DIDPosture.get(info["posture"]).ordinal, info["did"])) + our_did = results[0]["did"] + return our_did + + +async def get_target(request: web.BaseRequest, to_did: str, from_did: str): + """Get Connection Target from did.""" + context: AdminRequestContext = request["context"] + + try: + async with context.profile.session() as session: + resolver = session.inject(DMPResolver) + await resolver.resolve(to_did) + except Exception as err: + raise web.HTTPNotFound(reason=str(err)) from err + + async with context.session() as session: + ctx = session + messaging = ctx.inject(DIDCommMessaging) + routing_service = ctx.inject(RoutingService) + frm = to_did + services = await routing_service._resolve_services(messaging.resolver, frm) + chain = [ + { + "did": frm, + "service": services, + } + ] + + # Loop through service DIDs until we run out of DIDs to forward to + to_target = services[0].service_endpoint.uri + found_forwardable_service = await routing_service.is_forwardable_service( + messaging.resolver, services[0] + ) + while found_forwardable_service: + services = await routing_service._resolve_services( + messaging.resolver, to_target + ) + if services: + chain.append( + { + "did": to_target, + "service": services, + } + ) + to_target = services[0].service_endpoint.uri + found_forwardable_service = ( + await routing_service.is_forwardable_service( + messaging.resolver, services[0] + ) + if services + else False + ) + reply_destination = [ + ConnectionTarget( + did=f"{to_did}#key-1", + endpoint=service.service_endpoint.uri, + recipient_keys=[f"{to_did}#key-1"], + sender_key=from_did + "#key-1", + ) + for service in chain[-1]["service"] + ] + return reply_destination + + +class BasicMessageSchema(BaseDIDCommV2Schema): + """Request schema for performing a ping.""" + + content = fields.Str( + required=True, + allow_none=False, + metadata={"description": "Basic Message message content"}, + ) + + +@docs(tags=["basicmessagev2", "didcommv2"], summary="Send a Basic Message") +@request_schema(BasicMessageSchema()) +@response_schema(PingRequestResponseSchema(), 200, description="") +@tenant_authentication +async def basic_message_send(request: web.BaseRequest): + """Request handler for sending a trust ping to a connection. + + Args: + request: aiohttp request object + + """ + context = request["context"] + outbound_handler = request["outbound_message_router"] + body = await request.json() + to_did = body.get("to_did") + message = body.get("content") + await context.profile.notify( + "acapy::webhook::nametag", + { + "to_did": to_did, + "name": message, + }, + ) + + + our_did = await get_mydid(request) + their_did = to_did + reply_destination = await get_target(request, to_did, our_did) + msg = V2AgentMessage( + message={ + "type": "https://colton.wolkins.net/dev/name-tag/2.0/set-name", + "body": {"name": message}, + "lang": "en", + "to": [their_did], + "from": our_did, + } + ) + await outbound_handler(msg, target_list=reply_destination) + return web.json_response(msg.message) + + +async def register(app: web.Application): + """Register routes.""" + + app.add_routes([web.post("/name-tag/set-name", basic_message_send)]) + + +def post_process_routes(app: web.Application): + """Amend swagger API.""" + + # Add top-level tags description + if "tags" not in app._state["swagger_dict"]: + app._state["swagger_dict"]["tags"] = [] + app._state["swagger_dict"]["tags"].append( + { + "name": "basicmessagev2", + "description": "Basic Message to contact", + "externalDocs": {"description": "Specification", "url": SPEC_URI}, + } + ) + app._state["swagger_dict"]["tags"].append( + { + "name": "didcommv2", + "description": "DIDComm V2 based protocols for Interop-a-thon", + "externalDocs": { + "description": "Specification", + "url": "https://didcomm.org", + }, + } + ) diff --git a/acapy_agent/protocols_v2/trustping/__init__.py b/acapy_agent/protocols_v2/trustping/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/trustping/definition.py b/acapy_agent/protocols_v2/trustping/definition.py new file mode 100644 index 0000000000..62bddef6f5 --- /dev/null +++ b/acapy_agent/protocols_v2/trustping/definition.py @@ -0,0 +1,10 @@ +"""Version definitions for this protocol.""" + +versions = [ + { + "major_version": 1, + "minimum_minor_version": 0, + "current_minor_version": 0, + "path": "v1_0", + } +] diff --git a/acapy_agent/protocols_v2/trustping/v1_0/__init__.py b/acapy_agent/protocols_v2/trustping/v1_0/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/trustping/v1_0/message_types.py b/acapy_agent/protocols_v2/trustping/v1_0/message_types.py new file mode 100644 index 0000000000..b5c27b19e9 --- /dev/null +++ b/acapy_agent/protocols_v2/trustping/v1_0/message_types.py @@ -0,0 +1,48 @@ +"""Message type identifiers for Trust Pings.""" + +import logging +from ....messaging.v2_agent_message import V2AgentMessage + +SPEC_URI = "https://identity.foundation/didcomm-messaging/spec/#trust-ping-protocol-20" + +# Message types +PING = "https://didcomm.org/trust-ping/2.0/ping" +PING_RESPONSE = "https://didcomm.org/trust-ping/2.0/ping-response" + +PROTOCOL_PACKAGE = "acapy_agent.protocols_v2.trustping.v1_0" + + +class trust_ping: + """Trust Ping 2.0 DIDComm V2 Protocol.""" + + async def __call__(self, *args, **kwargs): + """Call the Handler.""" + await self.handle(*args, **kwargs) + + @staticmethod + async def handle(context, responder, payload): + """Handle the incoming message.""" + logging.getLogger(__name__) + if not payload["body"].get("response_requested", False): + return + their_did = context.message_receipt.sender_verkey.split("#")[0] + our_did = context.message_receipt.recipient_verkey.split("#")[0] + error_result = V2AgentMessage( + message={ + "type": "https://didcomm.org/trust-ping/2.0/ping-response", + "thid": payload["id"], + "body": {}, + "to": [their_did], + "from": our_did, + } + ) + await responder.send_reply(error_result) + + +HANDLERS = { + PING: f"{PROTOCOL_PACKAGE}.message_types.trust_ping", +}.items() + +MESSAGE_TYPES = { + PING: f"{PROTOCOL_PACKAGE}.message_types.trust_ping", +} diff --git a/acapy_agent/protocols_v2/trustping/v1_0/routes.py b/acapy_agent/protocols_v2/trustping/v1_0/routes.py new file mode 100644 index 0000000000..0b6569970a --- /dev/null +++ b/acapy_agent/protocols_v2/trustping/v1_0/routes.py @@ -0,0 +1,236 @@ +"""Trust ping admin routes.""" + +from aiohttp import web +from aiohttp_apispec import docs, request_schema, response_schema +from marshmallow import fields +from didcomm_messaging import DIDCommMessaging, RoutingService +from didcomm_messaging.resolver import DIDResolver as DMPResolver + +from ....admin.decorators.auth import tenant_authentication +from ....admin.request_context import AdminRequestContext +from ....messaging.models.openapi import OpenAPISchema +from ....messaging.valid import UUID4_EXAMPLE +from .message_types import SPEC_URI + +from ....wallet.base import BaseWallet +from ....wallet.did_info import DIDInfo +from ....wallet.did_method import ( + DIDMethod, + DIDMethods, +) +from ....wallet.did_posture import DIDPosture +from ....messaging.v2_agent_message import V2AgentMessage +from ....connections.models.connection_target import ConnectionTarget + + +class BaseDIDCommV2Schema(OpenAPISchema): + """Request schema for performing a ping.""" + + to_did = fields.Str( + required=True, + allow_none=False, + metadata={"description": "Comment for the ping message"}, + ) + + +class PingRequestSchema(BaseDIDCommV2Schema): + """Request schema for performing a ping.""" + + response_requested = fields.Bool( + required=False, + allow_none=True, + metadata={"description": "Comment for the ping message"}, + ) + + +class PingRequestResponseSchema(OpenAPISchema): + """Request schema for performing a ping.""" + + thread_id = fields.Str( + required=False, metadata={"description": "Thread ID of the ping message"} + ) + + +class PingConnIdMatchInfoSchema(OpenAPISchema): + """Path parameters and validators for request taking connection id.""" + + conn_id = fields.Str( + required=True, + metadata={"description": "Connection identifier", "example": UUID4_EXAMPLE}, + ) + + +def format_did_info(info: DIDInfo): + """Serialize a DIDInfo object.""" + if info: + return { + "did": info.did, + "verkey": info.verkey, + "posture": DIDPosture.get(info.metadata).moniker, + "key_type": info.key_type.key_type, + "method": info.method.method_name, + "metadata": info.metadata, + } + + +async def get_mydid(request: web.BaseRequest): + """Get a DID that can be used for communication.""" + context: AdminRequestContext = request["context"] + # filter_did = request.query.get("did") + # filter_verkey = request.query.get("verkey") + filter_posture = DIDPosture.get(request.query.get("posture")) + results = [] + async with context.session() as session: + did_methods: DIDMethods = session.inject(DIDMethods) + filter_method: DIDMethod | None = did_methods.from_method( + request.query.get("method") or "did:peer:2" + ) + # key_types = session.inject(KeyTypes) + # filter_key_type = key_types.from_key_type(request.query.get("key_type", "")) + wallet: BaseWallet | None = session.inject_or(BaseWallet) + if not wallet: + raise web.HTTPForbidden(reason="No wallet available") + else: + dids = await wallet.get_local_dids() + results = [ + format_did_info(info) + for info in dids + if ( + filter_posture is None + or DIDPosture.get(info.metadata) is DIDPosture.WALLET_ONLY + ) + and (not filter_method or info.method == filter_method) + # and (not filter_key_type or info.key_type == filter_key_type) + ] + + results.sort(key=lambda info: (DIDPosture.get(info["posture"]).ordinal, info["did"])) + our_did = results[0]["did"] + return our_did + + +async def get_target(request: web.BaseRequest, to_did: str, from_did: str): + """Get Connection Target from did.""" + context: AdminRequestContext = request["context"] + + try: + async with context.session() as session: + resolver = session.inject(DMPResolver) + await resolver.resolve(to_did) + except Exception as err: + raise web.HTTPNotFound(reason=str(err)) from err + + async with context.session() as session: + ctx = session + messaging = ctx.inject(DIDCommMessaging) + routing_service = ctx.inject(RoutingService) + frm = to_did + services = await routing_service._resolve_services(messaging.resolver, frm) + chain = [ + { + "did": frm, + "service": services, + } + ] + + # Loop through service DIDs until we run out of DIDs to forward to + to_target = services[0].service_endpoint.uri + found_forwardable_service = await routing_service.is_forwardable_service( + messaging.resolver, services[0] + ) + while found_forwardable_service: + services = await routing_service._resolve_services( + messaging.resolver, to_target + ) + if services: + chain.append( + { + "did": to_target, + "service": services, + } + ) + to_target = services[0].service_endpoint.uri + found_forwardable_service = ( + await routing_service.is_forwardable_service( + messaging.resolver, services[0] + ) + if services + else False + ) + reply_destination = [ + ConnectionTarget( + did=f"{to_did}#key-1", + endpoint=service.service_endpoint.uri, + recipient_keys=[f"{to_did}#key-1"], + sender_key=from_did + "#key-1", + ) + for service in chain[-1]["service"] + ] + return reply_destination + + +@docs(tags=["trustping", "didcommv2"], summary="Send a trust ping to a connection") +@request_schema(PingRequestSchema()) +@response_schema(PingRequestResponseSchema(), 200, description="") +@tenant_authentication +async def connections_send_ping(request: web.BaseRequest): + """Request handler for sending a trust ping to a connection. + + Args: + request: aiohttp request object + + """ + request["context"] + outbound_handler = request["outbound_message_router"] + body = await request.json() + to_did = body.get("to_did") + response_requested = body.get("response_requested") + + our_did = await get_mydid(request) + their_did = to_did + reply_destination = await get_target(request, to_did, our_did) + msg = V2AgentMessage( + message={ + "type": "https://didcomm.org/trust-ping/2.0/ping", + "body": {}, + "to": [their_did], + "from": our_did, + } + ) + + if response_requested: + msg.message["response_requested"] = True + + await outbound_handler(msg, target_list=reply_destination) + + return web.json_response(msg.message) + + +async def register(app: web.Application): + """Register routes.""" + + app.add_routes([web.post("/trust-ping/send-ping", connections_send_ping)]) + + +def post_process_routes(app: web.Application): + """Amend swagger API.""" + + # Add top-level tags description + if "tags" not in app._state["swagger_dict"]: + app._state["swagger_dict"]["tags"] = [] + app._state["swagger_dict"]["tags"].append( + { + "name": "trustpingv2", + "description": "Trust-ping to contact", + "externalDocs": {"description": "Specification", "url": SPEC_URI}, + } + ) + app._state["swagger_dict"]["tags"].append( + { + "name": "didcommv2", + "description": "DIDComm V2 based protocols for Interop-a-thon", + "externalDocs": { + "description": "Specification", + "url": "https://didcomm.org", + }, + } + ) diff --git a/acapy_agent/protocols_v2/trustping/v1_0/tests/__init__.py b/acapy_agent/protocols_v2/trustping/v1_0/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/acapy_agent/protocols_v2/trustping/v1_0/tests/test_routes.py b/acapy_agent/protocols_v2/trustping/v1_0/tests/test_routes.py new file mode 100644 index 0000000000..f590a458c9 --- /dev/null +++ b/acapy_agent/protocols_v2/trustping/v1_0/tests/test_routes.py @@ -0,0 +1,134 @@ +from unittest import IsolatedAsyncioTestCase + +from .....admin.request_context import AdminRequestContext +from .....tests import mock +from .....utils.testing import create_test_profile +from .. import routes as test_module +from .....wallet.did_method import ( + DIDMethods, +) + +# from didcomm_messaging import DIDCommMessaging, RoutingService +from didcomm_messaging.resolver import DIDResolver as DMPResolver +from didcomm_messaging import ( + CryptoService, + DIDCommMessaging, + PackagingService, + RoutingService, + SecretsManager, +) +from didcomm_messaging.crypto.backend.askar import AskarCryptoService + + +class TestTrustpingRoutes(IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.session_inject = {} + self.profile = await create_test_profile( + settings={ + "admin.admin_api_key": "secret-key", + } + ) + self.context = AdminRequestContext.test_context(self.session_inject, self.profile) + self.context.injector.bind_instance(DIDMethods, DIDMethods()) + from .....didcomm_v2.adapters import ResolverAdapter, SecretsAdapter + from .....resolver.did_resolver import DIDResolver + from .....resolver.default.peer4 import PeerDID4Resolver + + self.context.injector.bind_instance(DIDResolver, DIDResolver()) + didResolver = self.context.inject_or(DIDResolver) + self.context.injector.bind_instance( + DMPResolver, ResolverAdapter(self.profile, didResolver) + ) + self.context.injector.bind_instance(SecretsManager, SecretsAdapter(self.profile)) + self.context.injector.bind_instance(RoutingService, RoutingService()) + self.context.injector.bind_instance(CryptoService, AskarCryptoService()) + self.context.injector.bind_instance(PackagingService, PackagingService()) + peer_did_4_resolver = PeerDID4Resolver() + await peer_did_4_resolver.setup(self.context) + didResolver.register_resolver(peer_did_4_resolver) + + self.context.injector.bind_instance( + DIDCommMessaging, + DIDCommMessaging( + self.context.inject_or(CryptoService), + self.context.inject_or(SecretsManager), + self.context.inject_or(DMPResolver), + self.context.inject_or(PackagingService), + self.context.inject_or(RoutingService), + ), + ) + + self.request_dict = { + "context": self.context, + "outbound_message_router": mock.CoroutineMock(), + } + self.request = mock.MagicMock( + app={}, + match_info={}, + query={}, + __getitem__=lambda _, k: self.request_dict[k], + headers={"x-api-key": "secret-key"}, + ) + + async def test_connections_send_ping(self): + DID = "did:peer:4zQmXrH3ADfT6LtLQgrVkQtitAnYtQaEaaonP8yehJv79DAD:z6uysxxSHsMeCGVLxaA5yaTNMqkacZmod7a5nG9Seq8SNjt8NqK7oXreduL22hybjVvWgUA6TVq9enGQC3PP3RU3tKvxfnPhvYDrs3CoYx8VsdFvbuUVYGhsxPVgg8yByGgV6kmteqdACzThtpVLCXcLcxuxJj4i6v3W2AMyUTKy18aPbupzMMBLbdxsuVT1ePydY4AuB2VpVPz1XRxBJZjzQd1Va6BnzCPS9y87XpZwS8hc5GQcqss7XG1Pmmq3xCbiKzSBfx3NL6sxgWY1Vuc3aYaPXkXEtAvMbUnbyGx9UDY6rozLdv2WnyHP6B9krNz4TgfoTFSz9exNctKt7BTNuPi5cgEhAEKs81sqAKSQbuN514mrsAU3QEyEGFH46Wnfm1PxDwwhg4E6oqxKTxykvz698MuZijAr8nD6fV9gMLoF2FHSyqn21CQZXsWC5ZTo534HbjVNEifEkev493JdZQTLmTkA8rdEnnCBwSgXaHMtPmYbd6XzSj4PVKa31K14Q2ktrxkic7kTVSf5Dv9mzvUgD4iUSmtsxj2VrBNaZHFbWYD8QANJY6H6NRt9WjwZDypLdPDmUZDzNLmqxSxdxT5DYxboVPtuh9dWtT7tpeDuzD6XBs9jjw5cSZ7DWWRSvTWuovm9sSuSC8zdCuzRrHFu6JaYDVHujyUvNyQ3cJ4EJxvb4LjmaCfZiuY6VvtGSkdyxoe22PyH3eBVzgwxYN7XPDCc7ewE8uS76x7PC2qWnFzrXCHaP51jg3cZtXoysktXGsvZ8B1XnZJYTteS5GJmaUgZYD9QsKSyq3GEiygzBN1StPSopBzpAscdfH4VtGNbrBYNCVBaGQetJfm9FV9HRyk1XJHTGe5JTBb2862dGHY7zmpinwM4XriPRizGPkEVE3FJddFUrmMz6iRMpDp7ZyxeYUrnGN94vW8nuPa77CkyPu82LkdJgrFfAwGyB6B26Br3YmBPBH8af1uKSfufgyZr4KqFL3NFDp8DPPvMSgJaxVnHdjYnCWYDJWaoFR3LKMdBiH1Z894akrb4DEGNVV7YSZoWLEAtRiDP87526z9pv85QiZyau6St4L2bMfEnYcR5TDQtj4oWGbZtfxTbURWZr2RoCdo9vQSa1YFjrN8rX3ob8CCeexRb38eJj9o5gRaVdij3JZyxPDuHDQHBUeAU75RueK2QCHTCBAkhp1JrFLGiHuTrfX6Q1HmpE2YVyoabybCHuu7joMZrVmCkmbYUhPXmFXd2mX3drBApVdxvrbS4VJZxzVETTnDnXxSeVffubam4cWZGaHHzjUUJ593wkWwbnafjUxoQHeH5gRE9fo7stAWCNn4hYRFdTVKRe4zC7pkLXcTThRvZDHwhixPaxKYFXAd5Vhkixo5DLjvt8t4kBZEDfBmNxef1Bkf3TUXAW7bh21SMjSwzekQuVezZwYjiTMPCPMXPv7BvRQfT3NqftD45B3TkAbYXRCo5t18fJt6eKfEvQYyJQbJbgJHvvJyPM54t9z4y9qQtJEPGNNAhbwmttoshCznEkYEpLBUXiffXyS6LDjSURntutCL2GQp8YBMup9xEm32o44NupwND56a78dxdoF9XDxwp8vvY89rTsPGX2bRMyFW8uyXyyrwpnqQMoykidWrBexjTYc4oZpRzHznaVXnLxsWZRqKCdthw2jmTyxDoJncWvnpLHRWNwW44oP1pzCHMf9nwunySm6dp79wMKo5tmfanFud596J237C8MhZc4sFcknPC2BkeBqu4E5WryhA2ZCAmiNqKg6EjPbhEAG45yT1gw9fTtA3ydqwvarsUyXT7DKPjwasSAY6VD5iyzKpC1obEoXTzCb5sVmtUDFywArUGChPed8uVzaWyN73vJqo6SG" + self.request.json = mock.CoroutineMock(return_value={"to_did": DID}) + self.request.match_info = {"to_did": DID} + + with ( + mock.patch.object( + test_module, "V2AgentMessage", mock.MagicMock() + ) as mock_ping, + mock.patch.object(test_module, "get_mydid", mock.CoroutineMock()) as mock_did, + mock.patch.object( + test_module.web, "json_response", mock.MagicMock() + ) as json_response, + ): + mock_ping.return_value = mock.MagicMock(_thread_id="dummy") + mock_did.return_value = mock.CoroutineMock(return_value=DID) + # mock_retrieve.return_value = mock.MagicMock(is_ready=True) + result = await test_module.connections_send_ping(self.request) + expected = mock_ping( + message={ + "type": "https://didcomm.org/trust-ping/2.0/ping", + "body": {}, + "to": [DID], + "from": DID, + } + ) + json_response.assert_called_once_with(expected.message) + assert result is json_response.return_value + + # async def test_connections_send_ping_no_conn(self): + # self.request.json = mock.CoroutineMock(return_value={"comment": "some comment"}) + # self.request.match_info = {"conn_id": "dummy"} + + # with ( + # mock.patch.object(test_module.web, "json_response", mock.MagicMock()), + # ): + # # mock_retrieve.side_effect = test_module.StorageNotFoundError() + # with self.assertRaises(test_module.web.HTTPNotFound): + # await test_module.connections_send_ping(self.request) + + # async def test_connections_send_ping_not_ready(self): + # self.request.json = mock.CoroutineMock(return_value={"comment": "some comment"}) + # self.request.match_info = {"conn_id": "dummy"} + + # with ( + # mock.patch.object(test_module.web, "json_response", mock.MagicMock()), + # ): + # # mock_retrieve.return_value = mock.MagicMock(is_ready=False) + # with self.assertRaises(test_module.web.HTTPBadRequest): + # await test_module.connections_send_ping(self.request) + + async def test_register(self): + mock_app = mock.MagicMock() + mock_app.add_routes = mock.MagicMock() + + await test_module.register(mock_app) + mock_app.add_routes.assert_called_once() + + async def test_post_process_routes(self): + mock_app = mock.MagicMock(_state={"swagger_dict": {}}) + test_module.post_process_routes(mock_app) + assert "tags" in mock_app._state["swagger_dict"] diff --git a/acapy_agent/transport/inbound/http.py b/acapy_agent/transport/inbound/http.py index f4a29ba3e7..4c634c1523 100644 --- a/acapy_agent/transport/inbound/http.py +++ b/acapy_agent/transport/inbound/http.py @@ -38,6 +38,7 @@ async def make_application(self) -> web.Application: app = web.Application(**app_args) app.add_routes([web.get("/", self.invite_message_handler)]) app.add_routes([web.post("/", self.inbound_message_handler)]) + app.add_routes([web.options("/", self.options_message_handler)]) return app async def start(self) -> None: @@ -128,6 +129,25 @@ async def inbound_message_handler(self, request: web.BaseRequest): ) return web.Response(status=200) + async def options_message_handler(self, request: web.BaseRequest): + """Message handler for invites. + + Args: + request: aiohttp request object + + Returns: + The web response + + """ + return web.Response( + status=200, + headers={ + "Access-Control-Allow-Headers": "Content-Type", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET,POST", + }, + ) + async def invite_message_handler(self, request: web.BaseRequest): """Message handler for invites.