Skip to content

fix: multi-source nodes can be skipped under max_turns#7187

Open
majiayu000 wants to merge 2 commits intomicrosoft:mainfrom
majiayu000:fix/6728-tests
Open

fix: multi-source nodes can be skipped under max_turns#7187
majiayu000 wants to merge 2 commits intomicrosoft:mainfrom
majiayu000:fix/6728-tests

Conversation

@majiayu000
Copy link

Why are these changes needed?

When a node has multiple incoming edges, the current "all" activation only enqueues the target after the final parent runs. In a graph like P -> U, P -> E, U -> E with max_turns=2, E never runs because it would be scheduled in a third turn. This change queues the target once any parent triggers, so E can run in the second turn alongside U and avoids being skipped under tight turn budgets.

Repro

import asyncio
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import TextMessage
from autogen_agentchat.teams import GraphFlow
from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import DiGraph, DiGraphEdge, DiGraphNode
from autogen_core import CancellationToken

class EchoAgent(BaseChatAgent):
    def __init__(self, name: str) -> None:
        super().__init__(name, f"Echo agent {name}")

    @property
    def produced_message_types(self):
        return (TextMessage,)

    async def on_messages(self, messages, cancellation_token: CancellationToken) -> Response:
        return Response(chat_message=TextMessage(content="ping", source=self.name))

    async def on_reset(self, cancellation_token: CancellationToken) -> None:
        pass

async def main():
    p = EchoAgent("P")
    u = EchoAgent("U")
    e = EchoAgent("E")
    graph = DiGraph(
        nodes={
            "P": DiGraphNode(name="P", edges=[DiGraphEdge(target="U"), DiGraphEdge(target="E")]),
            "U": DiGraphNode(name="U", edges=[DiGraphEdge(target="E")]),
            "E": DiGraphNode(name="E", edges=[]),
        }
    )
    team = GraphFlow(participants=[p, u, e], graph=graph, max_turns=2)
    result = await team.run(task="Start")
    print([m.source for m in result.messages])

asyncio.run(main())

Output (before):

['user', 'P', 'U']  # E missing because max_turns=2 is reached

Output (after):

['user', 'P', 'U', 'E']  # U/E order may vary

Related issue number

  • N/A

Checks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants