|
| 1 | +import asyncio |
| 2 | +import warnings |
1 | 3 | from unittest.mock import AsyncMock, Mock, PropertyMock, call, patch |
2 | 4 |
|
3 | 5 | import pytest |
4 | 6 |
|
5 | 7 | from tortoise import BaseDBAsyncClient, ConfigurationError |
6 | 8 | from tortoise.connection import ConnectionHandler |
| 9 | +from tortoise.warnings import TortoiseLoopSwitchWarning |
7 | 10 |
|
8 | 11 |
|
9 | 12 | @pytest.fixture |
@@ -206,9 +209,10 @@ def test_create_connection_db_info_not_str( |
206 | 209 |
|
207 | 210 |
|
208 | 211 | def test_get_alias_present(conn_handler): |
209 | | - conn_handler._storage = {"default": "some_connection"} |
| 212 | + mock_conn = Mock(_check_loop=Mock(return_value=True)) |
| 213 | + conn_handler._storage = {"default": mock_conn} |
210 | 214 | ret_val = conn_handler.get("default") |
211 | | - assert ret_val == "some_connection" |
| 215 | + assert ret_val is mock_conn |
212 | 216 |
|
213 | 217 |
|
214 | 218 | @patch("tortoise.connection.ConnectionHandler._create_connection") |
@@ -246,10 +250,12 @@ def test_reset(conn_handler): |
246 | 250 |
|
247 | 251 | @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) |
248 | 252 | def test_all(mocked_db_config, conn_handler): |
249 | | - conn_handler._storage = {"default": "some_conn", "other": "some_other_conn"} |
| 253 | + mock_conn_1 = Mock(_check_loop=Mock(return_value=True)) |
| 254 | + mock_conn_2 = Mock(_check_loop=Mock(return_value=True)) |
| 255 | + conn_handler._storage = {"default": mock_conn_1, "other": mock_conn_2} |
250 | 256 | mocked_db_config.return_value = {"default": {}, "other": {}} |
251 | 257 | ret_val = conn_handler.all() |
252 | | - assert set(ret_val) == {"some_conn", "some_other_conn"} |
| 258 | + assert set(ret_val) == {mock_conn_1, mock_conn_2} |
253 | 259 |
|
254 | 260 |
|
255 | 261 | @pytest.mark.asyncio |
@@ -282,3 +288,50 @@ async def test_close_all_without_discard(mocked_db_config, conn_handler): |
282 | 288 | conn_1.close.assert_awaited_once() |
283 | 289 | conn_2.close.assert_awaited_once() |
284 | 290 | assert conn_handler._storage == {"default": conn_1, "other": conn_2} |
| 291 | + |
| 292 | + |
| 293 | +# --- Event loop validation tests --- |
| 294 | + |
| 295 | + |
| 296 | +@pytest.mark.asyncio |
| 297 | +async def test_check_loop_returns_true_when_not_bound(): |
| 298 | + client = Mock(spec=BaseDBAsyncClient) |
| 299 | + client._bound_loop = None |
| 300 | + client._check_loop = BaseDBAsyncClient._check_loop.__get__(client) |
| 301 | + assert client._check_loop() is True |
| 302 | + |
| 303 | + |
| 304 | +@pytest.mark.asyncio |
| 305 | +async def test_check_loop_returns_true_on_same_loop(): |
| 306 | + client = Mock(spec=BaseDBAsyncClient) |
| 307 | + client._bound_loop = asyncio.get_running_loop() |
| 308 | + client._check_loop = BaseDBAsyncClient._check_loop.__get__(client) |
| 309 | + assert client._check_loop() is True |
| 310 | + |
| 311 | + |
| 312 | +@pytest.mark.asyncio |
| 313 | +async def test_check_loop_returns_false_on_different_loop(): |
| 314 | + client = Mock(spec=BaseDBAsyncClient) |
| 315 | + client._bound_loop = asyncio.new_event_loop() |
| 316 | + client._check_loop = BaseDBAsyncClient._check_loop.__get__(client) |
| 317 | + assert client._check_loop() is False |
| 318 | + |
| 319 | + |
| 320 | +@patch("tortoise.connection.ConnectionHandler._create_connection") |
| 321 | +def test_get_reconnects_on_loop_change(mocked_create_connection, conn_handler): |
| 322 | + """When _check_loop() returns False, get() should warn and create a new connection.""" |
| 323 | + stale_conn = Mock(_check_loop=Mock(return_value=False)) |
| 324 | + fresh_conn = Mock(_check_loop=Mock(return_value=True)) |
| 325 | + mocked_create_connection.return_value = fresh_conn |
| 326 | + conn_handler._storage = {"default": stale_conn} |
| 327 | + |
| 328 | + with warnings.catch_warnings(record=True) as w: |
| 329 | + warnings.simplefilter("always") |
| 330 | + ret_val = conn_handler.get("default") |
| 331 | + |
| 332 | + assert ret_val is fresh_conn |
| 333 | + assert conn_handler._storage["default"] is fresh_conn |
| 334 | + mocked_create_connection.assert_called_once_with("default") |
| 335 | + loop_warnings = [x for x in w if issubclass(x.category, TortoiseLoopSwitchWarning)] |
| 336 | + assert len(loop_warnings) == 1 |
| 337 | + assert "different event loop" in str(loop_warnings[0].message) |
0 commit comments