|
19 | 19 |
|
20 | 20 | from fastapi import FastAPI, HTTPException, Request, Response, WebSocket |
21 | 21 | from fastapi.responses import StreamingResponse |
| 22 | +from google.adk.agents.run_config import StreamingMode |
22 | 23 | from google.adk.artifacts import InMemoryArtifactService |
23 | 24 | from google.adk.cli.adk_web_server import RunAgentRequest |
24 | | -from google.adk.runners import Runner as GoogleRunner |
| 25 | +from google.adk.runners import Runner as GoogleRunner, RunConfig |
25 | 26 | from google.adk.sessions import InMemorySessionService, Session |
26 | 27 | from google.adk.tools.mcp_tool.mcp_session_manager import ( |
27 | 28 | StreamableHTTPConnectionParams, |
28 | 29 | ) |
29 | 30 | from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset |
| 31 | +from google.adk.utils.context_utils import Aclosing |
30 | 32 | from pydantic import BaseModel |
31 | 33 |
|
32 | 34 | from veadk import Runner |
@@ -297,19 +299,44 @@ async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse: |
297 | 299 | artifact_service=self.artifact_service, |
298 | 300 | ) |
299 | 301 |
|
| 302 | + # Determine streaming mode from request |
| 303 | + stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE |
| 304 | + |
300 | 305 | async def event_generator(): |
301 | 306 | try: |
302 | | - async for event in runner.run_async( |
303 | | - user_id=req.user_id, |
304 | | - session_id=req.session_id, |
305 | | - new_message=req.new_message, |
306 | | - state_delta=req.state_delta, |
307 | | - ): |
308 | | - event_json = event.model_dump_json( |
309 | | - exclude_none=True, by_alias=True |
| 307 | + async with Aclosing( |
| 308 | + runner.run_async( |
| 309 | + user_id=req.user_id, |
| 310 | + session_id=req.session_id, |
| 311 | + new_message=req.new_message, |
| 312 | + state_delta=req.state_delta, |
| 313 | + run_config=RunConfig(streaming_mode=stream_mode), |
| 314 | + invocation_id=req.invocation_id, |
310 | 315 | ) |
311 | | - logger.debug(f"SSE event: {event_json}") |
312 | | - yield f"data: {event_json}\n\n" |
| 316 | + ) as agen: |
| 317 | + async for event in agen: |
| 318 | + # ADK Web renders artifacts from `actions.artifactDelta` |
| 319 | + # during part processing *and* during action processing |
| 320 | + # 1) the original event with `artifactDelta` cleared (content) |
| 321 | + # 2) a content-less "action-only" event carrying `artifactDelta` |
| 322 | + events_to_stream = [event] |
| 323 | + if ( |
| 324 | + event.actions.artifact_delta |
| 325 | + and event.content |
| 326 | + and event.content.parts |
| 327 | + ): |
| 328 | + content_event = event.model_copy(deep=True) |
| 329 | + content_event.actions.artifact_delta = {} |
| 330 | + artifact_event = event.model_copy(deep=True) |
| 331 | + artifact_event.content = None |
| 332 | + events_to_stream = [content_event, artifact_event] |
| 333 | + |
| 334 | + for event_to_stream in events_to_stream: |
| 335 | + sse_event = event_to_stream.model_dump_json( |
| 336 | + exclude_none=True, by_alias=True |
| 337 | + ) |
| 338 | + logger.debug(f"SSE event: {sse_event}") |
| 339 | + yield f"data: {sse_event}\n\n" |
313 | 340 | except Exception as e: |
314 | 341 | logger.exception(f"Error in event_generator: {e}") |
315 | 342 | yield f"data: {json.dumps({'error': str(e)})}\n\n" |
|
0 commit comments