Skip to content

Commit c6d886b

Browse files
authored
Auto-recreate connection on loop change (#2098)
* Auto-recreate connection on loop change * Fix test
1 parent 48151d6 commit c6d886b

File tree

13 files changed

+217
-22
lines changed

13 files changed

+217
-22
lines changed

docs/connections.rst

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,54 @@ providing isolation between different contexts (useful for testing).
162162
conn3 = Tortoise.get_connection("default")
163163
assert conn is not conn3
164164
165+
Event Loop Handling
166+
===================
167+
168+
Some database drivers (asyncpg, aiomysql) bind their connection pools to the event loop
169+
that created them. If the loop changes -- for example, when using function-scoped pytest
170+
fixtures or Starlette's ``TestClient`` -- the old pool becomes unusable.
171+
172+
Tortoise handles this automatically: when ``ConnectionHandler.get()`` detects that the
173+
current event loop differs from the one the connection was created on, it transparently
174+
creates a fresh connection.
175+
176+
**In production**, a loop change usually indicates a bug (e.g., mixing sync/async code).
177+
A ``TortoiseLoopSwitchWarning`` is emitted so you can investigate:
178+
179+
.. code-block:: python
180+
181+
import warnings
182+
from tortoise.warnings import TortoiseLoopSwitchWarning
183+
184+
# Suppress if you know what you're doing
185+
warnings.filterwarnings("ignore", category=TortoiseLoopSwitchWarning)
186+
187+
**In tests**, ``tortoise_test_context()`` suppresses this warning automatically.
188+
No special configuration needed.
189+
190+
.. list-table:: Backend Loop Binding
191+
:header-rows: 1
192+
:widths: 30 20 50
193+
194+
* - Backend
195+
- Bound?
196+
- Notes
197+
* - asyncpg
198+
- Yes
199+
- Pool stores loop at creation time
200+
* - aiomysql/asyncmy
201+
- Yes
202+
- Pool stores loop at creation time
203+
* - psycopg
204+
- No
205+
- Uses running loop per-operation
206+
* - aiosqlite
207+
- Partial
208+
- Grabs loop per-operation, not at creation
209+
* - asyncodbc (MSSQL/Oracle)
210+
- No
211+
- Per-operation loop resolution
212+
165213
API Reference
166214
=============
167215

docs/contrib/unittest.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,31 @@ For tests that require multiple database connections:
123123
await ctx.generate_schemas()
124124
yield ctx
125125
126+
Event Loop Isolation
127+
====================
128+
129+
Some backends (asyncpg, aiomysql) bind connection pools to the event loop that created
130+
them. ``tortoise_test_context()`` handles this transparently -- if the event loop changes
131+
between tests, connections are automatically recreated.
132+
133+
This means you **don't** need ``loop_scope="session"`` or any special pytest-asyncio
134+
configuration. The simplest setup works:
135+
136+
.. code-block:: toml
137+
138+
# pyproject.toml -- no loop_scope overrides needed
139+
[tool.pytest.ini_options]
140+
asyncio_mode = "auto"
141+
142+
If you use ``TortoiseContext`` directly (without ``tortoise_test_context``), you may see
143+
a ``TortoiseLoopSwitchWarning`` when the loop changes. Suppress it with:
144+
145+
.. code-block:: python
146+
147+
import warnings
148+
from tortoise.warnings import TortoiseLoopSwitchWarning
149+
warnings.filterwarnings("ignore", category=TortoiseLoopSwitchWarning)
150+
126151
Testing Database Capabilities
127152
=============================
128153

tests/test_connection.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import asyncio
2+
import warnings
13
from unittest.mock import AsyncMock, Mock, PropertyMock, call, patch
24

35
import pytest
46

57
from tortoise import BaseDBAsyncClient, ConfigurationError
68
from tortoise.connection import ConnectionHandler
9+
from tortoise.warnings import TortoiseLoopSwitchWarning
710

811

912
@pytest.fixture
@@ -206,9 +209,10 @@ def test_create_connection_db_info_not_str(
206209

207210

208211
def test_get_alias_present(conn_handler):
209-
conn_handler._storage = {"default": "some_connection"}
212+
mock_conn = Mock(_check_loop=Mock(return_value=True))
213+
conn_handler._storage = {"default": mock_conn}
210214
ret_val = conn_handler.get("default")
211-
assert ret_val == "some_connection"
215+
assert ret_val is mock_conn
212216

213217

214218
@patch("tortoise.connection.ConnectionHandler._create_connection")
@@ -246,10 +250,12 @@ def test_reset(conn_handler):
246250

247251
@patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock)
248252
def test_all(mocked_db_config, conn_handler):
249-
conn_handler._storage = {"default": "some_conn", "other": "some_other_conn"}
253+
mock_conn_1 = Mock(_check_loop=Mock(return_value=True))
254+
mock_conn_2 = Mock(_check_loop=Mock(return_value=True))
255+
conn_handler._storage = {"default": mock_conn_1, "other": mock_conn_2}
250256
mocked_db_config.return_value = {"default": {}, "other": {}}
251257
ret_val = conn_handler.all()
252-
assert set(ret_val) == {"some_conn", "some_other_conn"}
258+
assert set(ret_val) == {mock_conn_1, mock_conn_2}
253259

254260

255261
@pytest.mark.asyncio
@@ -282,3 +288,50 @@ async def test_close_all_without_discard(mocked_db_config, conn_handler):
282288
conn_1.close.assert_awaited_once()
283289
conn_2.close.assert_awaited_once()
284290
assert conn_handler._storage == {"default": conn_1, "other": conn_2}
291+
292+
293+
# --- Event loop validation tests ---
294+
295+
296+
@pytest.mark.asyncio
297+
async def test_check_loop_returns_true_when_not_bound():
298+
client = Mock(spec=BaseDBAsyncClient)
299+
client._bound_loop = None
300+
client._check_loop = BaseDBAsyncClient._check_loop.__get__(client)
301+
assert client._check_loop() is True
302+
303+
304+
@pytest.mark.asyncio
305+
async def test_check_loop_returns_true_on_same_loop():
306+
client = Mock(spec=BaseDBAsyncClient)
307+
client._bound_loop = asyncio.get_running_loop()
308+
client._check_loop = BaseDBAsyncClient._check_loop.__get__(client)
309+
assert client._check_loop() is True
310+
311+
312+
@pytest.mark.asyncio
313+
async def test_check_loop_returns_false_on_different_loop():
314+
client = Mock(spec=BaseDBAsyncClient)
315+
client._bound_loop = asyncio.new_event_loop()
316+
client._check_loop = BaseDBAsyncClient._check_loop.__get__(client)
317+
assert client._check_loop() is False
318+
319+
320+
@patch("tortoise.connection.ConnectionHandler._create_connection")
321+
def test_get_reconnects_on_loop_change(mocked_create_connection, conn_handler):
322+
"""When _check_loop() returns False, get() should warn and create a new connection."""
323+
stale_conn = Mock(_check_loop=Mock(return_value=False))
324+
fresh_conn = Mock(_check_loop=Mock(return_value=True))
325+
mocked_create_connection.return_value = fresh_conn
326+
conn_handler._storage = {"default": stale_conn}
327+
328+
with warnings.catch_warnings(record=True) as w:
329+
warnings.simplefilter("always")
330+
ret_val = conn_handler.get("default")
331+
332+
assert ret_val is fresh_conn
333+
assert conn_handler._storage["default"] is fresh_conn
334+
mocked_create_connection.assert_called_once_with("default")
335+
loop_warnings = [x for x in w if issubclass(x.category, TortoiseLoopSwitchWarning)]
336+
assert len(loop_warnings) == 1
337+
assert "different event loop" in str(loop_warnings[0].message)

tests/test_query_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@ async def test_execute_pypika_explicit_connection_with_multiple_configured() ->
278278

279279
class DummyClient:
280280
query_class = type("QueryClass", (), {"SQL_CONTEXT": None})
281+
_bound_loop = None
282+
283+
def _check_loop(self) -> bool:
284+
return True
281285

282286
async def execute_query_dict_with_affected(self, query, values=None):
283287
return [], 0

tortoise/backends/asyncpg/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ async def create_connection(self, with_db: bool) -> None:
5959
}
6060
try:
6161
self._pool = await self.create_pool(password=self.password, **self._template)
62+
await self._post_connect()
6263
self.log.debug("Created connection pool %s with params: %s", self._pool, self._template)
6364
except asyncpg.InvalidCatalogNameError as ex:
6465
msg = "Can't establish connection to "

