Skip to content

Commit d2e9c93

Browse files
authored
Include agent in ToolContext tool calls (#2446)
Summary - propagate the current `Agent` into every new `ToolContext`, including realtime session and tool execution paths - extend `ToolContext` to accept an `agent` keyword, preserve backwards-compatible constructors, and add regression/unit coverage for the new behavior Testing - Not run (not requested)
1 parent a53d6bd commit d2e9c93

File tree

8 files changed

+128
-1
lines changed

8 files changed

+128
-1
lines changed

src/agents/agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ async def _run_agent_impl(context: ToolContext, input_json: str) -> Any:
599599
tool_call_id=context.tool_call_id,
600600
tool_arguments=context.tool_arguments,
601601
tool_call=context.tool_call,
602+
agent=context.agent,
602603
)
603604
if should_capture_tool_input:
604605
nested_context.tool_input = params_data

src/agents/realtime/session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ async def _handle_tool_call(
600600
tool_name=event.name,
601601
tool_call_id=event.call_id,
602602
tool_arguments=event.arguments,
603+
agent=agent,
603604
)
604605
result = await func_tool.on_invoke_tool(tool_context, event.arguments)
605606

@@ -626,6 +627,7 @@ async def _handle_tool_call(
626627
tool_name=event.name,
627628
tool_call_id=event.call_id,
628629
tool_arguments=event.arguments,
630+
agent=agent,
629631
)
630632

631633
# Execute the handoff to get the new agent

src/agents/run_internal/tool_execution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo
820820
context_wrapper,
821821
tool_call.call_id,
822822
tool_call=tool_call,
823+
agent=agent,
823824
)
824825
agent_hooks = agent.hooks
825826
if config.trace_include_sensitive_data:

src/agents/tool_context.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .usage import Usage
1010

1111
if TYPE_CHECKING:
12+
from .agent import AgentBase
1213
from .items import TResponseInputItem
1314
from .run_context import _ApprovalRecord
1415

@@ -44,6 +45,9 @@ class ToolContext(RunContextWrapper[TContext]):
4445
tool_call: ResponseFunctionToolCall | None = None
4546
"""The tool call object associated with this invocation."""
4647

