Skip to content

Commit 1a69cec

Browse files
committed
feat(client): add multi-tenant support for client
1 parent 7c135b0 commit 1a69cec

File tree

1 file changed

+169
-70
lines changed

1 file changed

+169
-70
lines changed

veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py

Lines changed: 169 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,18 @@
1414

1515
import asyncio
1616
import json
17+
import time
1718
import uuid
19+
import threading
20+
from dataclasses import dataclass, field
1821
from typing import TYPE_CHECKING, Any, Callable, Optional
1922

2023
from fastapi import FastAPI, HTTPException, Request, Response, WebSocket
2124
from fastapi.responses import StreamingResponse
22-
from google.adk.agents.run_config import RunConfig, StreamingMode
25+
from google.adk.agents.run_config import StreamingMode
2326
from google.adk.artifacts import InMemoryArtifactService
2427
from google.adk.cli.adk_web_server import RunAgentRequest
25-
from google.adk.runners import Runner as GoogleRunner
28+
from google.adk.runners import Runner as GoogleRunner, RunConfig
2629
from google.adk.sessions import InMemorySessionService, Session
2730
from google.adk.tools.mcp_tool.mcp_session_manager import (
2831
StreamableHTTPConnectionParams,
@@ -34,6 +37,8 @@
3437
from veadk import Runner
3538
from veadk.utils.logger import get_logger
3639

40+
from fastapi.middleware.cors import CORSMiddleware
41+
3742
if TYPE_CHECKING:
3843
from veadk import Agent
3944

@@ -48,47 +53,121 @@ class ExtraRoute(BaseModel):
4853
methods: list[str]
4954

5055

51-
class WebsocketSessionManager:
52-
def __init__(self):
53-
# ws id -> ws instance
54-
self.connections: dict[str, WebSocket] = {}
56+
@dataclass
57+
class ClientResource:
58+
websocket: WebSocket
59+
agent: "Agent"
60+
session_service: InMemorySessionService
61+
artifact_service: InMemoryArtifactService
62+
pending_requests: dict[str, asyncio.Future] = field(default_factory=dict)
63+
last_active_time: float = field(default_factory=time.time)
64+
65+
def update_activity(self):
66+
self.last_active_time = time.time()
67+
68+
69+
class ResourceManager:
70+
def __init__(self, timeout_seconds: int = 3600):
71+
self._lock: threading.Lock = threading.Lock()
72+
self.resources: dict[str, ClientResource] = {}
73+
self.timeout_seconds = timeout_seconds
74+
self.cleanup_task: Optional[asyncio.Task] = None
75+
76+
def register(
77+
self,
78+
client_id: str,
79+
websocket: WebSocket,
80+
agent: "Agent",
81+
session_service: InMemorySessionService,
82+
artifact_service: InMemoryArtifactService,
83+
):
84+
with self._lock:
85+
self.resources[client_id] = ClientResource(
86+
websocket=websocket,
87+
agent=agent,
88+
session_service=session_service,
89+
artifact_service=artifact_service,
90+
)
91+
logger.info(f"client {client_id} registered")
92+
93+
def get(self, client_id: str) -> Optional[ClientResource]:
94+
with self._lock:
95+
logger.info(f"get {client_id}")
96+
resource = self.resources.get(client_id)
97+
if resource:
98+
resource.update_activity()
99+
return resource
100+
101+
async def remove(self, client_id: str):
102+
if client_id in self.resources:
103+
resource = self.resources.pop(client_id)
104+
try:
105+
await resource.websocket.close()
106+
for fut in resource.pending_requests.values():
107+
if not fut.done():
108+
fut.cancel()
109+
except Exception as e:
110+
logger.warning(
111+
f"client {client_id} resource websocket close error: {e}"
112+
)
113+
pass
114+
115+
async def start_cleanup_loop(self):
116+
logger.info("ResourceManager: active cleanup loop")
117+
while True:
118+
await asyncio.sleep(60) # Check every minute
119+
logger.debug("cleanup loop running...")
120+
now = time.time()
121+
to_remove = []
122+
for client_id, resource in self.resources.items():
123+
logger.debug(
124+
f"check {client_id}, last_active_time={resource.last_active_time}, timeout={self.timeout_seconds}"
125+
)
126+
if now - resource.last_active_time > self.timeout_seconds:
127+
to_remove.append(client_id)
128+
129+
for client_id in to_remove:
130+
logger.info(f"Removing inactive client {client_id}")
131+
await self.remove(client_id)
132+
133+
def start(self):
134+
self.cleanup_task = asyncio.create_task(self.start_cleanup_loop())
55135

56-
# ws id -> msg id -> ret
57-
self.pendings: dict[str, dict[str, asyncio.Future]] = {}
136+
def stop(self):
137+
if self.cleanup_task:
138+
self.cleanup_task.cancel()
58139

59-
async def call_mcp_http(self, ws_id: str, request: dict):
140+
async def call_mcp_http(self, client_id: str, request: dict):
60141
"""Forward MCP request to client."""
61-
try:
62-
ws = self.connections[ws_id]
63-
except KeyError:
64-
logger.error(f"Websocket {ws_id} not found")
142+
resource = self.get(client_id)
143+
if not resource:
144+
logger.error(f"Client {client_id} not found")
65145
return b""
66146

67-
msg = {}
68-
69-
msg["id"] = str(uuid.uuid4())
70-
msg["type"] = "http_request"
71-
msg["payload"] = request
147+
ws = resource.websocket
148+
msg = {"id": str(uuid.uuid4()), "type": "http_request", "payload": request}
72149

73150
fut = asyncio.get_event_loop().create_future()
74151

75-
if ws_id not in self.pendings:
76-
self.pendings[ws_id] = {}
77-
78-
self.pendings[ws_id][msg["id"]] = fut
152+
resource.pending_requests[msg["id"]] = fut
79153

80154
await ws.send_text(json.dumps(msg))
81155
return await fut
82156

83-
async def handle_ws_message(self, ws_id: str, raw: str):
157+
async def handle_ws_message(self, client_id: str, raw: str):
158+
resource = self.get(client_id)
159+
if not resource:
160+
return
161+
84162
msg = json.loads(raw)
85163
if msg.get("type") != "http_response":
86164
return
87165

88166
req_id = msg["id"]
89-
fut = self.pendings[ws_id].pop(req_id, None)
167+
fut = resource.pending_requests.pop(req_id, None)
90168
if fut:
91169
fut.set_result(msg)
170+
# todo : 异常ID处理
92171

93172

94173
class ServerWithReverseMCP:
@@ -102,27 +181,36 @@ def __init__(
102181
extra_routes: list[ExtraRoute] | None = None,
103182
):
104183
self.agent = agent
105-
106184
self.host = host
107185
self.port = port
108-
109186
self.extra_routes = extra_routes
110187

111-
self.app = FastAPI(
112-
openapi_url=None,
113-
docs_url=None,
114-
redoc_url=None,
115-
swagger_ui_oauth2_redirect_url=None,
188+
self.app = FastAPI()
189+
origins = [
190+
"*", # 允许所有源(开发环境可用,生产环境不推荐)
191+
]
192+
193+
self.app.add_middleware(
194+
CORSMiddleware,
195+
allow_origins=origins, # 允许访问的源
196+
allow_credentials=True, # 允许携带Cookie
197+
allow_methods=["*"], # 允许所有HTTP方法(GET、POST、PUT等)
198+
allow_headers=["*"], # 允许所有请求头
116199
)
117200

118201
self.artifact_service = InMemoryArtifactService()
202+
self.resource_manager = ResourceManager()
119203

120204
# build routes for self.app
121205
self.build()
122206

123-
self.ws_session_mgr = WebsocketSessionManager()
124-
self.ws_agent_mgr: dict[str, "Agent"] = {}
125-
self.ws_session_service_mgr: dict[str, "InMemorySessionService"] = {}
207+
@self.app.on_event("startup")
208+
async def startup_event():
209+
self.resource_manager.start()
210+
211+
@self.app.on_event("shutdown")
212+
async def shutdown_event():
213+
self.resource_manager.stop()
126214

127215
def build(self):
128216
logger.info("Build routes for server with reverse mcp")
@@ -149,9 +237,18 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
149237
session_id = payload.session_id
150238
prompt = payload.prompt
151239

152-
agent = self.ws_agent_mgr[payload.websocket_id]
240+
resource = self.resource_manager.get(payload.websocket_id)
241+
if not resource:
242+
raise HTTPException(
243+
status_code=404, detail=f"Client {payload.websocket_id} not found"
244+
)
245+
agent = resource.agent
153246

154-
runner = Runner(app_name=payload.app_name, agent=agent)
247+
runner = Runner(
248+
app_name=payload.app_name,
249+
agent=agent,
250+
session_service=resource.session_service,
251+
)
155252
response = await runner.run(
156253
messages=[prompt],
157254
user_id=user_id,
@@ -160,6 +257,12 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
160257

161258
return InvokeResponse(response=response)
162259

260+
@self.app.delete("/management/clients/{client_id}")
261+
async def delete_client(client_id: str):
262+
"""Manually remove a client resource."""
263+
await self.resource_manager.remove(client_id)
264+
return {"status": "success", "client_id": client_id}
265+
163266
# build websocket endpoint
164267
@self.app.websocket("/ws")
165268
async def ws_endpoint(ws: WebSocket):
@@ -179,15 +282,10 @@ async def ws_endpoint(ws: WebSocket):
179282
filters = [t.strip() for t in filters_str.split(",") if t.strip()]
180283

181284
logger.info(f"Register websocket {client_id} to session manager.")
182-
self.ws_session_mgr.connections[client_id] = ws
183285

184286
logger.info(f"Fork agent for websocket {client_id}")
185287
agent = self.agent.clone()
186288

187-
logger.info(
188-
f"clone agent \n model_name={agent.model_name}\n instruction={agent.instruction}\n"
189-
)
190-
191289
# Mount MCPToolset when creating agent
192290
mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp"
193291
mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: client_id}
@@ -201,10 +299,18 @@ async def ws_endpoint(ws: WebSocket):
201299
tool_filter=filters,
202300
)
203301
)
204-
self.ws_agent_mgr[client_id] = agent
205302

