1515import asyncio
1616import json
1717import uuid
18- from typing import TYPE_CHECKING
19-
20- from fastapi import FastAPI , Request , Response , WebSocket
18+ from typing import TYPE_CHECKING , Any , Optional
19+
20+ from fastapi import FastAPI , HTTPException , Request , Response , WebSocket
21+ from fastapi .responses import StreamingResponse
22+ from google .adk .agents .run_config import StreamingMode
23+ from google .adk .artifacts import InMemoryArtifactService
24+ from google .adk .cli .adk_web_server import RunAgentRequest
25+ from google .adk .runners import Runner as GoogleRunner , RunConfig
26+ from google .adk .sessions import InMemorySessionService , Session
2127from google .adk .tools .mcp_tool .mcp_session_manager import (
2228 StreamableHTTPConnectionParams ,
2329)
2430from google .adk .tools .mcp_tool .mcp_toolset import MCPToolset
31+ from google .adk .utils .context_utils import Aclosing
2532from pydantic import BaseModel
2633
2734from veadk import Runner
@@ -93,11 +100,15 @@ def __init__(
93100 self .port = port
94101
95102 self .app = FastAPI ()
103+
104+ self .artifact_service = InMemoryArtifactService ()
105+
96106 # build routes for self.app
97107 self .build ()
98108
99109 self .ws_session_mgr = WebsocketSessionManager ()
100110 self .ws_agent_mgr : dict [str , "Agent" ] = {}
111+ self .ws_session_service_mgr : dict [str , "InMemorySessionService" ] = {}
101112
102113 def build (self ):
103114 logger .info ("Build routes for server with reverse mcp" )
@@ -126,19 +137,6 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
126137
127138 agent = self .ws_agent_mgr [payload .websocket_id ]
128139
129- if not agent .tools :
130- logger .debug ("Mount fake MCPToolset to agent" )
131-
132- # we hard code the mcp url with `/mcp` to obey the mcp protocol
133- agent .tools .append (
134- MCPToolset (
135- connection_params = StreamableHTTPConnectionParams (
136- url = f"http://127.0.0.1:{ self .port } /mcp" ,
137- headers = {REVERSE_MCP_HEADER_KEY : payload .websocket_id },
138- ),
139- )
140- )
141-
142140 runner = Runner (app_name = payload .app_name , agent = agent )
143141 response = await runner .run (
144142 messages = [prompt ],
@@ -152,28 +150,213 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
152150 @self .app .websocket ("/ws" )
153151 async def ws_endpoint (ws : WebSocket ):
154152 client_id = ws .query_params .get ("id" )
153+
155154 if not client_id :
156155 await ws .close (
157156 code = 400 ,
158157 reason = "WebSocket `id` is required like `/ws?id=my_id`" ,
159158 )
160159 return
161160
161+ # Parse filters from query params, comma-separated string
162+ filters_str = ws .query_params .get ("filters" )
163+ filters = None
164+ if filters_str :
165+ filters = [t .strip () for t in filters_str .split ("," ) if t .strip ()]
166+
162167 logger .info (f"Register websocket { client_id } to session manager." )
163168 self .ws_session_mgr .connections [client_id ] = ws
164169
165170 logger .info (f"Fork agent for websocket { client_id } " )
166- self .ws_agent_mgr [client_id ] = self .agent .clone ()
171+ agent = self .agent .clone ()
172+
173+ logger .info (
174+ f"clone agent \n model_name={ agent .model_name } \n instruction={ agent .instruction } \n "
175+ )
176+
177+ # Mount MCPToolset when creating agent
178+ mcp_toolset_url = f"http://127.0.0.1:{ self .port } /mcp"
179+ mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY : client_id }
180+ logger .debug (f"Mount MCPToolset to agent for websocket { client_id } " )
181+ agent .tools .append (
182+ MCPToolset (
183+ connection_params = StreamableHTTPConnectionParams (
184+ url = mcp_toolset_url ,
185+ headers = mcp_toolset_headers ,
186+ ),
187+ tool_filter = filters ,
188+ )
189+ )
190+ self .ws_agent_mgr [client_id ] = agent
191+
192+ logger .info (f"Create session service for websocket { client_id } " )
193+ self .ws_session_service_mgr [client_id ] = InMemorySessionService ()
167194
168195 await ws .accept ()
169196 logger .info (f"Websocket { client_id } connected" )
170197
171- while True :
172- raw = await ws .receive_text ()
173- await self .ws_session_mgr .handle_ws_message (client_id , raw )
198+ try :
199+ while True :
200+ raw = await ws .receive_text ()
201+ logger .debug (f"ws.receive_text() = { raw } " )
202+ await self .ws_session_mgr .handle_ws_message (client_id , raw )
203+ except Exception as e :
204+ logger .warning (f"client { client_id } web socket connection closed: { e } " )
205+
206+ class CreateSessionRequest (BaseModel ):
207+ state : Optional [dict [str , Any ]] = None
208+ session_id : Optional [str ] = None
209+ websocket_id : str
210+
211+ class RunAgentRequestWithWsId (RunAgentRequest ):
212+ websocket_id : str
213+
214+ def _get_session_service (websocket_id : str ) -> InMemorySessionService :
215+ """Get session service for the websocket client."""
216+ if websocket_id not in self .ws_session_service_mgr :
217+ raise HTTPException (
218+ status_code = 404 , detail = f"WebSocket client { websocket_id } not found"
219+ )
220+ return self .ws_session_service_mgr [websocket_id ]
221+
222+ @self .app .post (
223+ "/apps/{app_name}/users/{user_id}/sessions" ,
224+ response_model_exclude_none = True ,
225+ )
226+ async def create_session (
227+ app_name : str ,
228+ user_id : str ,
229+ req : CreateSessionRequest ,
230+ ) -> Session :
231+ """Create a new session."""
232+ session_id = req .session_id if req .session_id else str (uuid .uuid4 ())
233+ session = Session (
234+ app_name = app_name ,
235+ user_id = user_id ,
236+ id = session_id ,
237+ state = req .state if req .state else {},
238+ )
239+ session_service = _get_session_service (req .websocket_id )
240+ await session_service .create_session (
241+ app_name = app_name ,
242+ user_id = user_id ,
243+ session_id = session_id ,
244+ state = req .state if req .state else {},
245+ )
246+ logger .info (
247+ f"Created session: { session_id } for user { user_id } in app { app_name } "
248+ )
249+ return session
250+
251+ @self .app .post (
252+ "/apps/{app_name}/users/{user_id}/sessions/{session_id}" ,
253+ response_model_exclude_none = True ,
254+ )
255+ async def create_session_with_id (
256+ app_name : str ,
257+ user_id : str ,
258+ session_id : str ,
259+ req : CreateSessionRequest ,
260+ ) -> Session :
261+ """Create a session with specific ID."""
262+ session_service = _get_session_service (req .websocket_id )
263+ await session_service .create_session (
264+ app_name = app_name ,
265+ user_id = user_id ,
266+ session_id = session_id ,
267+ state = req .state if req .state else {},
268+ )
269+ session = Session (
270+ app_name = app_name ,
271+ user_id = user_id ,
272+ id = session_id ,
273+ state = req .state if req .state else {},
274+ )
275+ logger .info (f"Created session with ID: { session_id } for user { user_id } " )
276+ return session
277+
278+ @self .app .post ("/run_sse" )
279+ async def run_agent_sse (req : RunAgentRequestWithWsId ) -> StreamingResponse :
280+ """Run agent with SSE streaming."""
281+ session_service = _get_session_service (req .websocket_id )
282+
283+ # Get session
284+ session = await session_service .get_session (
285+ app_name = req .app_name ,
286+ user_id = req .user_id ,
287+ session_id = req .session_id ,
288+ )
289+ if not session :
290+ raise HTTPException (status_code = 404 , detail = "Session not found" )
291+
292+ # Get agent for this websocket
293+ if req .websocket_id in self .ws_agent_mgr :
294+ agent = self .ws_agent_mgr [req .websocket_id ]
295+ logger .debug (f"Using agent from websocket { req .websocket_id } " )
296+ else :
297+ raise HTTPException (
298+ status_code = 404 ,
299+ detail = f"WebSocket client { req .websocket_id } not found" ,
300+ )
301+
302+ # Create runner
303+ runner = GoogleRunner (
304+ agent = agent ,
305+ app_name = req .app_name ,
306+ session_service = session_service ,
307+ artifact_service = self .artifact_service ,
308+ )
309+
310+ # Determine streaming mode from request
311+ stream_mode = StreamingMode .SSE if req .streaming else StreamingMode .NONE
312+
313+ async def event_generator ():
314+ try :
315+ async with Aclosing (
316+ runner .run_async (
317+ user_id = req .user_id ,
318+ session_id = req .session_id ,
319+ new_message = req .new_message ,
320+ state_delta = req .state_delta ,
321+ run_config = RunConfig (streaming_mode = stream_mode ),
322+ invocation_id = req .invocation_id ,
323+ )
324+ ) as agen :
325+ async for event in agen :
326+ # ADK Web renders artifacts from `actions.artifactDelta`
327+ # during part processing *and* during action processing
328+ # 1) the original event with `artifactDelta` cleared (content)
329+ # 2) a content-less "action-only" event carrying `artifactDelta`
330+ events_to_stream = [event ]
331+ if (
332+ event .actions .artifact_delta
333+ and event .content
334+ and event .content .parts
335+ ):
336+ content_event = event .model_copy (deep = True )
337+ content_event .actions .artifact_delta = {}
338+ artifact_event = event .model_copy (deep = True )
339+ artifact_event .content = None
340+ events_to_stream = [content_event , artifact_event ]
341+
342+ for event_to_stream in events_to_stream :
343+ sse_event = event_to_stream .model_dump_json (
344+ exclude_none = True , by_alias = True
345+ )
346+ logger .debug (f"SSE event: { sse_event } " )
347+ yield f"data: { sse_event } \n \n "
348+ except Exception as e :
349+ logger .exception (f"Error in event_generator: { e } " )
350+ yield f"data: { json .dumps ({'error' : 'Internal server error' })} \n \n "
351+
352+ return StreamingResponse (
353+ event_generator (),
354+ media_type = "text/event-stream" ,
355+ )
174356
175357 # build the fake MPC server,
176358 # and intercept all requests to the client websocket client.
359+ # NOTE: This catch-all route must be defined LAST
177360 @self .app .api_route ("/{path:path}" , methods = ["GET" , "POST" ])
178361 async def mcp_proxy (path : str , request : Request ):
179362 client_id = request .headers .get (REVERSE_MCP_HEADER_KEY )
@@ -202,10 +385,27 @@ async def mcp_proxy(path: str, request: Request):
202385
203386 logger .debug (f"[Reverse mcp proxy] Response from local: { resp } " )
204387
388+ # Filter hop-by-hop headers to avoid Content-Length mismatch
389+ headers = resp ["payload" ]["headers" ]
390+ hop_by_hop_headers = {
391+ "content-length" ,
392+ "transfer-encoding" ,
393+ "connection" ,
394+ "keep-alive" ,
395+ "proxy-authenticate" ,
396+ "proxy-authorization" ,
397+ "te" ,
398+ "trailers" ,
399+ "upgrade" ,
400+ }
401+ filtered_headers = {
402+ k : v for k , v in headers .items () if k .lower () not in hop_by_hop_headers
403+ }
404+
205405 return Response (
206406 content = resp ["payload" ]["body" ], # type: ignore
207407 status_code = resp ["payload" ]["status" ], # type: ignore
208- headers = resp [ "payload" ][ "headers" ] , # type: ignore
408+ headers = filtered_headers , # type: ignore
209409 )
210410
211411 def run (self ):
0 commit comments