Skip to content

Commit 2410772

Browse files
committed
Global fallback for fastapi
1 parent c467923 commit 2410772

File tree

7 files changed

+608
-529
lines changed

7 files changed

+608
-529
lines changed

examples/fastapi/_tests.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,12 @@ async def client() -> ClientManagerType:
5757

5858
@pytest.fixture(scope="module")
5959
async def client_east() -> ClientManagerType:
60+
# app_east uses _enable_global_fallback=False, so we need to explicitly
61+
# enter the context from app.state to make it current for tests
6062
async with client_manager(app_east) as c:
61-
yield c
63+
ctx = app_east.state._tortoise_context
64+
with ctx: # Enter context to make it current via contextvar
65+
yield c
6266

6367

6468
class UserTester:

examples/fastapi/main_custom_timezone.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
@asynccontextmanager
1111
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
1212
# app startup
13+
# Disable global fallback since this is the secondary app in tests
14+
# (main app already uses global fallback). Context is stored in app.state.
1315
async with register_orm(
1416
app,
1517
use_tz=False,
1618
timezone="Asia/Shanghai",
1719
add_exception_handlers=True,
20+
_enable_global_fallback=False,
1821
):
1922
# db connected
2023
yield

tests/contrib/test_fastapi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ async def test_await(
3131
use_tz=False,
3232
timezone="UTC",
3333
_create_db=False,
34+
_enable_global_fallback=True,
3435
)
3536
await orm.close_orm()
3637
mocked_close_connections.assert_awaited_once()

