Skip to content

Commit 22f4599

Browse files
sjy3zakahan
andauthored
fix(janus): fix the remote web socket closed expection (#510)
* feat: add create_session and run_sse on reverse mcp app * feat: reverse mcp with session_service_mgr * fix: filter_mcp_tools * chore: rename filter * chore: support streaming mode * fix: filtered_headers of content-length * fix: fix the remote web socket closed expection * fix: fix the trusted_mcp_components test --------- Co-authored-by: hanzhi.421 <hanzhi.421@bytedance.com>
1 parent c5e7293 commit 22f4599

File tree

3 files changed

+234
-24
lines changed

3 files changed

+234
-24
lines changed

tests/tools/mcp_tool/test_trusted_mcp_components.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ async def aclose(self):
170170
connection_params=self.mock_http_params, errlog=sys.stderr
171171
)
172172
manager._sessions = {}
173-
manager._session_lock = asyncio.Lock()
173+
# manager._session_lock = asyncio.Lock()
174174

175175
# Call create_session
176176
headers = {"x-trusted-mcp": "true"}
@@ -207,7 +207,7 @@ async def run_test():
207207
connection_params=self.mock_http_params, errlog=sys.stderr
208208
)
209209
manager._sessions = {}
210-
manager._session_lock = asyncio.Lock()
210+
# manager._session_lock = asyncio.Lock()
211211

212212
# Set up an existing session
213213
existing_session = mock.MagicMock()

veadk/toolkits/apps/reverse_mcp/client_with_reverse_mcp.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,24 @@
2323

2424

2525
class ClientWithReverseMCP:
26-
def __init__(self, ws_url: str, mcp_server_url: str, client_id: str):
26+
def __init__(
27+
self,
28+
ws_url: str,
29+
mcp_server_url: str,
30+
client_id: str,
31+
filters: list[str] | None = None,
32+
):
2733
"""Start a client with reverse mcp,
2834
2935
Args:
3036
ws_url: The url of the websocket server (cloud). Like example.com:8000
3137
mcp_server_url: The url of the mcp server (local).
38+
client_id: The client id for the websocket connection.
39+
filters: Optional list of tool names to filter (whitelist). If None, all tools are available.
3240
"""
3341
self.ws_url = f"ws://{ws_url}/ws?id={client_id}"
42+
if filters:
43+
self.ws_url += f"&filters={','.join(filters)}"
3444
self.mcp_server_url = mcp_server_url
3545

3646
# set timeout for httpx client

veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py

Lines changed: 221 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515
import asyncio
1616
import json
1717
import uuid
18-
from typing import TYPE_CHECKING
19-
20-
from fastapi import FastAPI, Request, Response, WebSocket
18+
from typing import TYPE_CHECKING, Any, Optional
19+
20+
from fastapi import FastAPI, HTTPException, Request, Response, WebSocket
21+
from fastapi.responses import StreamingResponse
22+
from google.adk.agents.run_config import StreamingMode
23+
from google.adk.artifacts import InMemoryArtifactService
24+
from google.adk.cli.adk_web_server import RunAgentRequest
25+
from google.adk.runners import Runner as GoogleRunner, RunConfig
26+
from google.adk.sessions import InMemorySessionService, Session
2127
from google.adk.tools.mcp_tool.mcp_session_manager import (
2228
StreamableHTTPConnectionParams,
2329
)
2430
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
31+
from google.adk.utils.context_utils import Aclosing
2532
from pydantic import BaseModel
2633