206303
logger.info(f"Create session service for websocket {client_id}")
207-
self.ws_session_service_mgr[client_id] = InMemorySessionService()
304+
session_service = InMemorySessionService()
305+
artifact_service = InMemoryArtifactService()
306+
307+
self.resource_manager.register(
308+
client_id=client_id,
309+
websocket=ws,
310+
agent=agent,
311+
session_service=session_service,
312+
artifact_service=artifact_service,
313+
)
208314

209315
await ws.accept()
210316
logger.info(f"Websocket {client_id} connected")
@@ -213,7 +319,7 @@ async def ws_endpoint(ws: WebSocket):
213319
while True:
214320
raw = await ws.receive_text()
215321
logger.debug(f"ws.receive_text() = {raw}")
216-
await self.ws_session_mgr.handle_ws_message(client_id, raw)
322+
await self.resource_manager.handle_ws_message(client_id, raw)
217323
except Exception as e:
218324
logger.warning(f"client {client_id} web socket connection closed: {e}")
219325

@@ -227,12 +333,12 @@ class RunAgentRequestWithWsId(RunAgentRequest):
227333

228334
def _get_session_service(websocket_id: str) -> InMemorySessionService:
229335
"""Get session service for the websocket client."""
230-
if websocket_id not in self.ws_session_service_mgr:
336+
resource = self.resource_manager.get(websocket_id)
337+
if not resource:
231338
raise HTTPException(
232-
status_code=404,
233-
detail=f"WebSocket client {websocket_id} not found",
339+
status_code=404, detail=f"WebSocket client {websocket_id} not found"
234340
)
235-
return self.ws_session_service_mgr[websocket_id]
341+
return resource.session_service
236342

