Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/8001.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure Request ID Header Propagation in All Requests and, Responses
1 change: 1 addition & 0 deletions changes/8160.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure `request_id` propagation for non-HTTP entry points (event handlers, background tasks, and sweeper)
21 changes: 19 additions & 2 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from ai.backend.common.auth import AgentAuthHandler, PublicKey, SecretKey
from ai.backend.common.bgtask.bgtask import ProgressReporter
from ai.backend.common.configs.redis import RedisConfig
from ai.backend.common.contexts.request_id import receive_request_id
from ai.backend.common.defs import RedisRole
from ai.backend.common.docker import ImageRef
from ai.backend.common.dto.agent.response import (
Expand Down Expand Up @@ -179,11 +180,22 @@ async def _inner(self_: AgentRPCServer, request: RPCMessage) -> Any:
try:
if request.body is None:
return await meth(self_)
return await meth(
request_id = request.body.get("request_id")
receive_request_id(request_id, f"RPC call from manager: {meth.__name__}")
result = await meth(
self_,
*request.body["args"],
**request.body["kwargs"],
)
if request_id:
if isinstance(result, dict):
result["request_id"] = request_id
else:
log.warning(
"Cannot attach request_id to non-dict RPC response: {}",
type(result).__name__,
)
return result
except (TimeoutError, asyncio.CancelledError):
raise
except ResourceError:
Expand Down Expand Up @@ -216,12 +228,17 @@ async def _inner(self_: AgentRPCServer, request: RPCMessage) -> Any:
try:
if request.body is None:
return await meth(self_)
request_id = request.body.get("request_id")
receive_request_id(request_id, f"RPC call from manager: {meth.__name__}")
res = await meth(
self_,
*request.body["args"],
**request.body["kwargs"],
)
return res.as_dict()
resp_dict = res.as_dict()
if request_id:
resp_dict["request_id"] = request_id
return resp_dict
except (TimeoutError, asyncio.CancelledError):
raise
except ResourceError:
Expand Down
5 changes: 4 additions & 1 deletion src/ai/backend/common/api_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pydantic import BaseModel, ConfigDict
from pydantic_core._pydantic_core import ValidationError

from ai.backend.common.contexts.request_id import bind_request_id
from ai.backend.common.types import StreamReader

from .exception import (
Expand Down Expand Up @@ -501,7 +502,9 @@ async def wrapped(first_arg: Any, request: web.Request) -> web.StreamResponse:

body_stream = result.body
status = result.status
resp = web.StreamResponse(status=status, headers=result.headers)
headers: dict[str, str] = dict(result.headers) if result.headers else {}
bind_request_id(headers, f"stream_api_handler response: {handler.__name__}")
resp = web.StreamResponse(status=status, headers=headers)

body_iter = body_stream.read()

Expand Down
3 changes: 3 additions & 0 deletions src/ai/backend/common/bgtask/bgtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TaskSetKey,
ValkeyBgtaskClient,
)
from ai.backend.common.contexts.request_id import ensure_request_id
from ai.backend.common.events.dispatcher import (
EventProducer,
)
Expand Down Expand Up @@ -492,6 +493,7 @@ async def _try_to_revive_task(
) -> Optional[BaseBackgroundTaskResult]:
return await self._task_registry.revive_task(task_name.value, task_info.body)

@ensure_request_id
async def _execute_new_task(
self,
task_name: BgtaskNameBase,
Expand Down Expand Up @@ -523,6 +525,7 @@ async def _execute_new_task(
last_message=last_message,
)

@ensure_request_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Middleware might not guarantee the use of the existing request ID.

async def _revive_task(
self, task_name: BgtaskNameBase, task_info: TaskInfo, task_key: BgTaskKey
) -> None:
Expand Down
81 changes: 79 additions & 2 deletions src/ai/backend/common/contexts/request_id.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
from __future__ import annotations

import functools
import logging
import uuid
from collections.abc import Iterator
from collections.abc import Callable, Coroutine, Iterator, MutableMapping
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Optional
from typing import Any, Final, Optional, ParamSpec, TypeVar

from ai.backend.logging import BraceStyleAdapter

P = ParamSpec("P")
T = TypeVar("T")

log = BraceStyleAdapter(logging.getLogger(__spec__.name))

REQUEST_ID_HEADER: Final = "X-BackendAI-RequestID"

_request_id_var: ContextVar[str] = ContextVar("request_id")

Expand Down Expand Up @@ -35,3 +48,67 @@ def with_request_id(request_id: Optional[str] = None) -> Iterator[None]:
finally:
# Reset the context variable to its previous state
_request_id_var.reset(token)


def bind_request_id(
target: MutableMapping[str, Any],
context_description: str,
*,
key: str = REQUEST_ID_HEADER,
) -> None:
"""
Set the request ID if available in the current context.
Logs a warning if no request_id is available.

:param target: The dict to add the request ID to (e.g., headers or request body)
:param context_description: A description of the operation for logging (e.g., "wsproxy status query")
:param key: The key name to use (default: "X-BackendAI-RequestID")
"""
if request_id := current_request_id():
target[key] = request_id
else:
log.warning("No request_id in context for {}", context_description)


def receive_request_id(request_id: Optional[str], context_description: str) -> None:
"""
Set the request ID in the current context.
Unlike with_request_id(), this does not auto-generate a UUID if None is passed,
and does not reset the context when done.
Logs a warning if no request_id is provided.

Use this for fire-and-forget scenarios
like RPC handlers where the context is scoped to the request.

:param request_id: The request ID to set (from incoming request)
:param context_description: A description of the operation for logging
"""
if request_id:
_request_id_var.set(request_id)
else:
log.warning("No request_id in context for {}", context_description)


def ensure_request_id(
func: Callable[P, Coroutine[Any, Any, T]],
) -> Callable[P, Coroutine[Any, Any, T]]:
"""
Decorator that ensures a request_id exists in the context before executing the function.
If no request_id is set, generates a new UUID.

Use this for background tasks, timer callbacks, and other entry points
that don't go through the HTTP middleware.

Example:
@ensure_request_id
async def my_background_task():
# request_id is guaranteed to exist here
...
"""

@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
with with_request_id():
return await func(*args, **kwargs)

return wrapper
29 changes: 16 additions & 13 deletions src/ai/backend/common/events/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from aiotools.taskgroup import PersistentTaskGroup
from aiotools.taskgroup.types import AsyncExceptionHandler

from ai.backend.common.contexts.request_id import current_request_id
from ai.backend.common.contexts.request_id import current_request_id, with_request_id
from ai.backend.common.contexts.user import current_user
from ai.backend.common.message_queue.queue import AbstractMessageQueue
from ai.backend.common.message_queue.types import (
Expand Down Expand Up @@ -528,7 +528,8 @@ async def _handle(
log.debug("DISPATCH_{}(evh:{})", evh_type.name, evh.name)

# Apply all context variables from metadata if available
if metadata:
# If metadata exists but request_id is None, generate a new one
if metadata and metadata.request_id:
with metadata.apply_context():
if asyncio.iscoroutinefunction(cb):
# mypy cannot catch the meaning of asyncio.iscoroutinefunction().
Expand All @@ -542,17 +543,19 @@ async def _handle(
duration=time.perf_counter() - start,
)
else:
if asyncio.iscoroutinefunction(cb):
# mypy cannot catch the meaning of asyncio.iscoroutinefunction().
await cb(evh.context, source, event) # type: ignore
else:
cb(evh.context, source, event) # type: ignore
for post_callback in post_callbacks:
await post_callback.done()
self._metric_observer.observe_event_success(
event_type=event_type,
duration=time.perf_counter() - start,
)
# Generate a new request_id for event handlers without one
with with_request_id():
if asyncio.iscoroutinefunction(cb):
# mypy cannot catch the meaning of asyncio.iscoroutinefunction().
await cb(evh.context, source, event) # type: ignore
else:
cb(evh.context, source, event) # type: ignore
for post_callback in post_callbacks:
await post_callback.done()
self._metric_observer.observe_event_success(
event_type=event_type,
duration=time.perf_counter() - start,
)
except Exception as e:
self._metric_observer.observe_event_failure(
event_type=event_type,
Expand Down
14 changes: 8 additions & 6 deletions src/ai/backend/common/leader/tasks/event_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass
from typing import Final

from ai.backend.common.contexts.request_id import with_request_id
from ai.backend.common.events.dispatcher import EventProducer
from ai.backend.common.events.types import AbstractAnycastEvent
from ai.backend.common.leader.tasks.base import PeriodicTask
Expand Down Expand Up @@ -48,12 +49,13 @@ def __init__(

async def run(self) -> None:
"""Execute the task - produce an event."""
try:
event = self._spec.event_factory()
await self._event_producer.anycast_event(event)
log.debug(f"Event task {self._spec.name} produced event")
except Exception:
log.exception(f"Failed to produce event for task {self._spec.name}")
with with_request_id():
try:
event = self._spec.event_factory()
await self._event_producer.anycast_event(event)
log.debug(f"Event task {self._spec.name} produced event")
except Exception:
log.exception(f"Failed to produce event for task {self._spec.name}")

@property
def name(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion src/ai/backend/common/middlewares/request_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ async def request_id_middleware(request: web.Request, handler: Handler) -> web.S
with_request_id(request_id),
with_log_context_fields({"request_id": request_id}),
):
return await _handler(request)
response = await _handler(request)
if request_id:
response.headers[REQUEST_ID_HEADER] = request_id
return response
2 changes: 2 additions & 0 deletions src/ai/backend/manager/agent_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ai.backend.common import msgpack
from ai.backend.common.auth import ManagerAuthHandler, PublicKey, SecretKey
from ai.backend.common.contexts.request_id import bind_request_id
from ai.backend.common.types import AgentId
from ai.backend.logging import BraceStyleAdapter

Expand Down Expand Up @@ -46,6 +47,7 @@ async def _wrapped(*args, **kwargs):
"args": args,
"kwargs": kwargs,
}
bind_request_id(request_body, f"RPC call to agent: {name}", key="request_id")
self.peer.last_used = time.monotonic()
ret = await self.peer.invoke(name, request_body, order_key=self.order_key.get())
self.peer.last_used = time.monotonic()
Expand Down
5 changes: 4 additions & 1 deletion src/ai/backend/manager/api/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aiohttp import web

from ai.backend.common import validators as tx
from ai.backend.common.contexts.request_id import bind_request_id
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.errors.common import (
InternalServerError,
Expand Down Expand Up @@ -41,11 +42,13 @@ class WSProxyVersionQueryParams:
async def query_wsproxy_status(
wsproxy_addr: str,
) -> dict[str, Any]:
headers: dict[str, str] = {"Accept": "application/json"}
bind_request_id(headers, f"wsproxy status query: {wsproxy_addr}")
async with (
aiohttp.ClientSession() as session,
session.get(
wsproxy_addr + "/status",
headers={"Accept": "application/json"},
headers=headers,
) as resp,
):
try:
Expand Down
13 changes: 9 additions & 4 deletions src/ai/backend/manager/api/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
BaseFieldModel,
)
from ai.backend.common.clients.valkey_client.valkey_stat.client import ValkeyStatClient
from ai.backend.common.contexts.request_id import bind_request_id
from ai.backend.common.exception import BackendAIError
from ai.backend.common.types import (
VFolderHostPermission,
Expand Down Expand Up @@ -2294,7 +2295,8 @@ async def get_fstab_contents(request: web.Request, params: Any) -> web.Response:
try:
client_timeout = aiohttp.ClientTimeout(total=10.0)
async with aiohttp.ClientSession(timeout=client_timeout) as sess:
headers = {"X-BackendAI-Watcher-Token": watcher_info["token"]}
headers: dict[str, str] = {"X-BackendAI-Watcher-Token": watcher_info["token"]}
bind_request_id(headers, f"watcher fstab request: agent={params['agent_id']}")
url = watcher_info["addr"] / "fstab"
async with sess.get(url, headers=headers, params=params) as watcher_resp:
if watcher_resp.status == 200:
Expand Down Expand Up @@ -2387,7 +2389,8 @@ async def _fetch_mounts(
) -> tuple[str, Mapping]:
async with sema:
watcher_info = await get_watcher_info(request, agent_id)
headers = {"X-BackendAI-Watcher-Token": watcher_info["token"]}
headers: dict[str, str] = {"X-BackendAI-Watcher-Token": watcher_info["token"]}
bind_request_id(headers, f"watcher mounts GET request: agent={agent_id}")
url = watcher_info["addr"] / "mounts"
try:
async with sess.get(url, headers=headers) as watcher_resp:
Expand Down Expand Up @@ -2506,7 +2509,8 @@ async def _mount(
async with sema:
watcher_info = await get_watcher_info(request, agent_id)
try:
headers = {"X-BackendAI-Watcher-Token": watcher_info["token"]}
headers: dict[str, str] = {"X-BackendAI-Watcher-Token": watcher_info["token"]}
bind_request_id(headers, f"watcher mount POST request: agent={agent_id}")
url = watcher_info["addr"] / "mounts"
async with sess.post(url, json=params, headers=headers) as resp:
if resp.status == 200:
Expand Down Expand Up @@ -2634,7 +2638,8 @@ async def _umount(
async with sema:
watcher_info = await get_watcher_info(request, agent_id)
try:
headers = {"X-BackendAI-Watcher-Token": watcher_info["token"]}
headers: dict[str, str] = {"X-BackendAI-Watcher-Token": watcher_info["token"]}
bind_request_id(headers, f"watcher umount DELETE request: agent={agent_id}")
url = watcher_info["addr"] / "mounts"
async with sess.delete(url, json=params, headers=headers) as resp:
if resp.status == 200:
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/cli/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import click
from alembic.config import Config

from ai.backend.common.contexts.request_id import ensure_request_id
from ai.backend.common.types import AgentId
from ai.backend.logging import BraceStyleAdapter
from ai.backend.logging.utils import enforce_debug_logging
Expand Down Expand Up @@ -59,6 +60,7 @@ def ping(cli_ctx: CLIContext, agent_id: str, alembic_config: str, timeout: float
from ai.backend.manager.agent_cache import AgentRPCCache
from ai.backend.manager.models.utils import create_async_engine

@ensure_request_id
async def _impl():
bootstrap_config = await cli_ctx.get_bootstrap_config()
manager_public_key, manager_secret_key = load_certificate(
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/clients/agent/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ai.backend.common import msgpack
from ai.backend.common.auth import ManagerAuthHandler
from ai.backend.common.contexts.request_id import ensure_request_id
from ai.backend.common.types import AgentId
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.errors.agent import AgentConnectionUnavailable
Expand Down Expand Up @@ -236,6 +237,7 @@ async def _health_check_loop(self) -> None:
await asyncio.sleep(self._spec.health_check_interval)
await self._check_all_health()

@ensure_request_id
async def _check_all_health(self) -> None:
"""Check health of all connections (using asyncio.gather)."""
async with self._lock:
Expand Down
Loading
Loading