Skip to content

Commit 930fabe

Browse files
chore(janus): support extra routes for reverse mcp server (#511)
1 parent 22f4599 commit 930fabe

File tree

1 file changed

+35
-7
lines changed

1 file changed

+35
-7
lines changed

veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
import asyncio
1616
import json
1717
import uuid
18-
from typing import TYPE_CHECKING, Any, Optional
18+
from typing import TYPE_CHECKING, Any, Callable, Optional
1919

2020
from fastapi import FastAPI, HTTPException, Request, Response, WebSocket
2121
from fastapi.responses import StreamingResponse
22-
from google.adk.agents.run_config import StreamingMode
22+
from google.adk.agents.run_config import RunConfig, StreamingMode
2323
from google.adk.artifacts import InMemoryArtifactService
2424
from google.adk.cli.adk_web_server import RunAgentRequest
25-
from google.adk.runners import Runner as GoogleRunner, RunConfig
25+
from google.adk.runners import Runner as GoogleRunner
2626
from google.adk.sessions import InMemorySessionService, Session
2727
from google.adk.tools.mcp_tool.mcp_session_manager import (
2828
StreamableHTTPConnectionParams,
@@ -42,6 +42,12 @@
4242
REVERSE_MCP_HEADER_KEY = "X-Reverse-MCP-ID"
4343

4444

45+
class ExtraRoute(BaseModel):
46+
path: str
47+
endpoint: Callable
48+
methods: list[str]
49+
50+
4551
class WebsocketSessionManager:
4652
def __init__(self):
4753
# ws id -> ws instance
@@ -93,13 +99,21 @@ def __init__(
9399
agent: "Agent",
94100
host: str = "0.0.0.0",
95101
port: int = 8000,
102+
extra_routes: list[ExtraRoute] | None = None,
96103
):
97104
self.agent = agent
98105

99106
self.host = host
100107
self.port = port
101108

102-
self.app = FastAPI()
109+
self.extra_routes = extra_routes
110+
111+
self.app = FastAPI(
112+
openapi_url=None,
113+
docs_url=None,
114+
redoc_url=None,
115+
swagger_ui_oauth2_redirect_url=None,
116+
)
103117

104118
self.artifact_service = InMemoryArtifactService()
105119

@@ -215,7 +229,8 @@ def _get_session_service(websocket_id: str) -> InMemorySessionService:
215229
"""Get session service for the websocket client."""
216230
if websocket_id not in self.ws_session_service_mgr:
217231
raise HTTPException(
218-
status_code=404, detail=f"WebSocket client {websocket_id} not found"
232+
status_code=404,
233+
detail=f"WebSocket client {websocket_id} not found",
219234
)
220235
return self.ws_session_service_mgr[websocket_id]
221236

@@ -276,7 +291,9 @@ async def create_session_with_id(
276291
return session
277292

278293
@self.app.post("/run_sse")
279-
async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse:
294+
async def run_agent_sse(
295+
req: RunAgentRequestWithWsId,
296+
) -> StreamingResponse:
280297
"""Run agent with SSE streaming."""
281298
session_service = _get_session_service(req.websocket_id)
282299

@@ -337,7 +354,10 @@ async def event_generator():
337354
content_event.actions.artifact_delta = {}
338355
artifact_event = event.model_copy(deep=True)
339356
artifact_event.content = None
340-
events_to_stream = [content_event, artifact_event]
357+
events_to_stream = [
358+
content_event,
359+
artifact_event,
360+
]
341361

342362
for event_to_stream in events_to_stream:
343363
sse_event = event_to_stream.model_dump_json(
@@ -354,6 +374,14 @@ async def event_generator():
354374
media_type="text/event-stream",
355375
)
356376

377+
if self.extra_routes:
378+
for route in self.extra_routes:
379+
self.app.add_api_route(
380+
path=route.path,
381+
endpoint=route.endpoint,
382+
methods=route.methods,
383+
)
384+
357385
# build the fake MPC server,
358386
# and intercept all requests to the client websocket client.
359387
# NOTE: This catch-all route must be defined LAST

0 commit comments

Comments
 (0)