237343
@self.app.post(
238344
"/apps/{app_name}/users/{user_id}/sessions",
@@ -291,11 +397,18 @@ async def create_session_with_id(
291397
return session
292398

293399
@self.app.post("/run_sse")
294-
async def run_agent_sse(
295-
req: RunAgentRequestWithWsId,
296-
) -> StreamingResponse:
400+
async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse:
297401
"""Run agent with SSE streaming."""
298-
session_service = _get_session_service(req.websocket_id)
402+
resource = self.resource_manager.get(req.websocket_id)
403+
if not resource:
404+
raise HTTPException(
405+
status_code=404,
406+
detail=f"WebSocket client {req.websocket_id} not found",
407+
)
408+
409+
session_service = resource.session_service
410+
agent = resource.agent
411+
logger.debug(f"Using agent from websocket {req.websocket_id}")
299412

300413
# Get session
301414
session = await session_service.get_session(
@@ -306,16 +419,6 @@ async def run_agent_sse(
306419
if not session:
307420
raise HTTPException(status_code=404, detail="Session not found")
308421

309-
# Get agent for this websocket
310-
if req.websocket_id in self.ws_agent_mgr:
311-
agent = self.ws_agent_mgr[req.websocket_id]
312-
logger.debug(f"Using agent from websocket {req.websocket_id}")
313-
else:
314-
raise HTTPException(
315-
status_code=404,
316-
detail=f"WebSocket client {req.websocket_id} not found",
317-
)
318-
319422
# Create runner
320423
runner = GoogleRunner(
321424
agent=agent,
@@ -354,10 +457,7 @@ async def event_generator():
354457
content_event.actions.artifact_delta = {}
355458
artifact_event = event.model_copy(deep=True)
356459
artifact_event.content = None
357-
events_to_stream = [
358-
content_event,
359-
artifact_event,
360-
]
460+
events_to_stream = [content_event, artifact_event]
361461

362462
for event_to_stream in events_to_stream:
363463
sse_event = event_to_stream.model_dump_json(
@@ -367,7 +467,7 @@ async def event_generator():
367467
yield f"data: {sse_event}\n\n"
368468
except Exception as e:
369469
logger.exception(f"Error in event_generator: {e}")
370-
yield f"data: {json.dumps({'error': 'Internal server error'})}\n\n"
470+
yield f"data: {json.dumps({'error': str(e)})}\n\n"
371471

372472
return StreamingResponse(
373473
event_generator(),
@@ -391,8 +491,7 @@ async def mcp_proxy(path: str, request: Request):
391491
if not client_id:
392492
return Response("client id not found", status_code=400)
393493

394-
ws = self.ws_session_mgr.connections.get(client_id)
395-
if not ws:
494+
if not self.resource_manager.get(client_id):
396495
return Response("websocket `client_id` not connected", status_code=503)
397496

398497
body = await request.body()
@@ -409,7 +508,7 @@ async def mcp_proxy(path: str, request: Request):
409508

410509
logger.debug(f"[Reverse mcp proxy] Request from agent: {payload}")
411510

412-
resp = await self.ws_session_mgr.call_mcp_http(client_id, payload)
511+
resp = await self.resource_manager.call_mcp_http(client_id, payload)
413512

414513
logger.debug(f"[Reverse mcp proxy] Response from local: {resp}")
415514

0 commit comments

Comments
 (0)