Skip to content

Commit 47e48f8

Browse files
committed
fix(mcp): reset client session stack state after cleanup
1 parent 5eb4a33 commit 47e48f8

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

src/agents/mcp/server.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,12 @@ async def cleanup(self):
709709
else:
710710
logger.error(f"Error cleaning up server: {e}")
711711
finally:
712+
# Always reset the exit stack so we don't retain callbacks/references from the
713+
# previous connection. This keeps teardown deterministic and allows reconnecting
714+
# with a fresh stack even if cleanup encountered recoverable errors.
715+
self.exit_stack = AsyncExitStack()
712716
self.session = None
717+
self.server_initialize_result = None
713718

714719

715720
class MCPServerStdioParams(TypedDict):

tests/mcp/test_connect_disconnect.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88
from .helpers import DummyStreamsContextManager, tee
99

1010

11+
class CountingStreamsContextManager:
12+
def __init__(self, counter: dict[str, int]):
13+
self.counter = counter
14+
15+
async def __aenter__(self):
16+
self.counter["enter"] += 1
17+
return (object(), object())
18+
19+
async def __aexit__(self, exc_type, exc_val, exc_tb):
20+
self.counter["exit"] += 1
21+
22+
1123
@pytest.mark.asyncio
1224
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
1325
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
@@ -67,3 +79,33 @@ async def test_manual_connect_disconnect_works(
6779

6880
await server.cleanup()
6981
assert server.session is None, "Server should be disconnected"
82+
83+
84+
@pytest.mark.asyncio
85+
@patch("agents.mcp.server.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
86+
@patch("agents.mcp.server.stdio_client")
87+
async def test_cleanup_resets_exit_stack_and_reconnects(
88+
mock_stdio_client: AsyncMock, mock_initialize: AsyncMock
89+
):
90+
counter = {"enter": 0, "exit": 0}
91+
mock_stdio_client.side_effect = lambda params: CountingStreamsContextManager(counter)
92+
93+
server = MCPServerStdio(
94+
params={
95+
"command": tee,
96+
},
97+
cache_tools_list=True,
98+
)
99+
100+
await server.connect()
101+
original_exit_stack = server.exit_stack
102+
103+
await server.cleanup()
104+
assert server.session is None
105+
assert server.exit_stack is not original_exit_stack
106+
assert server.server_initialize_result is None
107+
assert counter == {"enter": 1, "exit": 1}
108+
109+
await server.connect()
110+
await server.cleanup()
111+
assert counter == {"enter": 2, "exit": 2}

0 commit comments

Comments
 (0)