Skip to content

Commit 1c6f579

Browse files
Fix tool_choice=required to return after tool execution
When tool_choice is 'required', the user's intent is to force exactly one tool call. After the tool executes, return immediately with the function call and result - don't continue to call the model again. This fixes integration tests that were failing with empty text responses because with tool_choice=required, the model would keep returning function calls instead of text. Also adds regression tests for: - conversation_id propagation between tool iterations (from PR #3664) - tool_choice=required returns after tool execution
1 parent 98754f7 commit 1c6f579

File tree

3 files changed

+173
-3
lines changed

3 files changed

+173
-3
lines changed

python/packages/azure-ai/tests/test_azure_ai_client.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,12 +1391,26 @@ async def test_integration_options(
13911391

13921392
assert response is not None
13931393
assert isinstance(response, ChatResponse)
1394-
assert response.text is not None, f"No text in response for option '{option_name}'"
1395-
assert len(response.text) > 0, f"Empty response for option '{option_name}'"
1394+
1395+
# For tool_choice="required", we return after tool execution without a model text response
1396+
is_required_tool_choice = option_name == "tool_choice" and (
1397+
option_value == "required" or (isinstance(option_value, dict) and option_value.get("mode") == "required")
1398+
)
1399+
1400+
if is_required_tool_choice:
1401+
# Response should have function call and function result, but no text from model
1402+
assert len(response.messages) >= 2, f"Expected function call + result for {option_name}"
1403+
has_function_call = any(c.type == "function_call" for msg in response.messages for c in msg.contents)
1404+
has_function_result = any(c.type == "function_result" for msg in response.messages for c in msg.contents)
1405+
assert has_function_call, f"No function call in response for {option_name}"
1406+
assert has_function_result, f"No function result in response for {option_name}"
1407+
else:
1408+
assert response.text is not None, f"No text in response for option '{option_name}'"
1409+
assert len(response.text) > 0, f"Empty response for option '{option_name}'"
13961410

13971411
# Validate based on option type
13981412
if needs_validation:
1399-
if option_name.startswith("tool_choice"):
1413+
if option_name.startswith("tool_choice") and not is_required_tool_choice:
14001414
# Should have called the weather function
14011415
text = response.text.lower()
14021416
assert "sunny" in text or "seattle" in text, f"Tool not invoked for {option_name}"

python/packages/core/agent_framework/_tools.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2194,6 +2194,11 @@ async def _get_response() -> ChatResponse:
21942194
break
21952195
errors_in_a_row = result["errors_in_a_row"]
21962196

2197+
# When tool_choice is 'required', return after tool execution
2198+
# The user's intent is to force exactly one tool call and get the result
2199+
if mutable_options.get("tool_choice") == "required":
2200+
return response
2201+
21972202
if response.conversation_id is not None:
21982203
# For conversation-based APIs, the server already has the function call message.
21992204
# Only send the new function result message (added by _handle_function_call_results).
@@ -2300,6 +2305,11 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]:
23002305
if result["action"] != "continue":
23012306
return
23022307

2308+
# When tool_choice is 'required', return after tool execution
2309+
# The user's intent is to force exactly one tool call and get the result
2310+
if mutable_options.get("tool_choice") == "required":
2311+
return
2312+
23032313
if response.conversation_id is not None:
23042314
# For conversation-based APIs, the server already has the function call message.
23052315
# Only send the new function result message (the last one added by _handle_function_call_results).

python/packages/core/tests/core/test_function_invocation_logic.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2626,3 +2626,149 @@ def test_func(arg1: str) -> str:
26262626
assert conversation_ids_received[1] == "stream_conv_after_first", (
26272627
"streaming: conversation_id should be updated in options after receiving new conversation_id from API"
26282628
)
2629+
2630+
2631+
async def test_tool_choice_required_returns_after_tool_execution():
2632+
"""Test that tool_choice='required' returns after tool execution without another model call.
2633+
2634+
When tool_choice is 'required', the user's intent is to force exactly one tool call.
2635+
After the tool executes, we should return the response with the function call and result,
2636+
not continue to call the model again.
2637+
"""
2638+
from collections.abc import AsyncIterable, MutableSequence, Sequence
2639+
from typing import Any
2640+
from unittest.mock import patch
2641+
2642+
from agent_framework import (
2643+
BaseChatClient,
2644+
ChatMessage,
2645+
ChatResponse,
2646+
ChatResponseUpdate,
2647+
Content,
2648+
ResponseStream,
2649+
Role,
2650+
tool,
2651+
)
2652+
from agent_framework._middleware import ChatMiddlewareLayer
2653+
from agent_framework._tools import FunctionInvocationLayer
2654+
2655+
class TrackingChatClient(
2656+
ChatMiddlewareLayer,
2657+
FunctionInvocationLayer,
2658+
BaseChatClient,
2659+
):
2660+
def __init__(self) -> None:
2661+
super().__init__(function_middleware=[])
2662+
self.run_responses: list[ChatResponse] = []
2663+
self.streaming_responses: list[list[ChatResponseUpdate]] = []
2664+
self.call_count: int = 0
2665+
2666+
def _inner_get_response(
2667+
self,
2668+
*,
2669+
messages: MutableSequence[ChatMessage],
2670+
stream: bool,
2671+
options: dict[str, Any],
2672+
**kwargs: Any,
2673+
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
2674+
if stream:
2675+
return self._get_streaming_response(messages=messages, options=options, **kwargs)
2676+
2677+
async def _get() -> ChatResponse:
2678+
self.call_count += 1
2679+
if not self.run_responses:
2680+
return ChatResponse(messages=ChatMessage(role="assistant", text="done"))
2681+
return self.run_responses.pop(0)
2682+
2683+
return _get()
2684+
2685+
def _get_streaming_response(
2686+
self,
2687+
*,
2688+
messages: MutableSequence[ChatMessage],
2689+
options: dict[str, Any],
2690+
**kwargs: Any,
2691+
) -> ResponseStream[ChatResponseUpdate, ChatResponse]:
2692+
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
2693+
self.call_count += 1
2694+
if not self.streaming_responses:
2695+
yield ChatResponseUpdate(text="done", role="assistant", is_finished=True)
2696+
return
2697+
response = self.streaming_responses.pop(0)
2698+
for update in response:
2699+
yield update
2700+
2701+
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
2702+
return ChatResponse.from_chat_response_updates(updates)
2703+
2704+
return ResponseStream(_stream(), finalizer=_finalize)
2705+
2706+
@tool(name="test_func", approval_mode="never_require")
2707+
def test_func(arg1: str) -> str:
2708+
return f"Result {arg1}"
2709+
2710+
# Test non-streaming: should only call model once, then return with function call + result
2711+
with patch("agent_framework._tools.DEFAULT_MAX_ITERATIONS", 5):
2712+
client = TrackingChatClient()
2713+
2714+
client.run_responses = [
2715+
ChatResponse(
2716+
messages=ChatMessage(
2717+
role="assistant",
2718+
contents=[Content.from_function_call(call_id="call_1", name="test_func", arguments='{"arg1": "v1"}')],
2719+
),
2720+
),
2721+
# This second response should NOT be consumed
2722+
ChatResponse(
2723+
messages=ChatMessage(role="assistant", text="this should not be reached"),
2724+
),
2725+
]
2726+
2727+
response = await client.get_response(
2728+
"hello",
2729+
options={"tool_choice": "required", "tools": [test_func]},
2730+
)
2731+
2732+
# Should only call model once - after tool execution, return immediately
2733+
assert client.call_count == 1
2734+
# Response should contain function call and function result
2735+
assert len(response.messages) == 2
2736+
assert response.messages[0].role == Role.ASSISTANT
2737+
assert response.messages[0].contents[0].type == "function_call"
2738+
assert response.messages[1].role == Role.TOOL
2739+
assert response.messages[1].contents[0].type == "function_result"
2740+
# Second response should still be in queue (not consumed)
2741+
assert len(client.run_responses) == 1
2742+
2743+
# Test streaming version too
2744+
with patch("agent_framework._tools.DEFAULT_MAX_ITERATIONS", 5):
2745+
streaming_client = TrackingChatClient()
2746+
2747+
streaming_client.streaming_responses = [
2748+
[
2749+
ChatResponseUpdate(
2750+
contents=[Content.from_function_call(call_id="call_2", name="test_func", arguments='{"arg1": "v2"}')],
2751+
role="assistant",
2752+
),
2753+
],
2754+
# This second response should NOT be consumed
2755+
[
2756+
ChatResponseUpdate(text="this should not be reached", role="assistant", is_finished=True),
2757+
],
2758+
]
2759+
2760+
response_stream = streaming_client.get_response(
2761+
"hello",
2762+
stream=True,
2763+
options={"tool_choice": "required", "tools": [test_func]},
2764+
)
2765+
updates = []
2766+
async for update in response_stream:
2767+
updates.append(update)
2768+
2769+
# Should only call model once
2770+
assert streaming_client.call_count == 1
2771+
# Should have function call update and function result update
2772+
assert len(updates) == 2
2773+
# Second streaming response should still be in queue (not consumed)
2774+
assert len(streaming_client.streaming_responses) == 1

0 commit comments

Comments
 (0)