Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/8628.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Prevent `scoped_query` from overriding `project` param with `group_id`
2 changes: 0 additions & 2 deletions src/ai/backend/manager/api/gql_legacy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,6 @@ async def wrapped(
kwargs["domain_name"] = domain_name
if group_id is not None:
kwargs["group_id"] = group_id
if kwargs.get("project") is not None:
kwargs["project"] = group_id
kwargs[user_key] = user_id
return await resolve_func(root, info, *args, **kwargs)

Expand Down
54 changes: 54 additions & 0 deletions tests/unit/manager/api/test_gql_legacy_scoped_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Regression test for scoped_query: the `project` parameter must not be
silently overridden by `group_id`. (BA-4280)
"""

from __future__ import annotations

import uuid
from typing import Any
from unittest.mock import MagicMock

import graphene
import pytest

from ai.backend.manager.api.gql_legacy.base import scoped_query
from ai.backend.manager.models.user import UserRole


class TestScopedQuery:
"""Test fixes for scoped_query bugs BA-4280."""

@pytest.fixture
def mock_graphene_info(self) -> MagicMock:
"""Mock GraphQL ResolveInfo with SUPERADMIN context."""
ctx = MagicMock()
ctx.user = {"role": UserRole.SUPERADMIN, "domain_name": "default", "uuid": uuid.uuid4()}
ctx.access_key = "test-key"
info = MagicMock(spec=graphene.ResolveInfo)
info.context = ctx
return info

@pytest.mark.asyncio
async def test_project_param_preserved(self, mock_graphene_info: MagicMock) -> None:
"""Regression: project was silently overridden by group_id in scoped_query."""
project_id = uuid.uuid4()
group_id = uuid.uuid4()
received: dict[str, Any] = {}

@scoped_query(autofill_user=False, user_key="user_uuid")
async def _mock_resolver(
_root: Any,
_info: graphene.ResolveInfo,
*,
project: uuid.UUID | None = None,
group_id: uuid.UUID | None = None,
domain_name: str | None = None,
user_uuid: uuid.UUID | None = None,
) -> None:
received.update(project=project, group_id=group_id)

await _mock_resolver(None, mock_graphene_info, project=project_id, group_id=group_id)

assert received["project"] == project_id
assert received["group_id"] == group_id
Loading