1515import asyncio
1616import json
1717import uuid
18- from typing import TYPE_CHECKING , Any , Optional
18+ from typing import TYPE_CHECKING , Any , Callable , Optional
1919
2020from fastapi import FastAPI , HTTPException , Request , Response , WebSocket
2121from fastapi .responses import StreamingResponse
22- from google .adk .agents .run_config import StreamingMode
22+ from google .adk .agents .run_config import RunConfig , StreamingMode
2323from google .adk .artifacts import InMemoryArtifactService
2424from google .adk .cli .adk_web_server import RunAgentRequest
25- from google .adk .runners import Runner as GoogleRunner , RunConfig
25+ from google .adk .runners import Runner as GoogleRunner
2626from google .adk .sessions import InMemorySessionService , Session
2727from google .adk .tools .mcp_tool .mcp_session_manager import (
2828 StreamableHTTPConnectionParams ,
4242REVERSE_MCP_HEADER_KEY = "X-Reverse-MCP-ID"
4343
4444
45+ class ExtraRoute (BaseModel ):
46+ path : str
47+ endpoint : Callable
48+ methods : list [str ]
49+
50+
4551class WebsocketSessionManager :
4652 def __init__ (self ):
4753 # ws id -> ws instance
@@ -93,13 +99,21 @@ def __init__(
9399 agent : "Agent" ,
94100 host : str = "0.0.0.0" ,
95101 port : int = 8000 ,
102+ extra_routes : list [ExtraRoute ] | None = None ,
96103 ):
97104 self .agent = agent
98105
99106 self .host = host
100107 self .port = port
101108
102- self .app = FastAPI ()
109+ self .extra_routes = extra_routes
110+
111+ self .app = FastAPI (
112+ openapi_url = None ,
113+ docs_url = None ,
114+ redoc_url = None ,
115+ swagger_ui_oauth2_redirect_url = None ,
116+ )
103117
104118 self .artifact_service = InMemoryArtifactService ()
105119
@@ -215,7 +229,8 @@ def _get_session_service(websocket_id: str) -> InMemorySessionService:
215229 """Get session service for the websocket client."""
216230 if websocket_id not in self .ws_session_service_mgr :
217231 raise HTTPException (
218- status_code = 404 , detail = f"WebSocket client { websocket_id } not found"
232+ status_code = 404 ,
233+ detail = f"WebSocket client { websocket_id } not found" ,
219234 )
220235 return self .ws_session_service_mgr [websocket_id ]
221236
@@ -276,7 +291,9 @@ async def create_session_with_id(
276291 return session
277292
278293 @self .app .post ("/run_sse" )
279- async def run_agent_sse (req : RunAgentRequestWithWsId ) -> StreamingResponse :
294+ async def run_agent_sse (
295+ req : RunAgentRequestWithWsId ,
296+ ) -> StreamingResponse :
280297 """Run agent with SSE streaming."""
281298 session_service = _get_session_service (req .websocket_id )
282299
@@ -337,7 +354,10 @@ async def event_generator():
337354 content_event .actions .artifact_delta = {}
338355 artifact_event = event .model_copy (deep = True )
339356 artifact_event .content = None
340- events_to_stream = [content_event , artifact_event ]
357+ events_to_stream = [
358+ content_event ,
359+ artifact_event ,
360+ ]
341361
342362 for event_to_stream in events_to_stream :
343363 sse_event = event_to_stream .model_dump_json (
@@ -354,6 +374,14 @@ async def event_generator():
354374 media_type = "text/event-stream" ,
355375 )
356376
377+ if self .extra_routes :
378+ for route in self .extra_routes :
379+ self .app .add_api_route (
380+ path = route .path ,
381+ endpoint = route .endpoint ,
382+ methods = route .methods ,
383+ )
384+
357385 # build the fake MPC server,
358386 # and intercept all requests to the client websocket client.
359387 # NOTE: This catch-all route must be defined LAST
0 commit comments