Skip to content

Commit 3ab9b30

Browse files
committed
fix: filter_mcp_tools
1 parent 4b9f2cf commit 3ab9b30

File tree

2 files changed

+36
-64
lines changed

2 files changed

+36
-64
lines changed

veadk/toolkits/apps/reverse_mcp/client_with_reverse_mcp.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,24 @@
2323

2424

2525
class ClientWithReverseMCP:
26-
def __init__(self, ws_url: str, mcp_server_url: str, client_id: str):
26+
def __init__(
27+
self,
28+
ws_url: str,
29+
mcp_server_url: str,
30+
client_id: str,
31+
mcp_tool_filter: list[str] | None = None,
32+
):
2733
"""Start a client with reverse mcp,
2834
2935
Args:
3036
ws_url: The url of the websocket server (cloud). Like example.com:8000
3137
mcp_server_url: The url of the mcp server (local).
38+
client_id: The client id for the websocket connection.
39+
mcp_tool_filter: Optional list of tool names to filter. If None, all tools are available.
3240
"""
3341
self.ws_url = f"ws://{ws_url}/ws?id={client_id}"
42+
if mcp_tool_filter:
43+
self.ws_url += f"&mcp_tool_filter={','.join(mcp_tool_filter)}"
3444
self.mcp_server_url = mcp_server_url
3545

3646
# set timeout for httpx client

veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,6 @@ class InvokeRequest(BaseModel):
121121

122122
websocket_id: str
123123

124-
mcp_tool_filter: Optional[list[str]] = None
125-
126124
class InvokeResponse(BaseModel):
127125
"""Response model for /invoke endpoint"""
128126

@@ -137,35 +135,6 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
137135

138136
agent = self.ws_agent_mgr[payload.websocket_id]
139137

140-
mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp"
141-
mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: payload.websocket_id}
142-
143-
has_mcp_toolset = False
144-
for tool in agent.tools:
145-
if isinstance(tool, MCPToolset):
146-
if hasattr(tool, "_connection_params"):
147-
conn_params = tool._connection_params
148-
if (
149-
hasattr(conn_params, "url")
150-
and conn_params.url == mcp_toolset_url
151-
and hasattr(conn_params, "headers")
152-
and conn_params.headers == mcp_toolset_headers
153-
):
154-
has_mcp_toolset = True
155-
break
156-
157-
if not has_mcp_toolset:
158-
logger.debug("Mount fake MCPToolset to agent")
159-
agent.tools.append(
160-
MCPToolset(
161-
connection_params=StreamableHTTPConnectionParams(
162-
url=mcp_toolset_url,
163-
headers=mcp_toolset_headers,
164-
),
165-
tool_filter=payload.mcp_tool_filter,
166-
)
167-
)
168-
169138
runner = Runner(app_name=payload.app_name, agent=agent)
170139
response = await runner.run(
171140
messages=[prompt],
@@ -179,18 +148,42 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
179148
@self.app.websocket("/ws")
180149
async def ws_endpoint(ws: WebSocket):
181150
client_id = ws.query_params.get("id")
151+
182152
if not client_id:
183153
await ws.close(
184154
code=400,
185155
reason="WebSocket `id` is required like `/ws?id=my_id`",
186156
)
187157
return
188158

159+
# Parse mcp_tool_filter from query params, comma-separated string
160+
mcp_tool_filter_str = ws.query_params.get("mcp_tool_filter")
161+
mcp_tool_filter = None
162+
if mcp_tool_filter_str:
163+
mcp_tool_filter = [
164+
t.strip() for t in mcp_tool_filter_str.split(",") if t.strip()
165+
]
166+
189167
logger.info(f"Register websocket {client_id} to session manager.")
190168
self.ws_session_mgr.connections[client_id] = ws
191169

192170
logger.info(f"Fork agent for websocket {client_id}")
193-
self.ws_agent_mgr[client_id] = self.agent.clone()
171+
agent = self.agent.clone()
172+
173+
# Mount MCPToolset when creating agent
174+
mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp"
175+
mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: client_id}
176+
logger.debug(f"Mount MCPToolset to agent for websocket {client_id}")
177+
agent.tools.append(
178+
MCPToolset(
179+
connection_params=StreamableHTTPConnectionParams(
180+
url=mcp_toolset_url,
181+
headers=mcp_toolset_headers,
182+
),
183+
tool_filter=mcp_tool_filter,
184+
)
185+
)
186+
self.ws_agent_mgr[client_id] = agent
194187

195188
logger.info(f"Create session service for websocket {client_id}")
196189
self.ws_session_service_mgr[client_id] = InMemorySessionService()
@@ -209,7 +202,6 @@ class CreateSessionRequest(BaseModel):
209202

210203
class RunAgentRequestWithWsId(RunAgentRequest):
211204
websocket_id: str
212-
mcp_tool_filter: Optional[list[str]] = None
213205

214206
def _get_session_service(websocket_id: str) -> InMemorySessionService:
215207
"""Get session service for the websocket client."""
@@ -299,36 +291,6 @@ async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse:
299291
detail=f"WebSocket client {req.websocket_id} not found",
300292
)
301293

302-
# Mount MCPToolset if needed
303-
mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp"
304-
mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: req.websocket_id}
305-
306-
has_mcp_toolset = False
307-
for tool in agent.tools:
308-
if isinstance(tool, MCPToolset):
309-
if hasattr(tool, "_connection_params"):
310-
conn_params = tool._connection_params
311-
if (
312-
hasattr(conn_params, "url")
313-
and conn_params.url == mcp_toolset_url
314-
and hasattr(conn_params, "headers")
315-
and conn_params.headers == mcp_toolset_headers
316-
):
317-
has_mcp_toolset = True
318-
break
319-
320-
if not has_mcp_toolset:
321-
logger.debug("Mount fake MCPToolset to agent for SSE")
322-
agent.tools.append(
323-
MCPToolset(
324-
connection_params=StreamableHTTPConnectionParams(
325-
url=mcp_toolset_url,
326-
headers=mcp_toolset_headers,
327-
),
328-
tool_filter=req.mcp_tool_filter,
329-
)
330-
)
331-
332294
# Create runner
333295
runner = GoogleRunner(
334296
agent=agent,

0 commit comments

Comments
 (0)