2734
from veadk import Runner
@@ -93,11 +100,15 @@ def __init__(
93100
self.port = port
94101

95102
self.app = FastAPI()
103+
104+
self.artifact_service = InMemoryArtifactService()
105+
96106
# build routes for self.app
97107
self.build()
98108

99109
self.ws_session_mgr = WebsocketSessionManager()
100110
self.ws_agent_mgr: dict[str, "Agent"] = {}
111+
self.ws_session_service_mgr: dict[str, "InMemorySessionService"] = {}
101112

102113
def build(self):
103114
logger.info("Build routes for server with reverse mcp")
@@ -126,19 +137,6 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
126137

127138
agent = self.ws_agent_mgr[payload.websocket_id]
128139

129-
if not agent.tools:
130-
logger.debug("Mount fake MCPToolset to agent")
131-
132-
# we hard code the mcp url with `/mcp` to obey the mcp protocol
133-
agent.tools.append(
134-
MCPToolset(
135-
connection_params=StreamableHTTPConnectionParams(
136-
url=f"http://127.0.0.1:{self.port}/mcp",
137-
headers={REVERSE_MCP_HEADER_KEY: payload.websocket_id},
138-
),
139-
)
140-
)
141-
142140
runner = Runner(app_name=payload.app_name, agent=agent)
143141
response = await runner.run(
144142
messages=[prompt],
@@ -152,28 +150,213 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
152150
@self.app.websocket("/ws")
153151
async def ws_endpoint(ws: WebSocket):
154152
client_id = ws.query_params.get("id")
153+
155154
if not client_id:
156155
await ws.close(
157156
code=400,
158157
reason="WebSocket `id` is required like `/ws?id=my_id`",
159158
)
160159
return
161160

161+
# Parse filters from query params, comma-separated string
162+
filters_str = ws.query_params.get("filters")
163+
filters = None
164+
if filters_str:
165+
filters = [t.strip() for t in filters_str.split(",") if t.strip()]
166+
162167
logger.info(f"Register websocket {client_id} to session manager.")
163168
self.ws_session_mgr.connections[client_id] = ws
164169

165170
logger.info(f"Fork agent for websocket {client_id}")
166-
self.ws_agent_mgr[client_id] = self.agent.clone()
171+
agent = self.agent.clone()
172+
173+
logger.info(
174+
f"clone agent \n model_name={agent.model_name}\n instruction={agent.instruction}\n"
175+
)
176+
177+
# Mount MCPToolset when creating agent
178+
mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp"
179+
mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: client_id}
180+
logger.debug(f"Mount MCPToolset to agent for websocket {client_id}")
181+
agent.tools.append(
182+
MCPToolset(
183+
connection_params=StreamableHTTPConnectionParams(
184+
url=mcp_toolset_url,
185+
headers=mcp_toolset_headers,
186+
),
187+
tool_filter=filters,
188+
)
189+
)
190+
self.ws_agent_mgr[client_id] = agent
191+
192+
logger.info(f"Create session service for websocket {client_id}")
193+
self.ws_session_service_mgr[client_id] = InMemorySessionService()
167194

168195
await ws.accept()
169196
logger.info(f"Websocket {client_id} connected")
170197

171-
while True:
172-
raw = await ws.receive_text()
173-
await self.ws_session_mgr.handle_ws_message(client_id, raw)
198+
try:
199+
while True:
200+
raw = await ws.receive_text()
201+
logger.debug(f"ws.receive_text() = {raw}")
202+
await self.ws_session_mgr.handle_ws_message(client_id, raw)
203+
except Exception as e:
204+
logger.warning(f"client {client_id} web socket connection closed: {e}")
205+
206+
class CreateSessionRequest(BaseModel):
207+
state: Optional[dict[str, Any]] = None
208+
session_id: Optional[str] = None
209+
websocket_id: str
210+
211+
class RunAgentRequestWithWsId(RunAgentRequest):
212+
websocket_id: str
213+
214+
def _get_session_service(websocket_id: str) -> InMemorySessionService:
215+
"""Get session service for the websocket client."""
216+
if websocket_id not in self.ws_session_service_mgr:
217+
raise HTTPException(
218+
status_code=404, detail=f"WebSocket client {websocket_id} not found"
219+
)
220+
return self.ws_session_service_mgr[websocket_id]
221+
222+
@self.app.post(
223+
"/apps/{app_name}/users/{user_id}/sessions",
224+
response_model_exclude_none=True,
225+
)
226+
async def create_session(
227+
app_name: str,
228+
user_id: str,
229+
req: CreateSessionRequest,
230+
) -> Session:
231+
"""Create a new session."""
232+
session_id = req.session_id if req.session_id else str(uuid.uuid4())
233+
session = Session(
234+
app_name=app_name,
235+
user_id=user_id,
236+
id=session_id,
237+
state=req.state if req.state else {},
238+
)
239+
session_service = _get_session_service(req.websocket_id)
240+
await session_service.create_session(
241+
app_name=app_name,
242+
user_id=user_id,
243+
session_id=session_id,
244+
state=req.state if req.state else {},
245+
)
246+
logger.info(
247+
f"Created session: {session_id} for user {user_id} in app {app_name}"
248+
)
249+
return session
250+
251+
@self.app.post(
252+
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
253+
response_model_exclude_none=True,
254+
)
255+
async def create_session_with_id(
256+
app_name: str,
257+
user_id: str,
258+
session_id: str,
259+
req: CreateSessionRequest,
260+
) -> Session:
261+
"""Create a session with specific ID."""
262+
session_service = _get_session_service(req.websocket_id)
263+
await session_service.create_session(
264+
app_name=app_name,
265+
user_id=user_id,
266+
session_id=session_id,
267+
state=req.state if req.state else {},
268+
)
269+
session = Session(
270+
app_name=app_name,
271+
user_id=user_id,
272+
id=session_id,
273+
state=req.state if req.state else {},
274+
)
275+
logger.info(f"Created session with ID: {session_id} for user {user_id}")
276+
return session
277+
278+
@self.app.post("/run_sse")
279+
async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse:
280+
"""Run agent with SSE streaming."""
281+
session_service = _get_session_service(req.websocket_id)
282+
283+
# Get session
284+
session = await session_service.get_session(
285+
app_name=req.app_name,
286+
user_id=req.user_id,
287+
session_id=req.session_id,
288+
)
289+
if not session:
290+
raise HTTPException(status_code=404, detail="Session not found")
291+
292+
# Get agent for this websocket
293+
if req.websocket_id in self.ws_agent_mgr:
294+
agent = self.ws_agent_mgr[req.websocket_id]
295+
logger.debug(f"Using agent from websocket {req.websocket_id}")
296+
else:
297+
raise HTTPException(
298+
status_code=404,
299+
detail=f"WebSocket client {req.websocket_id} not found",
300+
)
301+
302+
# Create runner
303+
runner = GoogleRunner(
304+
agent=agent,
305+
app_name=req.app_name,
306+
session_service=session_service,
307+
artifact_service=self.artifact_service,
308+
)
309+
310+
# Determine streaming mode from request
311+
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
312+
313+
async def event_generator():
314+
try:
315+
async with Aclosing(
316+
runner.run_async(
317+
user_id=req.user_id,
318+
session_id=req.session_id,
319+
new_message=req.new_message,
320+
state_delta=req.state_delta,
321+
run_config=RunConfig(streaming_mode=stream_mode),
322+
invocation_id=req.invocation_id,
323+
)
324+
) as agen:
325+
async for event in agen:
326+
# ADK Web renders artifacts from `actions.artifactDelta`
327+
# during part processing *and* during action processing
328+
# 1) the original event with `artifactDelta` cleared (content)
329+
# 2) a content-less "action-only" event carrying `artifactDelta`
330+
events_to_stream = [event]
331+
if (
332+
event.actions.artifact_delta
333+
and event.content
334+
and event.content.parts
335+
):
336+
content_event = event.model_copy(deep=True)
337+
content_event.actions.artifact_delta = {}
338+
artifact_event = event.model_copy(deep=True)
339+
artifact_event.content = None
340+
events_to_stream = [content_event, artifact_event]
341+
342+
for event_to_stream in events_to_stream:
343+
sse_event = event_to_stream.model_dump_json(
344+
exclude_none=True, by_alias=True
345+
)
346+
logger.debug(f"SSE event: {sse_event}")
347+
yield f"data: {sse_event}\n\n"
348+
except Exception as e:
349+
logger.exception(f"Error in event_generator: {e}")
350+
yield f"data: {json.dumps({'error': 'Internal server error'})}\n\n"
351+
352+
return StreamingResponse(
353+
event_generator(),
354+
media_type="text/event-stream",
355+
)
174356

175357
# build the fake MPC server,
176358
# and intercept all requests to the client websocket client.
359+
# NOTE: This catch-all route must be defined LAST
177360
@self.app.api_route("/{path:path}", methods=["GET", "POST"])
178361
async def mcp_proxy(path: str, request: Request):
179362
client_id = request.headers.get(REVERSE_MCP_HEADER_KEY)
@@ -202,10 +385,27 @@ async def mcp_proxy(path: str, request: Request):
202385

203386
logger.debug(f"[Reverse mcp proxy] Response from local: {resp}")
204387

388+
# Filter hop-by-hop headers to avoid Content-Length mismatch
389+
headers = resp["payload"]["headers"]
390+
hop_by_hop_headers = {
391+
"content-length",
392+
"transfer-encoding",
393+
"connection",
394+
"keep-alive",
395+
"proxy-authenticate",
396+
"proxy-authorization",
397+
"te",
398+
"trailers",
399+
"upgrade",
400+
}
401+
filtered_headers = {
402+
k: v for k, v in headers.items() if k.lower() not in hop_by_hop_headers
403+
}
404+
205405
return Response(
206406
content=resp["payload"]["body"], # type: ignore
207407
status_code=resp["payload"]["status"], # type: ignore
208-
headers=resp["payload"]["headers"], # type: ignore
408+
headers=filtered_headers, # type: ignore
209409
)
210410

211411
def run(self):

0 commit comments

Comments
 (0)