48+
agent: AgentBase[Any] | None = None
49+
"""The active agent for this tool call, when available."""
50+
4751
def __init__(
4852
self,
4953
context: TContext,
@@ -53,6 +57,7 @@ def __init__(
5357
tool_arguments: str | object = _MISSING,
5458
tool_call: ResponseFunctionToolCall | None = None,
5559
*,
60+
agent: AgentBase[Any] | None = None,
5661
turn_input: list[TResponseInputItem] | None = None,
5762
_approvals: dict[str, _ApprovalRecord] | None = None,
5863
tool_input: Any | None = None,
@@ -80,13 +85,15 @@ def __init__(
8085
else cast(str, tool_call_id)
8186
)
8287
self.tool_call = tool_call
88+
self.agent = agent
8389

8490
@classmethod
8591
def from_agent_context(
8692
cls,
8793
context: RunContextWrapper[TContext],
8894
tool_call_id: str,
8995
tool_call: ResponseFunctionToolCall | None = None,
96+
agent: AgentBase[Any] | None = None,
9097
) -> ToolContext:
9198
"""
9299
Create a ToolContext from a RunContextWrapper.
@@ -99,12 +106,16 @@ def from_agent_context(
99106
tool_args = (
100107
tool_call.arguments if tool_call is not None else _assert_must_pass_tool_arguments()
101108
)
109+
tool_agent = agent
110+
if tool_agent is None and isinstance(context, ToolContext):
111+
tool_agent = context.agent
102112

103113
tool_context = cls(
104114
tool_name=tool_name,
105115
tool_call_id=tool_call_id,
106116
tool_arguments=tool_args,
107117
tool_call=tool_call,
118+
agent=tool_agent,
108119
**base_values,
109120
)
110121
return tool_context

tests/realtime/test_session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,7 @@ async def test_function_tool_execution_success(
989989
call_args = mock_function_tool.on_invoke_tool.call_args
990990
tool_context = call_args[0][0]
991991
assert isinstance(tool_context, ToolContext)
992+
assert tool_context.agent == mock_agent
992993
assert call_args[0][1] == '{"param": "value"}'
993994

994995
# Verify tool output was sent to model

tests/test_agent_runner.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from agents.run_internal.tool_use_tracker import AgentToolUseTracker
6464
from agents.run_state import RunState
6565
from agents.tool import ComputerTool, FunctionToolResult, function_tool
66+
from agents.tool_context import ToolContext
6667
from agents.usage import Usage
6768

6869
from .fake_model import FakeModel
@@ -437,6 +438,36 @@ async def test_tool_call_runs():
437438
)
438439

439440

441+
@pytest.mark.asyncio
442+
async def test_tool_call_context_includes_current_agent() -> None:
443+
model = FakeModel()
444+
captured_contexts: list[ToolContext[Any]] = []
445+
446+
@function_tool(name_override="foo")
447+
def foo(context: ToolContext[Any]) -> str:
448+
captured_contexts.append(context)
449+
return "tool_result"
450+
451+
agent = Agent(
452+
name="test",
453+
model=model,
454+
tools=[foo],
455+
)
456+
457+
model.add_multiple_turn_outputs(
458+
[
459+
[get_function_tool_call("foo", "{}")],
460+
[get_text_message("done")],
461+
]
462+
)
463+
464+
result = await Runner.run(agent, input="user_message")
465+
466+
assert result.final_output == "done"
467+
assert len(captured_contexts) == 1
468+
assert captured_contexts[0].agent is agent
469+
470+
440471
@pytest.mark.asyncio
441472
async def test_handoffs():
442473
model = FakeModel()

tests/test_source_compat_constructors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ def test_tool_context_v070_positional_constructor_still_works() -> None:
8383
assert context.tool_name == "tool_name"
8484
assert context.tool_call_id == "call_id"
8585
assert context.tool_arguments == '{"x":1}'
86+
assert context.agent is None
87+
88+
89+
def test_tool_context_supports_agent_keyword_argument() -> None:
90+
usage = Usage()
91+
agent = Agent(name="agent")
92+
context = ToolContext(None, usage, "tool_name", "call_id", '{"x":1}', None, agent=agent)
93+
94+
assert context.usage is usage
95+
assert context.tool_name == "tool_name"
96+
assert context.tool_call_id == "call_id"
97+
assert context.tool_arguments == '{"x":1}'
98+
assert context.agent is agent
8699

87100

88101
def test_run_result_v070_positional_constructor_still_works() -> None:

tests/test_tool_context.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from openai.types.responses import ResponseFunctionToolCall
33

4+
from agents import Agent
45
from agents.run_context import RunContextWrapper
56
from agents.tool_context import ToolContext
67
from tests.utils.hitl import make_context_wrapper
@@ -30,9 +31,75 @@ def test_tool_context_from_agent_context_populates_fields() -> None:
3031
arguments='{"a": 1}',
3132
)
3233
ctx = make_context_wrapper()
34+
agent = Agent(name="agent")
3335

34-
tool_ctx = ToolContext.from_agent_context(ctx, tool_call_id="call-123", tool_call=tool_call)
36+
tool_ctx = ToolContext.from_agent_context(
37+
ctx,
38+
tool_call_id="call-123",
39+
tool_call=tool_call,
40+
agent=agent,
41+
)
3542

3643
assert tool_ctx.tool_name == "test_tool"
3744
assert tool_ctx.tool_call_id == "call-123"
3845
assert tool_ctx.tool_arguments == '{"a": 1}'
46+
assert tool_ctx.agent is agent
47+
48+
49+
def test_tool_context_agent_none_by_default() -> None:
50+
tool_call = ResponseFunctionToolCall(
51+
type="function_call",
52+
name="test_tool",
53+
call_id="call-1",
54+
arguments="{}",
55+
)
56+
ctx = make_context_wrapper()
57+
58+
tool_ctx = ToolContext.from_agent_context(ctx, tool_call_id="call-1", tool_call=tool_call)
59+
60+
assert tool_ctx.agent is None
61+
62+
63+
def test_tool_context_constructor_accepts_agent_keyword() -> None:
64+
agent = Agent(name="direct-agent")
65+
tool_ctx: ToolContext[dict[str, object]] = ToolContext(
66+
context={},
67+
tool_name="my_tool",
68+
tool_call_id="call-2",
69+
tool_arguments="{}",
70+
agent=agent,
71+
)
72+
73+
assert tool_ctx.agent is agent
74+
75+
76+
def test_tool_context_from_tool_context_inherits_agent() -> None:
77+
original_call = ResponseFunctionToolCall(
78+
type="function_call",
79+
name="test_tool",
80+
call_id="call-3",
81+
arguments="{}",
82+
)
83+
derived_call = ResponseFunctionToolCall(
84+
type="function_call",
85+
name="test_tool",
86+
call_id="call-4",
87+
arguments="{}",
88+
)
89+
agent = Agent(name="origin-agent")
90+
parent_context: ToolContext[dict[str, object]] = ToolContext(
91+
context={},
92+
tool_name="test_tool",
93+
tool_call_id="call-3",
94+
tool_arguments="{}",
95+
tool_call=original_call,
96+
agent=agent,
97+
)
98+
99+
derived_context = ToolContext.from_agent_context(
100+
parent_context,
101+
tool_call_id="call-4",
102+
tool_call=derived_call,
103+
)
104+
105+
assert derived_context.agent is agent

0 commit comments

Comments
 (0)