@@ -2626,3 +2626,149 @@ def test_func(arg1: str) -> str:
26262626 assert conversation_ids_received [1 ] == "stream_conv_after_first" , (
26272627 "streaming: conversation_id should be updated in options after receiving new conversation_id from API"
26282628 )
2629+
2630+
2631+ async def test_tool_choice_required_returns_after_tool_execution ():
2632+ """Test that tool_choice='required' returns after tool execution without another model call.
2633+
2634+ When tool_choice is 'required', the user's intent is to force exactly one tool call.
2635+ After the tool executes, we should return the response with the function call and result,
2636+ not continue to call the model again.
2637+ """
2638+ from collections .abc import AsyncIterable , MutableSequence , Sequence
2639+ from typing import Any
2640+ from unittest .mock import patch
2641+
2642+ from agent_framework import (
2643+ BaseChatClient ,
2644+ ChatMessage ,
2645+ ChatResponse ,
2646+ ChatResponseUpdate ,
2647+ Content ,
2648+ ResponseStream ,
2649+ Role ,
2650+ tool ,
2651+ )
2652+ from agent_framework ._middleware import ChatMiddlewareLayer
2653+ from agent_framework ._tools import FunctionInvocationLayer
2654+
2655+ class TrackingChatClient (
2656+ ChatMiddlewareLayer ,
2657+ FunctionInvocationLayer ,
2658+ BaseChatClient ,
2659+ ):
2660+ def __init__ (self ) -> None :
2661+ super ().__init__ (function_middleware = [])
2662+ self .run_responses : list [ChatResponse ] = []
2663+ self .streaming_responses : list [list [ChatResponseUpdate ]] = []
2664+ self .call_count : int = 0
2665+
2666+ def _inner_get_response (
2667+ self ,
2668+ * ,
2669+ messages : MutableSequence [ChatMessage ],
2670+ stream : bool ,
2671+ options : dict [str , Any ],
2672+ ** kwargs : Any ,
2673+ ) -> Awaitable [ChatResponse ] | ResponseStream [ChatResponseUpdate , ChatResponse ]:
2674+ if stream :
2675+ return self ._get_streaming_response (messages = messages , options = options , ** kwargs )
2676+
2677+ async def _get () -> ChatResponse :
2678+ self .call_count += 1
2679+ if not self .run_responses :
2680+ return ChatResponse (messages = ChatMessage (role = "assistant" , text = "done" ))
2681+ return self .run_responses .pop (0 )
2682+
2683+ return _get ()
2684+
2685+ def _get_streaming_response (
2686+ self ,
2687+ * ,
2688+ messages : MutableSequence [ChatMessage ],
2689+ options : dict [str , Any ],
2690+ ** kwargs : Any ,
2691+ ) -> ResponseStream [ChatResponseUpdate , ChatResponse ]:
2692+ async def _stream () -> AsyncIterable [ChatResponseUpdate ]:
2693+ self .call_count += 1
2694+ if not self .streaming_responses :
2695+ yield ChatResponseUpdate (text = "done" , role = "assistant" , is_finished = True )
2696+ return
2697+ response = self .streaming_responses .pop (0 )
2698+ for update in response :
2699+ yield update
2700+
2701+ def _finalize (updates : Sequence [ChatResponseUpdate ]) -> ChatResponse :
2702+ return ChatResponse .from_chat_response_updates (updates )
2703+
2704+ return ResponseStream (_stream (), finalizer = _finalize )
2705+
2706+ @tool (name = "test_func" , approval_mode = "never_require" )
2707+ def test_func (arg1 : str ) -> str :
2708+ return f"Result { arg1 } "
2709+
2710+ # Test non-streaming: should only call model once, then return with function call + result
2711+ with patch ("agent_framework._tools.DEFAULT_MAX_ITERATIONS" , 5 ):
2712+ client = TrackingChatClient ()
2713+
2714+ client .run_responses = [
2715+ ChatResponse (
2716+ messages = ChatMessage (
2717+ role = "assistant" ,
2718+ contents = [Content .from_function_call (call_id = "call_1" , name = "test_func" , arguments = '{"arg1": "v1"}' )],
2719+ ),
2720+ ),
2721+ # This second response should NOT be consumed
2722+ ChatResponse (
2723+ messages = ChatMessage (role = "assistant" , text = "this should not be reached" ),
2724+ ),
2725+ ]
2726+
2727+ response = await client .get_response (
2728+ "hello" ,
2729+ options = {"tool_choice" : "required" , "tools" : [test_func ]},
2730+ )
2731+
2732+ # Should only call model once - after tool execution, return immediately
2733+ assert client .call_count == 1
2734+ # Response should contain function call and function result
2735+ assert len (response .messages ) == 2
2736+ assert response .messages [0 ].role == Role .ASSISTANT
2737+ assert response .messages [0 ].contents [0 ].type == "function_call"
2738+ assert response .messages [1 ].role == Role .TOOL
2739+ assert response .messages [1 ].contents [0 ].type == "function_result"
2740+ # Second response should still be in queue (not consumed)
2741+ assert len (client .run_responses ) == 1
2742+
2743+ # Test streaming version too
2744+ with patch ("agent_framework._tools.DEFAULT_MAX_ITERATIONS" , 5 ):
2745+ streaming_client = TrackingChatClient ()
2746+
2747+ streaming_client .streaming_responses = [
2748+ [
2749+ ChatResponseUpdate (
2750+ contents = [Content .from_function_call (call_id = "call_2" , name = "test_func" , arguments = '{"arg1": "v2"}' )],
2751+ role = "assistant" ,
2752+ ),
2753+ ],
2754+ # This second response should NOT be consumed
2755+ [
2756+ ChatResponseUpdate (text = "this should not be reached" , role = "assistant" , is_finished = True ),
2757+ ],
2758+ ]
2759+
2760+ response_stream = streaming_client .get_response (
2761+ "hello" ,
2762+ stream = True ,
2763+ options = {"tool_choice" : "required" , "tools" : [test_func ]},
2764+ )
2765+ updates = []
2766+ async for update in response_stream :
2767+ updates .append (update )
2768+
2769+ # Should only call model once
2770+ assert streaming_client .call_count == 1
2771+ # Should have function call update and function result update
2772+ assert len (updates ) == 2
2773+ # Second streaming response should still be in queue (not consumed)
2774+ assert len (streaming_client .streaming_responses ) == 1
0 commit comments