tortoise/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ async def init(
300300
routers: list[str | type] | None = None,
301301
table_name_generator: Callable[[type[Model]], str] | None = None,
302302
init_connections: bool = True,
303+
_enable_global_fallback: bool = False,
303304
) -> TortoiseContext:
304305
"""
305306
Sets up Tortoise-ORM: loads apps and models, configures database connections but does not
@@ -368,6 +369,10 @@ async def init(
368369
:param init_connections:
369370
When ``False``, skips initializing connection clients while still loading apps
370371
and validating connection names against the config.
372+
:param _enable_global_fallback:
373+
When ``True``, stores the context as a global fallback for cross-task access.
374+
This is used by RegisterTortoise (FastAPI) where asgi-lifespan runs lifespan
375+
in a background task. Default is ``False`` for pure context isolation.
371376
372377
:raises ConfigurationError: For any configuration error
373378
@@ -425,6 +430,7 @@ async def init(
425430
routers=routers,
426431
table_name_generator=table_name_generator,
427432
init_connections=init_connections,
433+
_enable_global_fallback=_enable_global_fallback,
428434
)
429435

430436
return ctx

tortoise/context.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,55 @@
4545
"tortoise_context", default=None
4646
)
4747

48+
# Optional global fallback context for cross-task access.
49+
# This is used by RegisterTortoise (FastAPI) where asgi-lifespan runs lifespan
50+
# in a background task, but requests/tests run in a different task.
51+
# Disabled by default; enabled via Tortoise.init(_enable_global_fallback=True).
52+
_global_context: TortoiseContext | None = None
53+
4854

4955
def get_current_context() -> TortoiseContext | None:
5056
"""
5157
Get the currently active TortoiseContext, or None if no context is active.
5258
59+
Checks the contextvar first (for proper isolation), then falls back to
60+
the global context if one was set via _enable_global_fallback.
61+
5362
Returns:
5463
The current TortoiseContext if one is active, None otherwise.
5564
"""
56-
return _current_context.get()
65+
ctx = _current_context.get()
66+
if ctx is not None:
67+
return ctx
68+
return _global_context
69+
70+
71+
def set_global_context(ctx: TortoiseContext) -> None:
72+
"""
73+
Set the global fallback context for cross-task access.
74+
75+
This is used by RegisterTortoise (FastAPI) where asgi-lifespan runs lifespan
76+
in a background task, but requests/tests run in a different task.
77+
The global context allows these cross-task scenarios to work without
78+
explicit context passing.
79+
80+
Args:
81+
ctx: The TortoiseContext to set as global fallback.
82+
83+
Raises:
84+
ConfigurationError: If a global context is already set. Only one global
85+
context can be active at a time. For multiple isolated contexts,
86+
use explicit TortoiseContext() without global fallback.
87+
"""
88+
global _global_context
89+
if _global_context is not None:
90+
raise ConfigurationError(
91+
"Global context fallback is already enabled by another Tortoise.init() call. "
92+
"Only one global context can be active at a time. "
93+
"Use explicit TortoiseContext() for multiple isolated contexts, "
94+
"or set _enable_global_fallback=False for secondary apps."
95+
)
96+
_global_context = ctx
5797

5898

5999
def require_context() -> TortoiseContext:
@@ -230,6 +270,7 @@ async def init(
230270
routers: list[str | type] | None = None,
231271
table_name_generator: Callable[[type[Model]], str] | None = None,
232272
init_connections: bool = True,
273+
_enable_global_fallback: bool = False,
233274
) -> None:
234275
"""
235276
Initialize this context with database configuration.
@@ -255,6 +296,8 @@ async def init(
255296
table_name_generator: Optional callable to generate table names.
256297
init_connections: If False, skips initializing connection clients while still
257298
loading apps and validating connection names against the config.
299+
_enable_global_fallback: If True, sets this context as the global fallback
300+
for cross-task access (e.g., asgi-lifespan scenarios). Default is False.
258301
259302
Raises:
260303
ConfigurationError: If configuration is invalid or incomplete.
@@ -334,6 +377,10 @@ async def init(
334377

335378
self._inited = True
336379

380+
# Set global fallback for cross-task access if enabled
381+
if _enable_global_fallback:
382+
set_global_context(self)
383+
337384
def _init_timezone(self, use_tz: bool, timezone: str) -> None:
338385
"""Initialize timezone settings for this context."""
339386
self._use_tz = use_tz
@@ -478,9 +525,13 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
478525
"""
479526
Exit the async context manager, close connections, and restore previous context.
480527
"""
528+
global _global_context
481529
await self.close_connections()
482530
self._apps = None
483531
self._inited = False
532+
# Clear global context if this context was set as the global fallback
533+
if _global_context is self:
534+
_global_context = None
484535
self.__exit__(exc_type, exc_val, exc_tb)
485536

486537

@@ -558,5 +609,6 @@ async def test_create_user(db):
558609
"TortoiseContext",
559610
"get_current_context",
560611
"require_context",
612+
"set_global_context",
561613
"tortoise_test_context",
562614
]

tortoise/contrib/fastapi/__init__.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from tortoise import Tortoise
1111
from tortoise.connection import get_connections
12+
from tortoise.context import TortoiseContext
1213
from tortoise.exceptions import DoesNotExist, IntegrityError
1314
from tortoise.log import logger
1415

@@ -103,6 +104,10 @@ class RegisterTortoise(AbstractAsyncContextManager):
103104
A boolean that specifies if datetime will be timezone-aware by default or not.
104105
timezone:
105106
Timezone to use, default is UTC.
107+
_enable_global_fallback:
108+
If True, enables global context fallback for cross-task access (e.g., when
109+
using asgi-lifespan which runs lifespan in a background task). Default is True.
110+
Set to False when running multiple apps in the same process to avoid conflicts.
106111
107112
Raises
108113
------
@@ -122,6 +127,7 @@ def __init__(
122127
use_tz: bool = False,
123128
timezone: str = "UTC",
124129
_create_db: bool = False,
130+
_enable_global_fallback: bool = True,
125131
) -> None:
126132
self.app = app
127133
self.config = config
@@ -132,6 +138,8 @@ def __init__(
132138
self.use_tz = use_tz
133139
self.timezone = timezone
134140
self._create_db = _create_db
141+
self._enable_global_fallback = _enable_global_fallback
142+
self._context: TortoiseContext | None = None
135143

136144
if add_exception_handlers and app is not None:
137145
from starlette.middleware.exceptions import ExceptionMiddleware
@@ -151,24 +159,32 @@ async def wrap_middleware_call(self, *args, **kw) -> None:
151159

152160
ExceptionMiddleware.__call__ = wrap_middleware_call # type:ignore
153161

154-
async def init_orm(self) -> None: # pylint: disable=W0612
155-
await Tortoise.init(
162+
async def init_orm(self) -> TortoiseContext: # pylint: disable=W0612
163+
self._context = await Tortoise.init(
156164
config=self.config,
157165
config_file=self.config_file,
158166
db_url=self.db_url,
159167
modules=self.modules,
160168
use_tz=self.use_tz,
161169
timezone=self.timezone,
162170
_create_db=self._create_db,
171+
_enable_global_fallback=self._enable_global_fallback,
163172
)
173+
# Store context in app.state for explicit access when global fallback is disabled
174+
if self.app is not None:
175+
self.app.state._tortoise_context = self._context
164176
logger.info("Tortoise-ORM started, %s, %s", get_connections()._get_storage(), Tortoise.apps)
165177
if self.generate_schemas:
166178
logger.info("Tortoise-ORM generating schema")
167179
await Tortoise.generate_schemas()
180+
return self._context
168181

169-
@staticmethod
170-
async def close_orm() -> None: # pylint: disable=W0612
182+
async def close_orm(self) -> None: # pylint: disable=W0612
171183
await Tortoise.close_connections()
184+
# Clear context from app.state
185+
if self.app is not None and hasattr(self.app.state, "_tortoise_context"):
186+
delattr(self.app.state, "_tortoise_context")
187+
self._context = None
172188
logger.info("Tortoise-ORM shutdown")
173189

174190
def __call__(self, *args, **kwargs) -> Self:

0 commit comments

Comments
 (0)