@@ -121,8 +121,6 @@ class InvokeRequest(BaseModel):
121121
122122 websocket_id : str
123123
124- mcp_tool_filter : Optional [list [str ]] = None
125-
126124 class InvokeResponse (BaseModel ):
127125 """Response model for /invoke endpoint"""
128126
@@ -137,35 +135,6 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
137135
138136 agent = self .ws_agent_mgr [payload .websocket_id ]
139137
140- mcp_toolset_url = f"http://127.0.0.1:{ self .port } /mcp"
141- mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY : payload .websocket_id }
142-
143- has_mcp_toolset = False
144- for tool in agent .tools :
145- if isinstance (tool , MCPToolset ):
146- if hasattr (tool , "_connection_params" ):
147- conn_params = tool ._connection_params
148- if (
149- hasattr (conn_params , "url" )
150- and conn_params .url == mcp_toolset_url
151- and hasattr (conn_params , "headers" )
152- and conn_params .headers == mcp_toolset_headers
153- ):
154- has_mcp_toolset = True
155- break
156-
157- if not has_mcp_toolset :
158- logger .debug ("Mount fake MCPToolset to agent" )
159- agent .tools .append (
160- MCPToolset (
161- connection_params = StreamableHTTPConnectionParams (
162- url = mcp_toolset_url ,
163- headers = mcp_toolset_headers ,
164- ),
165- tool_filter = payload .mcp_tool_filter ,
166- )
167- )
168-
169138 runner = Runner (app_name = payload .app_name , agent = agent )
170139 response = await runner .run (
171140 messages = [prompt ],
@@ -179,18 +148,42 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
179148 @self .app .websocket ("/ws" )
180149 async def ws_endpoint (ws : WebSocket ):
181150 client_id = ws .query_params .get ("id" )
151+
182152 if not client_id :
183153 await ws .close (
184154 code = 400 ,
185155 reason = "WebSocket `id` is required like `/ws?id=my_id`" ,
186156 )
187157 return
188158
159+ # Parse mcp_tool_filter from query params, comma-separated string
160+ mcp_tool_filter_str = ws .query_params .get ("mcp_tool_filter" )
161+ mcp_tool_filter = None
162+ if mcp_tool_filter_str :
163+ mcp_tool_filter = [
164+ t .strip () for t in mcp_tool_filter_str .split ("," ) if t .strip ()
165+ ]
166+
189167 logger .info (f"Register websocket { client_id } to session manager." )
190168 self .ws_session_mgr .connections [client_id ] = ws
191169
192170 logger .info (f"Fork agent for websocket { client_id } " )
193- self .ws_agent_mgr [client_id ] = self .agent .clone ()
171+ agent = self .agent .clone ()
172+
173+ # Mount MCPToolset when creating agent
174+ mcp_toolset_url = f"http://127.0.0.1:{ self .port } /mcp"
175+ mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY : client_id }
176+ logger .debug (f"Mount MCPToolset to agent for websocket { client_id } " )
177+ agent .tools .append (
178+ MCPToolset (
179+ connection_params = StreamableHTTPConnectionParams (
180+ url = mcp_toolset_url ,
181+ headers = mcp_toolset_headers ,
182+ ),
183+ tool_filter = mcp_tool_filter ,
184+ )
185+ )
186+ self .ws_agent_mgr [client_id ] = agent
194187
195188 logger .info (f"Create session service for websocket { client_id } " )
196189 self .ws_session_service_mgr [client_id ] = InMemorySessionService ()
@@ -209,7 +202,6 @@ class CreateSessionRequest(BaseModel):
209202
210203 class RunAgentRequestWithWsId (RunAgentRequest ):
211204 websocket_id : str
212- mcp_tool_filter : Optional [list [str ]] = None
213205
214206 def _get_session_service (websocket_id : str ) -> InMemorySessionService :
215207 """Get session service for the websocket client."""
@@ -299,36 +291,6 @@ async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse:
299291 detail = f"WebSocket client { req .websocket_id } not found" ,
300292 )
301293
302- # Mount MCPToolset if needed
303- mcp_toolset_url = f"http://127.0.0.1:{ self .port } /mcp"
304- mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY : req .websocket_id }
305-
306- has_mcp_toolset = False
307- for tool in agent .tools :
308- if isinstance (tool , MCPToolset ):
309- if hasattr (tool , "_connection_params" ):
310- conn_params = tool ._connection_params
311- if (
312- hasattr (conn_params , "url" )
313- and conn_params .url == mcp_toolset_url
314- and hasattr (conn_params , "headers" )
315- and conn_params .headers == mcp_toolset_headers
316- ):
317- has_mcp_toolset = True
318- break
319-
320- if not has_mcp_toolset :
321- logger .debug ("Mount fake MCPToolset to agent for SSE" )
322- agent .tools .append (
323- MCPToolset (
324- connection_params = StreamableHTTPConnectionParams (
325- url = mcp_toolset_url ,
326- headers = mcp_toolset_headers ,
327- ),
328- tool_filter = req .mcp_tool_filter ,
329- )
330- )
331-
332294 # Create runner
333295 runner = GoogleRunner (
334296 agent = agent ,
0 commit comments