tortoise/backends/base/client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class BaseDBAsyncClient(abc.ABC):
118118
_connection: Any
119119
_parent: BaseDBAsyncClient
120120
_pool: Any
121+
_bound_loop: asyncio.AbstractEventLoop | None = None
121122
connection_name: str
122123
query_class: type[Query] = Query
123124
executor_class: type[BaseExecutor] = BaseExecutor
@@ -129,6 +130,20 @@ def __init__(self, connection_name: str, fetch_inserted: bool = True, **kwargs:
129130
self.connection_name = connection_name
130131
self.fetch_inserted = fetch_inserted
131132

133+
def _check_loop(self) -> bool:
134+
"""Check if the current event loop matches the one this client was created on."""
135+
try:
136+
current = asyncio.get_running_loop()
137+
except RuntimeError:
138+
return True # No running loop — can't validate
139+
if self._bound_loop is None:
140+
return True # Not yet bound (pool not created yet)
141+
return self._bound_loop is current
142+
143+
async def _post_connect(self) -> None:
144+
"""Called after pool/connection is created. Records the bound loop."""
145+
self._bound_loop = asyncio.get_running_loop()
146+
132147
async def create_connection(self, with_db: bool) -> None:
133148
"""
134149
Establish a DB connection.

tortoise/backends/mysql/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ async def create_connection(self, with_db: bool) -> None:
140140
hours = timezone.now().utcoffset().seconds / 3600 # type: ignore
141141
tz = f"{int(hours):+d}:{int((hours % 1) * 60):02d}"
142142
await cursor.execute(f"SET time_zone='{tz}';")
143+
await self._post_connect()
143144
self.log.debug("Created connection %s pool with params: %s", self._pool, self._template)
144145
except errors.OperationalError:
145146
raise DBConnectionError(f"Can't connect to MySQL server: {self._template}")

tortoise/backends/odbc/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ async def create_connection(self, with_db: bool) -> None:
9191
self._pool = await asyncodbc.create_pool(
9292
**self._template,
9393
)
94+
await self._post_connect()
9495
self.log.debug("Created connection %s pool with params: %s", self._pool, self._template)
9596
except pyodbc.InterfaceError:
9697
raise DBConnectionError(f"Can't establish connection to database {self.database}")

tortoise/backends/psycopg/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ async def create_connection(self, with_db: bool) -> None:
104104
# Immediately test the connection because the test suite expects it to check if the
105105
# connection is valid.
106106
await self._pool.open(wait=True, timeout=extra["timeout"])
107+
await self._post_connect()
107108
self.log.debug("Created connection pool %s with params: %s", self._pool, self._template)
108109
except (psycopg.errors.InvalidCatalogName, psycopg_pool.PoolTimeout):
109110
raise exceptions.DBConnectionError(

tortoise/backends/sqlite/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ async def create_connection(self, with_db: bool) -> None:
8484
for pragma, val in self.pragmas.items():
8585
cursor = await self._connection.execute(f"PRAGMA {pragma}={val}")
8686
await cursor.close()
87+
await self._post_connect()
8788
self.log.debug(
8889
"Created connection %s with params: filename=%s %s",
8990
self._connection,

0 commit comments

Comments
 (0)