1414
1515import asyncio
1616import json
17+ import time
1718import uuid
19+ import threading
20+ from dataclasses import dataclass , field
1821from typing import TYPE_CHECKING , Any , Callable , Optional
1922
2023from fastapi import FastAPI , HTTPException , Request , Response , WebSocket
2124from fastapi .responses import StreamingResponse
22- from google .adk .agents .run_config import RunConfig , StreamingMode
25+ from google .adk .agents .run_config import StreamingMode
2326from google .adk .artifacts import InMemoryArtifactService
2427from google .adk .cli .adk_web_server import RunAgentRequest
25- from google .adk .runners import Runner as GoogleRunner
28+ from google .adk .runners import Runner as GoogleRunner , RunConfig
2629from google .adk .sessions import InMemorySessionService , Session
2730from google .adk .tools .mcp_tool .mcp_session_manager import (
2831 StreamableHTTPConnectionParams ,
3437from veadk import Runner
3538from veadk .utils .logger import get_logger
3639
40+ from fastapi .middleware .cors import CORSMiddleware
41+
3742if TYPE_CHECKING :
3843 from veadk import Agent
3944
@@ -48,47 +53,121 @@ class ExtraRoute(BaseModel):
4853 methods : list [str ]
4954
5055
51- class WebsocketSessionManager :
52- def __init__ (self ):
53- # ws id -> ws instance
54- self .connections : dict [str , WebSocket ] = {}
56+ @dataclass
57+ class ClientResource :
58+ websocket : WebSocket
59+ agent : "Agent"
60+ session_service : InMemorySessionService
61+ artifact_service : InMemoryArtifactService
62+ pending_requests : dict [str , asyncio .Future ] = field (default_factory = dict )
63+ last_active_time : float = field (default_factory = time .time )
64+
65+ def update_activity (self ):
66+ self .last_active_time = time .time ()
67+
68+
69+ class ResourceManager :
70+ def __init__ (self , timeout_seconds : int = 3600 ):
71+ self ._lock : threading .Lock = threading .Lock ()
72+ self .resources : dict [str , ClientResource ] = {}
73+ self .timeout_seconds = timeout_seconds
74+ self .cleanup_task : Optional [asyncio .Task ] = None
75+
76+ def register (
77+ self ,
78+ client_id : str ,
79+ websocket : WebSocket ,
80+ agent : "Agent" ,
81+ session_service : InMemorySessionService ,
82+ artifact_service : InMemoryArtifactService ,
83+ ):
84+ with self ._lock :
85+ self .resources [client_id ] = ClientResource (
86+ websocket = websocket ,
87+ agent = agent ,
88+ session_service = session_service ,
89+ artifact_service = artifact_service ,
90+ )
91+ logger .info (f"client { client_id } registered" )
92+
93+ def get (self , client_id : str ) -> Optional [ClientResource ]:
94+ with self ._lock :
95+ logger .info (f"get { client_id } " )
96+ resource = self .resources .get (client_id )
97+ if resource :
98+ resource .update_activity ()
99+ return resource
100+
101+ async def remove (self , client_id : str ):
102+ if client_id in self .resources :
103+ resource = self .resources .pop (client_id )
104+ try :
105+ await resource .websocket .close ()
106+ for fut in resource .pending_requests .values ():
107+ if not fut .done ():
108+ fut .cancel ()
109+ except Exception as e :
110+ logger .warning (
111+ f"client { client_id } resource websocket close error: { e } "
112+ )
113+ pass
114+
115+ async def start_cleanup_loop (self ):
116+ logger .info ("ResourceManager: active cleanup loop" )
117+ while True :
118+ await asyncio .sleep (60 ) # Check every minute
119+ logger .debug ("cleanup loop running..." )
120+ now = time .time ()
121+ to_remove = []
122+ for client_id , resource in self .resources .items ():
123+ logger .debug (
124+ f"check { client_id } , last_active_time={ resource .last_active_time } , timeout={ self .timeout_seconds } "
125+ )
126+ if now - resource .last_active_time > self .timeout_seconds :
127+ to_remove .append (client_id )
128+
129+ for client_id in to_remove :
130+ logger .info (f"Removing inactive client { client_id } " )
131+ await self .remove (client_id )
132+
133+ def start (self ):
134+ self .cleanup_task = asyncio .create_task (self .start_cleanup_loop ())
55135
56- # ws id -> msg id -> ret
57- self .pendings : dict [str , dict [str , asyncio .Future ]] = {}
136+ def stop (self ):
137+ if self .cleanup_task :
138+ self .cleanup_task .cancel ()
58139
59- async def call_mcp_http (self , ws_id : str , request : dict ):
140+ async def call_mcp_http (self , client_id : str , request : dict ):
60141 """Forward MCP request to client."""
61- try :
62- ws = self .connections [ws_id ]
63- except KeyError :
64- logger .error (f"Websocket { ws_id } not found" )
142+ resource = self .get (client_id )
143+ if not resource :
144+ logger .error (f"Client { client_id } not found" )
65145 return b""
66146
67- msg = {}
68-
69- msg ["id" ] = str (uuid .uuid4 ())
70- msg ["type" ] = "http_request"
71- msg ["payload" ] = request
147+ ws = resource .websocket
148+ msg = {"id" : str (uuid .uuid4 ()), "type" : "http_request" , "payload" : request }
72149
73150 fut = asyncio .get_event_loop ().create_future ()
74151
75- if ws_id not in self .pendings :
76- self .pendings [ws_id ] = {}
77-
78- self .pendings [ws_id ][msg ["id" ]] = fut
152+ resource .pending_requests [msg ["id" ]] = fut
79153
80154 await ws .send_text (json .dumps (msg ))
81155 return await fut
82156
83- async def handle_ws_message (self , ws_id : str , raw : str ):
157+ async def handle_ws_message (self , client_id : str , raw : str ):
158+ resource = self .get (client_id )
159+ if not resource :
160+ return
161+
84162 msg = json .loads (raw )
85163 if msg .get ("type" ) != "http_response" :
86164 return
87165
88166 req_id = msg ["id" ]
89- fut = self . pendings [ ws_id ] .pop (req_id , None )
167+ fut = resource . pending_requests .pop (req_id , None )
90168 if fut :
91169 fut .set_result (msg )
170+ # todo : 异常ID处理
92171
93172
94173class ServerWithReverseMCP :
@@ -102,27 +181,36 @@ def __init__(
102181 extra_routes : list [ExtraRoute ] | None = None ,
103182 ):
104183 self .agent = agent
105-
106184 self .host = host
107185 self .port = port
108-
109186 self .extra_routes = extra_routes
110187
111- self .app = FastAPI (
112- openapi_url = None ,
113- docs_url = None ,
114- redoc_url = None ,
115- swagger_ui_oauth2_redirect_url = None ,
188+ self .app = FastAPI ()
189+ origins = [
190+ "*" , # 允许所有源(开发环境可用,生产环境不推荐)
191+ ]
192+
193+ self .app .add_middleware (
194+ CORSMiddleware ,
195+ allow_origins = origins , # 允许访问的源
196+ allow_credentials = True , # 允许携带Cookie
197+ allow_methods = ["*" ], # 允许所有HTTP方法(GET、POST、PUT等)
198+ allow_headers = ["*" ], # 允许所有请求头
116199 )
117200
118201 self .artifact_service = InMemoryArtifactService ()
202+ self .resource_manager = ResourceManager ()
119203
120204 # build routes for self.app
121205 self .build ()
122206
123- self .ws_session_mgr = WebsocketSessionManager ()
124- self .ws_agent_mgr : dict [str , "Agent" ] = {}
125- self .ws_session_service_mgr : dict [str , "InMemorySessionService" ] = {}
207+ @self .app .on_event ("startup" )
208+ async def startup_event ():
209+ self .resource_manager .start ()
210+
211+ @self .app .on_event ("shutdown" )
212+ async def shutdown_event ():
213+ self .resource_manager .stop ()
126214
127215 def build (self ):
128216 logger .info ("Build routes for server with reverse mcp" )
@@ -149,9 +237,18 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
149237 session_id = payload .session_id
150238 prompt = payload .prompt
151239
152- agent = self .ws_agent_mgr [payload .websocket_id ]
240+ resource = self .resource_manager .get (payload .websocket_id )
241+ if not resource :
242+ raise HTTPException (
243+ status_code = 404 , detail = f"Client { payload .websocket_id } not found"
244+ )
245+ agent = resource .agent
153246
154- runner = Runner (app_name = payload .app_name , agent = agent )
247+ runner = Runner (
248+ app_name = payload .app_name ,
249+ agent = agent ,
250+ session_service = resource .session_service ,
251+ )
155252 response = await runner .run (
156253 messages = [prompt ],
157254 user_id = user_id ,
@@ -160,6 +257,12 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
160257
161258 return InvokeResponse (response = response )
162259
260+ @self .app .delete ("/management/clients/{client_id}" )
261+ async def delete_client (client_id : str ):
262+ """Manually remove a client resource."""
263+ await self .resource_manager .remove (client_id )
264+ return {"status" : "success" , "client_id" : client_id }
265+
163266 # build websocket endpoint
164267 @self .app .websocket ("/ws" )
165268 async def ws_endpoint (ws : WebSocket ):
@@ -179,15 +282,10 @@ async def ws_endpoint(ws: WebSocket):
179282 filters = [t .strip () for t in filters_str .split ("," ) if t .strip ()]
180283
181284 logger .info (f"Register websocket { client_id } to session manager." )
182- self .ws_session_mgr .connections [client_id ] = ws
183285
184286 logger .info (f"Fork agent for websocket { client_id } " )
185287 agent = self .agent .clone ()
186288
187- logger .info (
188- f"clone agent \n model_name={ agent .model_name } \n instruction={ agent .instruction } \n "
189- )
190-
191289 # Mount MCPToolset when creating agent
192290 mcp_toolset_url = f"http://127.0.0.1:{ self .port } /mcp"
193291 mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY : client_id }
@@ -201,10 +299,18 @@ async def ws_endpoint(ws: WebSocket):
201299 tool_filter = filters ,
202300 )
203301 )
204- self .ws_agent_mgr [client_id ] = agent
205302
206303 logger .info (f"Create session service for websocket { client_id } " )
207- self .ws_session_service_mgr [client_id ] = InMemorySessionService ()
304+ session_service = InMemorySessionService ()
305+ artifact_service = InMemoryArtifactService ()
306+
307+ self .resource_manager .register (
308+ client_id = client_id ,
309+ websocket = ws ,
310+ agent = agent ,
311+ session_service = session_service ,
312+ artifact_service = artifact_service ,
313+ )
208314
209315 await ws .accept ()
210316 logger .info (f"Websocket { client_id } connected" )
@@ -213,7 +319,7 @@ async def ws_endpoint(ws: WebSocket):
213319 while True :
214320 raw = await ws .receive_text ()
215321 logger .debug (f"ws.receive_text() = { raw } " )
216- await self .ws_session_mgr .handle_ws_message (client_id , raw )
322+ await self .resource_manager .handle_ws_message (client_id , raw )
217323 except Exception as e :
218324 logger .warning (f"client { client_id } web socket connection closed: { e } " )
219325
@@ -227,12 +333,12 @@ class RunAgentRequestWithWsId(RunAgentRequest):
227333
228334 def _get_session_service (websocket_id : str ) -> InMemorySessionService :
229335 """Get session service for the websocket client."""
230- if websocket_id not in self .ws_session_service_mgr :
336+ resource = self .resource_manager .get (websocket_id )
337+ if not resource :
231338 raise HTTPException (
232- status_code = 404 ,
233- detail = f"WebSocket client { websocket_id } not found" ,
339+ status_code = 404 , detail = f"WebSocket client { websocket_id } not found"
234340 )
235- return self . ws_session_service_mgr [ websocket_id ]
341+ return resource . session_service
236342
237343 @self .app .post (
238344 "/apps/{app_name}/users/{user_id}/sessions" ,
@@ -291,11 +397,18 @@ async def create_session_with_id(
291397 return session
292398
293399 @self .app .post ("/run_sse" )
294- async def run_agent_sse (
295- req : RunAgentRequestWithWsId ,
296- ) -> StreamingResponse :
400+ async def run_agent_sse (req : RunAgentRequestWithWsId ) -> StreamingResponse :
297401 """Run agent with SSE streaming."""
298- session_service = _get_session_service (req .websocket_id )
402+ resource = self .resource_manager .get (req .websocket_id )
403+ if not resource :
404+ raise HTTPException (
405+ status_code = 404 ,
406+ detail = f"WebSocket client { req .websocket_id } not found" ,
407+ )
408+
409+ session_service = resource .session_service
410+ agent = resource .agent
411+ logger .debug (f"Using agent from websocket { req .websocket_id } " )
299412
300413 # Get session
301414 session = await session_service .get_session (
@@ -306,16 +419,6 @@ async def run_agent_sse(
306419 if not session :
307420 raise HTTPException (status_code = 404 , detail = "Session not found" )
308421
309- # Get agent for this websocket
310- if req .websocket_id in self .ws_agent_mgr :
311- agent = self .ws_agent_mgr [req .websocket_id ]
312- logger .debug (f"Using agent from websocket { req .websocket_id } " )
313- else :
314- raise HTTPException (
315- status_code = 404 ,
316- detail = f"WebSocket client { req .websocket_id } not found" ,
317- )
318-
319422 # Create runner
320423 runner = GoogleRunner (
321424 agent = agent ,
@@ -354,10 +457,7 @@ async def event_generator():
354457 content_event .actions .artifact_delta = {}
355458 artifact_event = event .model_copy (deep = True )
356459 artifact_event .content = None
357- events_to_stream = [
358- content_event ,
359- artifact_event ,
360- ]
460+ events_to_stream = [content_event , artifact_event ]
361461
362462 for event_to_stream in events_to_stream :
363463 sse_event = event_to_stream .model_dump_json (
@@ -367,7 +467,7 @@ async def event_generator():
367467 yield f"data: { sse_event } \n \n "
368468 except Exception as e :
369469 logger .exception (f"Error in event_generator: { e } " )
370- yield f"data: { json .dumps ({'error' : 'Internal server error' })} \n \n "
470+ yield f"data: { json .dumps ({'error' : str ( e ) })} \n \n "
371471
372472 return StreamingResponse (
373473 event_generator (),
@@ -391,8 +491,7 @@ async def mcp_proxy(path: str, request: Request):
391491 if not client_id :
392492 return Response ("client id not found" , status_code = 400 )
393493
394- ws = self .ws_session_mgr .connections .get (client_id )
395- if not ws :
494+ if not self .resource_manager .get (client_id ):
396495 return Response ("websocket `client_id` not connected" , status_code = 503 )
397496
398497 body = await request .body ()
@@ -409,7 +508,7 @@ async def mcp_proxy(path: str, request: Request):
409508
410509 logger .debug (f"[Reverse mcp proxy] Request from agent: { payload } " )
411510
412- resp = await self .ws_session_mgr .call_mcp_http (client_id , payload )
511+ resp = await self .resource_manager .call_mcp_http (client_id , payload )
413512
414513 logger .debug (f"[Reverse mcp proxy] Response from local: { resp } " )
415514
0 commit comments