Skip to content

Draft - DO NOT REVIEW - feat(autodeploy): Add TRT-LLM attention backend with CUDA graph support#11283

Open
MrGeva wants to merge 3 commits intoNVIDIA:mainfrom
nv-auto-deploy:eg/trtllm_attn_v2_squashed
Open

Draft - DO NOT REVIEW - feat(autodeploy): Add TRT-LLM attention backend with CUDA graph support#11283
MrGeva wants to merge 3 commits intoNVIDIA:mainfrom
nv-auto-deploy:eg/trtllm_attn_v2_squashed

Conversation

@MrGeva
Copy link
Collaborator

@MrGeva MrGeva commented Feb 4, 2026

Add TRT-LLM attention backend (thop.attention) for Auto-Deploy with full CUDA graph support and optimized metadata preparation.

Key features:

  • Unified KV cache with HND layout matching thop.attention kernel
  • Pre-allocated CPU/GPU buffers to avoid allocation overhead
  • Shared tensors across all 32 layers (single update per forward)
  • Vectorized GPU block offset computation via torch.searchsorted
  • Host prepare function running outside CUDA graph for metadata updates
  • Support for FP8 KV cache with quantization scales

Performance:

  • ~6650 tokens/sec throughput on Llama-3.1-8B-Instruct-FP8
  • ~1.8% faster than PTCacheBackend baseline

Architecture:

  • TrtllmAttentionGlobalState: Singleton managing shared tensors and buffers
  • TrtllmLayerState: Per-layer state linked to global shared tensors
  • TrtllmAttentionConfig: Runtime configuration (page size, batch size, etc.)
  • host_prepare_fn: Called before CUDA graph replay to update metadata

Files:

  • trtllm_attention.py: Main implementation with attention kernel wrapper
  • attention_interface.py: Interface extensions for TRT-LLM backend
  • kvcache.py: Transform integration for cached attention
  • llm_args.py: Configuration options for TRT-LLM backend

Summary by CodeRabbit

  • New Features
    • Integrated TRT-LLM attention backend for Auto-Deploy with optimized KV cache management, paging support, and shared workspace handling.
    • Enhanced KV cache configuration with flexible data type options and pool integration for efficient memory usage.
    • Expanded configuration framework with new setup options for both simplified and advanced use cases.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@MrGeva MrGeva requested a review from a team as a code owner February 4, 2026 17:12
@MrGeva MrGeva requested a review from nvchenghaoz February 4, 2026 17:12
@MrGeva MrGeva changed the title feat(autodeploy): Add TRT-LLM attention backend with CUDA graph support Draft - DO NOT REVIEW - feat(autodeploy): Add TRT-LLM attention backend with CUDA graph support Feb 4, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 4, 2026

📝 Walkthrough

Walkthrough

This pull request introduces a comprehensive TRT-LLM attention backend for Auto-Deploy with KV cache pool support, adding configuration models, per-layer state management, and kernel integration, alongside modernizing the configuration system with AutoDeployConfig and LlmArgs classes.

Changes

Cohort / File(s) Summary
Attention Configuration
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Introduces CacheConfig Pydantic model with dtype coercion and __or__ merge operator; extends SequenceInfo with KV cache pool integration via internal placeholders, setters, and public properties; adds protocol/type aliases for cache and buffer initializers.
TRT-LLM Attention Backend
tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py
Implements comprehensive TRT-LLM attention backend with workspace resource handler (64 MB shared buffer), paged KV cache management with HND layout, per-layer state tracking via TrtllmLayerState, global state singleton managing GPU/CPU buffers and pool pointers, host metadata preparation orchestration, custom MHA wrapper with QKV fusion, and AttentionRegistry integration hooks.
LLM Configuration
tensorrt_llm/_torch/auto_deploy/llm_args.py
Introduces AutoDeployConfig and LlmArgs classes replacing TorchLlmArgs; expands public fields for KV cache dtype/config, tokenizer lifecycle, attention page size, performance statistics; implements validators for synchronization, KV cache constraints, and model configuration; includes to_dict and to_llm_kwargs serialization methods.
Transform KV Cache Config
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
Adds cache_config field of type CacheConfig to InsertCachedAttentionConfig with default factory for custom cache configuration.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Auto-Deploy
    participant Config as TrtllmAttentionConfig
    participant GlobalState as TrtllmAttentionGlobalState
    participant LayerState as TrtllmLayerState
    participant KVCache as TrtllmKVResourceHandler
    participant Workspace as TrtllmWorkspaceResourceHandler
    participant Kernel as TRT-LLM Kernel

    Client->>Config: configure(sequence_info)
    Config->>GlobalState: allocate_workspace()
    GlobalState->>Workspace: allocate()
    Workspace-->>GlobalState: workspace tensor (64MB)

    loop Per Layer
        Client->>GlobalState: get_or_create_layer_state(layer_idx)
        GlobalState->>LayerState: init_from_shared()
        GlobalState->>KVCache: allocate(sequence_info)
        KVCache-->>GlobalState: paged kv_cache
    end

    Client->>Client: trtllm_mha_with_cache(Q, K, V)
    Client->>GlobalState: _prepare_trtllm_metadata()
    GlobalState-->>Client: block_offsets, pool_info
    Client->>LayerState: set_kv_cache_pool_info()
    Client->>Kernel: attention(fused_QKV, workspace, pool_pointers, pool_mapping)
    Kernel-->>Client: attention output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~65 minutes

🚥 Pre-merge checks | ❌ 3
❌ Failed checks (2 warnings, 1 inconclusive)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description contains technical details about the implementation but lacks the structured template sections (Title format, Description, Test Coverage, and Checklist items) required by the repository. Reformat the description to follow the template: Add a properly formatted title with [Ticket][type], provide concise Description and Test Coverage sections, and complete the PR Checklist by marking items reviewed.
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title is partially related to the changeset but includes a 'Draft - DO NOT REVIEW' prefix that is not descriptive and suggests the PR is incomplete. Remove the 'Draft - DO NOT REVIEW' prefix and ensure the title follows the repository template: [JIRA/Issue][type] Summary (e.g., [TRTLLM-1234][feat] Add TRT-LLM attention backend with CUDA graph support).

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py (2)

1-2: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year to reflect 2026 modifications.

The file was modified in 2026, but the header still states 2025. Please update it to the latest meaningful modification year.

✏️ Suggested update
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

As per coding guidelines, All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification.


49-55: ⚠️ Potential issue | 🟠 Major

Wire cache_config into cache initialization path.

InsertCachedAttentionConfig.cache_config is defined but unused—_apply() passes cm.kv_cache_config (type KvCacheConfig) to get_cache_initializers() instead of self.config.cache_config (type CacheConfig). This causes a type mismatch for the TRT-LLM backend, which expects CacheConfig and accesses cache_config.dtype. The new field must be wired into the initializer call to take effect.

🔧 Suggested fix
-            for k, resource_handler in attn_descriptor.get_cache_initializers(
-                attn_node, cm.kv_cache_config
-            ).items():
+            cache_config = self.config.cache_config
+            for k, resource_handler in attn_descriptor.get_cache_initializers(
+                attn_node, cache_config
+            ).items():
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)

1-10: ⚠️ Potential issue | 🟠 Major

Add the NVIDIA copyright header to this source file.

The file currently starts with a module docstring and has no NVIDIA SPDX header. Please add the required header with the latest modification year.

As per coding guidelines, All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification.

tensorrt_llm/_torch/auto_deploy/llm_args.py (1)

1-5: ⚠️ Potential issue | 🟠 Major

Add the NVIDIA copyright header to this file.

This source file is missing the required NVIDIA SPDX header with the latest modification year.

As per coding guidelines, All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification.

🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py`:
- Around line 301-309: The field validator method _coerce_dtype should not use
assert for invalid dtype checks because asserts can be skipped under -O; replace
the assert with raising a ValueError (e.g., raise ValueError(f"Invalid dtype:
{value!r}") or include the resolved dtype) so Pydantic always rejects bad
inputs; update the validation branch in the _coerce_dtype classmethod (decorated
with `@field_validator`("dtype", "mamba_dtype", "delta_dtype", mode="before")) to
raise ValueError when getattr(torch, value, None) is not a torch.dtype and
return the resolved torch.dtype otherwise.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py`:
- Around line 206-217: The shared pool mapping
(_shared_host_kv_cache_pool_mapping) is being written by _init_pool_pointers but
layers still use their per-layer host_kv_cache_pool_mapping (returned by the
fast path) which remains all-zero; fix by wiring layers to the shared mapping in
init_from_shared (set self.host_kv_cache_pool_mapping =
global_state._shared_host_kv_cache_pool_mapping) or, alternatively, update each
layer’s host_kv_cache_pool_mapping inside _init_pool_pointers when
_shared_host_kv_cache_pool_mapping is populated so the fast-path (which reads
state.host_kv_cache_pool_mapping) sees the correct mapping; apply the same
change to the corresponding duplicate initialization block elsewhere (the other
init_from_shared/_init_pool_pointers usage).
- Around line 1-2: Update the SPDX header year from 2025 to 2026 in the file
header comment at the top of trtllm_attention.py: change the copyright line that
currently reads "Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES." to
"Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES." so the SPDX header (first
two lines) reflects the latest modification year.
- Around line 890-952: The call to thop.attention currently hardcodes
q_scaling=1.0 but the function receives a scale parameter (variable name scale)
that should be applied; change the q_scaling argument in the thop.attention
invocation to pass the existing local variable scale (or a computed fallback
like (1.0 / (head_dim**0.5)) if scale is None) instead of 1.0 so q_scaling
reflects the source attention op; look for the thop.attention call and replace
the literal 1.0 q_scaling argument with scale (or computed fallback).

In `@tensorrt_llm/_torch/auto_deploy/llm_args.py`:
- Around line 241-257: The update_attn_page_size validator currently only
inspects self.transforms and can miss shortcut fields set via attn_backend;
update the validator (update_attn_page_size) to also check self.attn_backend
(e.g., treat "triton" or "torch" the same as transforms backends) and set
self.attn_page_size = self.max_seq_len when attn_backend indicates a
cached-attention backend, or alternatively move this validator to run after
update_transforms_with_shortcuts so transforms are already
synchronized—reference update_attn_page_size, self.transforms,
self.attn_backend, and update_transforms_with_shortcuts when making the change.
- Around line 6-7: Replace the direct pydantic Field import with the project
wrapper that handles status metadata: change the import line so Field is
imported from tensorrt_llm.llmapi.llm_args (while keeping PrivateAttr,
ValidationInfo, field_validator, model_validator from pydantic and BaseSettings,
SettingsConfigDict from pydantic_settings) so code using Field(status=...) is
converted into json_schema_extra correctly; update the import statement that
currently references pydantic.Field to instead reference the custom Field from
tensorrt_llm.llmapi.llm_args.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)

1048-1068: Initialize KV cache pool fields in __init__ to avoid class-level state.

_kv_cache_pool_pointers and _kv_cache_pool_mapping are defined at class scope, which makes them shared defaults. Since they are externally visible via properties, initialize them per-instance in __init__ instead.

♻️ Suggested fix
@@
-        self._extra_args: Dict[str, Optional[torch.Tensor]] = {}
+        self._extra_args: Dict[str, Optional[torch.Tensor]] = {}
+        self._kv_cache_pool_pointers: Optional[torch.Tensor] = None
+        self._kv_cache_pool_mapping: Optional[torch.Tensor] = None
@@
-    # KV Cache pool info for TRT-LLM attention (set by CacheInterface after KVCacheManager creation)
-    _kv_cache_pool_pointers: Optional[torch.Tensor] = None
-    _kv_cache_pool_mapping: Optional[torch.Tensor] = None

As per coding guidelines, Initialize all externally visible members of a Python class in the constructor.

Comment on lines +301 to +309
@field_validator("dtype", "mamba_dtype", "delta_dtype", mode="before")
@classmethod
def _coerce_dtype(cls, value):
if value is None or isinstance(value, torch.dtype):
return value
if isinstance(value, str):
dtype = getattr(torch, value, None)
assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}"
return dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

Pydantic v2 field validator assert vs ValueError validation error handling

💡 Result:

In Pydantic v2 field validators (@field_validator / BeforeValidator / AfterValidator), you should signal “this input is invalid” by raising ValueError or AssertionError (including via assert ...). Pydantic catches those and aggregates them into a single ValidationError. You generally should not raise ValidationError yourself inside validation code. [1][2]

assert / AssertionError

  • Works and becomes a validation error (typically error type like assertion_error). [1][3]
  • Not safe for real validation logic because assert statements are skipped when Python runs with -O (optimize), meaning your validator checks can silently stop running. [4][1]

ValueError

  • The recommended/default way to fail validation (always executed regardless of -O). [1][2]
  • Produces a validation error (often value_error) that’s collected into the final ValidationError. [1][2]

What not to do (common gotcha)

  • If you raise other exception types (e.g., TypeError, IndexError, etc.), they are not treated as validation errors and can bubble up directly rather than being wrapped into a ValidationError. [5]

When you need richer error codes/messages

  • Use pydantic_core.PydanticCustomError for structured/custom error types and templated messages. [1]

Practical rule: use ValueError for user/data validation, reserve assert for developer invariants/tests (or only when you’re sure -O won’t be used). [4][2]

Sources: [1] https://docs.pydantic.dev/latest/concepts/validators/ , [2] https://docs.pydantic.dev/latest/errors/errors/ , [3] https://docs.pydantic.dev/2.0/usage/validation_errors/ , [4] https://docs.pydantic.dev/2.5/concepts/validators/ , [5] https://docs.pydantic.dev/2.3/usage/validators/


🏁 Script executed:

# Check if the file exists and view the actual context around the problematic code
fd -t f "attention_interface.py" --exec wc -l {} \;

Repository: NVIDIA/TensorRT-LLM

Length of output: 135


🏁 Script executed:

# Also check the imports to confirm this is using Pydantic v2
fd -t f "attention_interface.py" --exec head -50 {} \;

Repository: NVIDIA/TensorRT-LLM

Length of output: 2134


🏁 Script executed:

# View the specific lines mentioned in the review (301-309)
sed -n '301,309p' ./tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 453


🏁 Script executed:

# Also check a broader context to understand the class and imports
sed -n '280,320p' ./tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1739


🏁 Script executed:

# Check for the presence of copyright header in the file
head -20 ./tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1066


Use ValueError instead of assert for dtype validation in Pydantic field validator.

assert statements are optimized away when Python runs with the -O flag, causing the validation check to silently skip. Pydantic v2 field validators should raise ValueError (or AssertionError directly, but ValueError is preferred) to ensure invalid inputs are always rejected.

Suggested change
-            assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}"
-            return dtype
+            if not isinstance(dtype, torch.dtype):
+                raise ValueError(f"Invalid dtype string: {value!r}")
+            return dtype
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py` around
lines 301 - 309, The field validator method _coerce_dtype should not use assert
for invalid dtype checks because asserts can be skipped under -O; replace the
assert with raising a ValueError (e.g., raise ValueError(f"Invalid dtype:
{value!r}") or include the resolved dtype) so Pydantic always rejects bad
inputs; update the validation branch in the _coerce_dtype classmethod (decorated
with `@field_validator`("dtype", "mamba_dtype", "delta_dtype", mode="before")) to
raise ValueError when getattr(torch, value, None) is not a torch.dtype and
return the resolved torch.dtype otherwise.

Comment on lines +1 to +2
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year to 2026.

This is a new 2026 file, but the header still lists 2025.

✏️ Suggested update
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

As per coding guidelines, All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification.

🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py` around lines
1 - 2, Update the SPDX header year from 2025 to 2026 in the file header comment
at the top of trtllm_attention.py: change the copyright line that currently
reads "Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES." to "Copyright (c)
2026 NVIDIA CORPORATION & AFFILIATES." so the SPDX header (first two lines)
reflects the latest modification year.

Comment on lines 206 to 198
def init_from_shared(self, global_state: "TrtllmAttentionGlobalState") -> None:
"""Initialize layer to use shared tensors from global state."""
# All layers share the same tensors (single KV cache pool)
self.sequence_length = global_state._shared_sequence_length
self.context_lengths = global_state._shared_context_lengths
self.kv_cache_block_offsets = global_state._shared_kv_cache_block_offsets
self.host_past_key_value_lengths = global_state._shared_host_past_key_value_lengths
self.host_context_lengths = global_state._shared_host_context_lengths
self.host_request_types = global_state._shared_host_request_types
self.host_total_kv_lens = global_state._shared_host_total_kv_lens
self.host_kv_cache_pool_pointers = global_state._shared_host_kv_cache_pool_pointers
# Keep host_kv_cache_pool_mapping from __post_init__ - it's layer-specific
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pool mapping written to shared tensor, but fast path returns per-layer mapping.

_init_pool_pointers() updates _shared_host_kv_cache_pool_mapping, yet the fast-path returns state.host_kv_cache_pool_mapping (per-layer) which is never updated there. This can leave mapping all-zero for CUDA graph capture/replay unless the initial non-capture path ran first. Consider wiring layer states to the shared mapping or updating each layer’s mapping in _init_pool_pointers().

✅ One possible fix (share mapping across layers)
     def init_from_shared(self, global_state: "TrtllmAttentionGlobalState") -> None:
@@
-        # Keep host_kv_cache_pool_mapping from __post_init__ - it's layer-specific
+        self.host_kv_cache_pool_mapping = global_state._shared_host_kv_cache_pool_mapping

Also applies to: 345-370

🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py` around lines
206 - 217, The shared pool mapping (_shared_host_kv_cache_pool_mapping) is being
written by _init_pool_pointers but layers still use their per-layer
host_kv_cache_pool_mapping (returned by the fast path) which remains all-zero;
fix by wiring layers to the shared mapping in init_from_shared (set
self.host_kv_cache_pool_mapping =
global_state._shared_host_kv_cache_pool_mapping) or, alternatively, update each
layer’s host_kv_cache_pool_mapping inside _init_pool_pointers when
_shared_host_kv_cache_pool_mapping is populated so the fast-path (which reads
state.host_kv_cache_pool_mapping) sees the correct mapping; apply the same
change to the corresponding duplicate initialization block elsewhere (the other
init_from_shared/_init_pool_pointers usage).

Comment on lines 271 to 324
device = "cuda"
max_blocks_per_seq = (max_context_length + tokens_per_block - 1) // tokens_per_block

# Shared device tensors
self._shared_sequence_length = torch.zeros(
max_num_requests, dtype=torch.int32, device=device
)
self._shared_context_lengths = torch.zeros(
max_num_requests, dtype=torch.int32, device=device
)
self._shared_kv_cache_block_offsets = torch.zeros(
1, max_num_requests, 2, max_blocks_per_seq, dtype=torch.int32, device=device
)

# Shared host tensors (pinned memory)
self._shared_host_past_key_value_lengths = torch.zeros(
max_num_requests, dtype=torch.int32, device="cpu", pin_memory=True
)
self._shared_host_context_lengths = torch.zeros(
max_num_requests, dtype=torch.int32, device="cpu", pin_memory=True
)
self._shared_host_request_types = torch.zeros(
max_num_requests, dtype=torch.int32, device="cpu", pin_memory=True
)
self._shared_host_total_kv_lens = torch.zeros(
2, dtype=torch.int64, device="cpu", pin_memory=True
)
self._shared_host_kv_cache_pool_pointers = torch.zeros(
1, 2, dtype=torch.int64, device="cpu", pin_memory=True
)
# Pool mapping: [num_layers, 2] - layer to pool mapping
# Using 64 as max layers (typical transformer max)
self._shared_host_kv_cache_pool_mapping = torch.zeros(
64, 2, dtype=torch.int32, device="cpu", pin_memory=True
)

self._shared_tensors_initialized = True

def _init_cpu_buffers(self, max_seqs: int) -> None:
"""Initialize pre-allocated CPU buffers to avoid tensor allocation in hot path."""
if self._cpu_buffers_initialized:
return

# Pre-allocate pinned CPU buffers for intermediate computations
self._cpu_input_seq_lens = torch.zeros(
max_seqs, dtype=torch.int32, device="cpu", pin_memory=True
)
self._cpu_seq_len_with_cache = torch.zeros(
max_seqs, dtype=torch.int32, device="cpu", pin_memory=True
)
self._cpu_past_kv_lens = torch.zeros(
max_seqs, dtype=torch.int32, device="cpu", pin_memory=True
)
self._cpu_cu_num_pages = torch.zeros(
max_seqs + 1, dtype=torch.long, device="cpu", pin_memory=True
)
self._cpu_pages_per_seq = torch.zeros(
max_seqs, dtype=torch.int32, device="cpu", pin_memory=True
)
self._cpu_buffers_initialized = True

def _init_gpu_buffers(self, max_pages: int, max_seqs: int) -> None:
"""Initialize pre-allocated GPU buffers for vectorized operations."""
if self._gpu_buffers_initialized:
return

# Pre-allocate buffers with max sizes to avoid per-call allocations
self._gpu_cu_pages = torch.zeros(max_seqs + 1, dtype=torch.long, device="cuda")
self._gpu_page_positions = torch.arange(max_pages, dtype=torch.long, device="cuda")
self._gpu_seq_idx = torch.zeros(max_pages, dtype=torch.long, device="cuda")
self._gpu_page_idx = torch.zeros(max_pages, dtype=torch.long, device="cuda")
self._gpu_base_offset = torch.zeros(max_pages, dtype=torch.int32, device="cuda")
self._gpu_buffers_initialized = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's check the file existence and size
wc -l tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 130


🏁 Script executed:

# Read the relevant section (lines 260-350 to include context)
sed -n '260,350p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 4166


🏁 Script executed:

# Also read lines around 705-707
sed -n '695,720p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1177


🏁 Script executed:

# Search for all .cuda() calls in the file
rg '\.cuda\(' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 204


🏁 Script executed:

# Search for device="cuda" or device="cpu" patterns
rg 'device\s*=\s*["\x27](cuda|cpu)' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 1660


🏁 Script executed:

# Search for _trtllm_config and _sequence_info usage
rg '_trtllm_config|_sequence_info' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 1589


🏁 Script executed:

# Find SequenceInfo class definition
rg 'class SequenceInfo' --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 148


🏁 Script executed:

# Search for imports at the top of the file to understand where SequenceInfo comes from
head -50 tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1964


🏁 Script executed:

# Check if device is accessed from sequence_info anywhere else in codebase
rg 'sequence_info\.device|_sequence_info\.device' --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 657


🏁 Script executed:

# Look for how device is determined in similar context (q.device pattern)
sed -n '1,100p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 3694


🏁 Script executed:

# Find where _init_shared_tensors is called
rg '_init_shared_tensors' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 194


🏁 Script executed:

# Check when _sequence_info is set (should be during configure)
sed -n '135,160p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1309


🏁 Script executed:

# Look at SequenceInfo class definition
sed -n '1,100p' tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 4502


🏁 Script executed:

# Find the SequenceInfo dataclass definition
rg -A 20 'class SequenceInfo' tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1360


🏁 Script executed:

# Get more of the SequenceInfo definition to find device attribute
sed -n '/^class SequenceInfo:/,/^class /p' tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py | head -150

Repository: NVIDIA/TensorRT-LLM

Length of output: 8941


🏁 Script executed:

# Also check if SequenceInfo is a dataclass with device field
rg '@dataclass' tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py -A 30 | grep -A 30 'SequenceInfo'

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search for where device property is defined on SequenceInfo
rg 'def device|self\.device' tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 236


🏁 Script executed:

# Get the device property definition
sed -n '118,125p' tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 325


🏁 Script executed:

# Also check the other device property at line 511
sed -n '511,520p' tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 418


🏁 Script executed:

# Look at the allocate method more carefully to understand the call flow
sed -n '130,160p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1462


🏁 Script executed:

# Find the full allocate method that calls _init_shared_tensors
sed -n '130,410p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 4509


🏁 Script executed:

# Look at where _init_shared_tensors is called (line 387)
sed -n '375,400p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1053


🏁 Script executed:

# Check the host_prepare method to understand the call sequence
sed -n '375,450p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2977


🏁 Script executed:

# Find where get_or_create_layer_state is called
rg 'get_or_create_layer_state' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 158


🏁 Script executed:

# Look at the forward method where get_or_create_layer_state is called
sed -n '600,750p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 6622


🏁 Script executed:

# Check if _init_shared_tensors is called before configure or if there's a guarantee
sed -n '845,900p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1898


🏁 Script executed:

# Look at the full forward flow to understand when allocate vs get_or_create_layer_state is called
rg 'def forward|allocate\(|get_or_create_layer_state' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py -n -B 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 558


🏁 Script executed:

# Check if _sequence_info is guaranteed to be set before _init_shared_tensors is called
# by looking at the global state initialization and when configure is called
sed -n '200,270p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 4318


🏁 Script executed:

# Look at TrtllmAttentionConfig to see the structure
sed -n '1050,1120p' tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2531


Hard-coded "cuda" and .cuda() calls break non-default GPU usage (cuda:1, etc.).

Shared buffers and metadata tensors are forced onto the default CUDA device. When the model runs on cuda:1 or other non-default devices, this causes device-mismatch errors.

The _trtllm_config._sequence_info is available at the time these allocations occur (set during configure() before _init_shared_tensors() is called), so use _trtllm_config._sequence_info.device for device-aware allocations.

For the tensor copies at lines 705-706, use .to(state.sequence_length.device) pattern instead of .cuda() to ensure tensors move to the correct target device.

Affected locations:

  • Line 271: device = "cuda"
  • Lines 338-342: device="cuda" in _init_gpu_buffers()
  • Lines 705-706: .cuda() calls in _prepare_trtllm_metadata()
  • Lines 1109-1110: device="cuda" in FP8 KV cache setup

Comment on lines +890 to +934
# Compute softmax scale
# sm_scale = scale if scale is not None else (1.0 / (head_dim**0.5))

# Attention window (full attention)
attention_window_size = max_context_length

# Pack parameters for thop.attention
rotary_embedding_scales = [1.0, 1.0, 1.0]
rotary_embedding_max_position_info = [max_context_length, max_context_length]
spec_decoding_bool_params = [False, False, False]
spec_decoding_tensor_params = [None, None, None]

# Add extra params for newer TRT-LLM versions
sm_version = get_sm_version()
if sm_version >= 89: # Ada/Hopper
spec_decoding_tensor_params.extend([None, None, None])

mla_tensor_params = [None, None]

try:
thop.attention(
qkv_fused, # q (actually fused QKV)
None, # k (None when using fused QKV)
None, # v (None when using fused QKV)
output, # output
None, # output_sf (NVFP4)
workspace_buffer, # workspace
sequence_length, # sequence_length
host_past_key_value_lengths, # host_past_key_value_lengths
host_total_kv_lens, # host_total_kv_lens
context_lengths, # context_lengths
host_context_lengths, # host_context_lengths
host_request_types, # host_request_types
kv_cache_block_offsets, # kv_cache_block_offsets
host_kv_cache_pool_pointers, # host_kv_cache_pool_pointers
host_kv_cache_pool_mapping, # host_kv_cache_pool_mapping
None, # cache_indirection (beam search)
_trtllm_config._kv_scale_orig_quant, # kv_scale_orig_quant (FP8 KV cache)
_trtllm_config._kv_scale_quant_orig, # kv_scale_quant_orig (FP8 KV cache)
None, # out_scale
None, # rotary_inv_freq
None, # rotary_cos_sin
None, # latent_cache (MLA)
None, # q_pe (MLA)
None, # block_ids_per_seq
None, # attention_sinks
True, # is_fused_qkv (Q contains [Q,K,V] concatenated)
True, # update_kv_cache
1, # predicted_tokens_per_seq
layer_idx, # layer_idx
num_heads, # num_heads
num_kv_heads, # num_kv_heads
head_dim, # head_size
kv_cache.shape[3], # tokens_per_block - use actual value from kv_cache shape!
max_num_requests, # max_num_requests
max_context_length, # max_context_length
attention_window_size, # attention_window_size
0, # sink_token_length
1, # beam_width
int(AttentionMaskType.causal), # mask_type
_trtllm_config._quant_mode, # quant_mode (128 for FP8 KV cache, 0 otherwise)
1.0, # q_scaling (scaling factor applied to Q, typically 1.0)
0, # position_embedding_type (none - RoPE applied outside)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's find and examine the file
find . -type f -name "trtllm_attention.py" | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 127


🏁 Script executed:

# Once found, let's look at the specific lines mentioned in the review
# Lines 890-952 should show the thop.attention call
fd "trtllm_attention.py" -x wc -l {}

Repository: NVIDIA/TensorRT-LLM

Length of output: 132


🏁 Script executed:

# Let's check the function signature to see if scale is a parameter
ast-grep --pattern 'def $FUNC_NAME($$$scale$$$) {
  $$$
  thop.attention(
    $$$
  )
  $$$
}'

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Find the file
git ls-files | grep -i "trtllm_attention.py"

Repository: NVIDIA/TensorRT-LLM

Length of output: 125


🏁 Script executed:

# Read the relevant section with context
sed -n '850,900p' ./tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1725


🏁 Script executed:

# Read the exact lines in the review
sed -n '890,952p' ./tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2977


🏁 Script executed:

# Let's find the function that contains this code - look for function def before line 890
sed -n '800,950p' ./tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py | head -150

Repository: NVIDIA/TensorRT-LLM

Length of output: 6104


🏁 Script executed:

# Find the function signature - look for "def " before line 890
sed -n '700,890p' ./tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py | grep -n "def " | tail -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 92


🏁 Script executed:

# Let's search for the function that contains this thop.attention call
# First, find all function definitions
rg "^def " ./tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 213


🏁 Script executed:

# Search for scale parameter in the file to see if it's used anywhere
rg "scale" ./tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py -B 2 -A 2 -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 3672


🏁 Script executed:

# Let's find the exact line where q_scaling is hardcoded to 1.0
sed -n '945,960p' ./tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 993


🏁 Script executed:

# Let's also check what the full function signature looks like for context
sed -n '766,800p' ./tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1169


Pass the scale parameter through to q_scaling instead of hardcoding 1.0.

The function receives scale as a parameter (line 786), which is extracted from the source attention op, but it's never used. The hardcoded q_scaling=1.0 will produce incorrect results if the source attention specifies a non-default scale. The commented code at line 890-891 shows this was intended to be fixed.

🔧 Suggested fix
-    # Compute softmax scale
-    # sm_scale = scale if scale is not None else (1.0 / (head_dim**0.5))
+    q_scaling = scale if scale is not None else (1.0 / (head_dim**0.5))
@@
-            1.0,  # q_scaling (scaling factor applied to Q, typically 1.0)
+            q_scaling,  # q_scaling (scale from source op or 1/sqrt(head_dim))
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py` around lines
890 - 952, The call to thop.attention currently hardcodes q_scaling=1.0 but the
function receives a scale parameter (variable name scale) that should be
applied; change the q_scaling argument in the thop.attention invocation to pass
the existing local variable scale (or a computed fallback like (1.0 /
(head_dim**0.5)) if scale is None) instead of 1.0 so q_scaling reflects the
source attention op; look for the thop.attention call and replace the literal
1.0 q_scaling argument with scale (or computed fallback).

Comment on lines +6 to 7
from pydantic import Field, PrivateAttr, ValidationInfo, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "llm_args.py" | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 142


🏁 Script executed:

cat -n tensorrt_llm/_torch/auto_deploy/llm_args.py | head -50

Repository: NVIDIA/TensorRT-LLM

Length of output: 1988


🏁 Script executed:

cat -n tensorrt_llm/llmapi/llm_args.py | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 4765


🏁 Script executed:

grep -n "status=" tensorrt_llm/_torch/auto_deploy/llm_args.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 228


🏁 Script executed:

sed -n '225,245p' tensorrt_llm/_torch/auto_deploy/llm_args.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1045


🏁 Script executed:

grep -n "status=" tensorrt_llm/_torch/auto_deploy/llm_args.py | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 228


🏁 Script executed:

web_search
Does Pydantic v2 Field() function accept arbitrary keyword arguments like status= or does it raise an error?

Repository: NVIDIA/TensorRT-LLM

Length of output: 307


🌐 Web query:

Pydantic v2 Field function accept arbitrary keyword arguments status parameter

💡 Result:

In Pydantic v2, pydantic.Field(...) does not accept arbitrary extra keyword arguments for JSON Schema anymore (so a custom kwarg like status=... is not supported as a free-form parameter). Instead, put custom schema keys under json_schema_extra. [1]

from pydantic import BaseModel, Field

class M(BaseModel):
    x: int = Field(json_schema_extra={'status': 'experimental'})

Field explicitly provides json_schema_extra for this purpose. [2]

If you pass keywords that v2 considers removed/unsupported, Pydantic raises a usage error (e.g., the documented “Keyword arguments removed” error). [3]

Sources
[1] Pydantic v2 Migration Guide – “Field no longer supports arbitrary keyword arguments… use json_schema_extra” (docs.pydantic.dev)
[2] Pydantic v2 Field API reference (shows json_schema_extra parameter) (docs.pydantic.dev)
[3] Pydantic “Usage Errors” – “Keyword arguments removed” (docs.pydantic.dev)

Citations:


🏁 Script executed:

grep -r "from.*llmapi.llm_args import" tensorrt_llm --include="*.py" -A 10 | grep -E "(Field|from.*llmapi)" | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 3365


🏁 Script executed:

grep -r "status=" tensorrt_llm --include="*.py" | grep -v "_torch/auto_deploy" | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 1042


🏁 Script executed:

python3 << 'EOF'
# Test what happens when we pass status= to pydantic.Field
try:
    from pydantic import BaseModel, Field
    
    class TestModel(BaseModel):
        x: int = Field(default=1, status="prototype")
    
    print("SUCCESS: Field accepted status parameter")
except TypeError as e:
    print(f"ERROR (TypeError): {e}")
except Exception as e:
    print(f"ERROR ({type(e).__name__}): {e}")
EOF

Repository: NVIDIA/TensorRT-LLM

Length of output: 421


🏁 Script executed:

python3 -c "import pydantic; print(pydantic.__version__)"

Repository: NVIDIA/TensorRT-LLM

Length of output: 69


🏁 Script executed:

grep -r "pydantic" tensorrt_llm/requirements* setup.py pyproject.toml 2>/dev/null | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

sed -n '225,250p' tensorrt_llm/_torch/auto_deploy/llm_args.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 1414


🏁 Script executed:

python3 << 'EOF'
# Verify that the custom Field from llmapi.llm_args works correctly with status=
import sys
sys.path.insert(0, '/tmp/repo')

try:
    from tensorrt_llm.llmapi.llm_args import Field
    from pydantic import BaseModel
    
    class TestModel(BaseModel):
        x: int = Field(default=1, status="prototype")
    
    # Check the schema
    schema = TestModel.model_json_schema()
    print("Custom Field wrapper result:")
    print(f"  Field properties: {schema['properties']['x']}")
    
except Exception as e:
    print(f"ERROR: {e}")
EOF

Repository: NVIDIA/TensorRT-LLM

Length of output: 145


🏁 Script executed:

sed -n '1,20p' tensorrt_llm/_torch/auto_deploy/llm_args.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 681


🏁 Script executed:

sed -n '1,100p' tensorrt_llm/llmapi/llm_args.py | grep -A 2 "^def Field"

Repository: NVIDIA/TensorRT-LLM

Length of output: 184


Import custom Field wrapper from llmapi.llm_args to handle status= metadata correctly.

Pydantic v2 doesn't natively support the status= keyword argument on Field. While v2.12.5 (currently in use) accepts it with a deprecation warning, this will be removed in Pydantic v3.0. The codebase has a custom Field wrapper in tensorrt_llm.llmapi.llm_args that properly handles the status parameter by converting it to json_schema_extra. Update the import to use this wrapper instead of pydantic.Field directly.

✅ Suggested import adjustment
-from pydantic import Field, PrivateAttr, ValidationInfo, field_validator, model_validator
+from pydantic import PrivateAttr, ValidationInfo, field_validator, model_validator
@@
-from ...llmapi.llm_args import (
+from ...llmapi.llm_args import (
+    Field,
     BaseLlmArgs,
     BuildConfig,
     EagleDecodingConfig,
     KvCacheConfig,
     SamplerType,
     _ParallelConfig,
 )
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/llm_args.py` around lines 6 - 7, Replace the
direct pydantic Field import with the project wrapper that handles status
metadata: change the import line so Field is imported from
tensorrt_llm.llmapi.llm_args (while keeping PrivateAttr, ValidationInfo,
field_validator, model_validator from pydantic and BaseSettings,
SettingsConfigDict from pydantic_settings) so code using Field(status=...) is
converted into json_schema_extra correctly; update the import statement that
currently references pydantic.Field to instead reference the custom Field from
tensorrt_llm.llmapi.llm_args.

Comment on lines 241 to +257
### VALIDATION #################################################################################
@model_validator(mode="after")
# TODO: discuss what to do with this once we fully transition to the new inference optimizer
def update_attn_page_size(self):
# NOTE force attn_page_size to equal max_seq_len for triton backend
if self.transforms.get("insert_cached_attention", {}).get("backend") in [
"triton",
"torch",
]:
self.attn_page_size = self.max_seq_len
# NOTE: (hg) For transformers mode. This is ugly.
if self.transforms.get("transformers_replace_cached_attn", {}).get("backend") in [
"triton",
"torch",
]:
self.attn_page_size = self.max_seq_len
return self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

update_attn_page_size can miss attn_backend shortcuts.

This validator reads only self.transforms, but it runs before update_transforms_with_shortcuts. If a user sets attn_backend="triton" without explicitly populating transforms, attn_page_size won’t be forced to max_seq_len. Consider using self.attn_backend directly or moving this validator after shortcut synchronization.

🔧 Minimal fix (use shortcut field directly)
-        if self.transforms.get("insert_cached_attention", {}).get("backend") in [
-            "triton",
-            "torch",
-        ]:
+        backend = self.attn_backend
+        if backend in ["triton", "torch"] or self.transforms.get(
+            "insert_cached_attention", {}
+        ).get("backend") in ["triton", "torch"]:
             self.attn_page_size = self.max_seq_len
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/llm_args.py` around lines 241 - 257, The
update_attn_page_size validator currently only inspects self.transforms and can
miss shortcut fields set via attn_backend; update the validator
(update_attn_page_size) to also check self.attn_backend (e.g., treat "triton" or
"torch" the same as transforms backends) and set self.attn_page_size =
self.max_seq_len when attn_backend indicates a cached-attention backend, or
alternatively move this validator to run after update_transforms_with_shortcuts
so transforms are already synchronized—reference update_attn_page_size,
self.transforms, self.attn_backend, and update_transforms_with_shortcuts when
making the change.

@MrGeva MrGeva force-pushed the eg/trtllm_attn_v2_squashed branch from 9d9868b to c4b169e Compare February 5, 2026 06:41
Add TRT-LLM attention backend (thop.attention) for Auto-Deploy with full
CUDA graph support and optimized metadata preparation.

Key features:
- Unified KV cache with HND layout matching thop.attention kernel
- Pre-allocated CPU/GPU buffers to avoid allocation overhead
- Shared tensors across all 32 layers (single update per forward)
- Vectorized GPU block offset computation via torch.searchsorted
- Host prepare function running outside CUDA graph for metadata updates
- Support for FP8 KV cache with quantization scales

Performance:
- ~6650 tokens/sec throughput on Llama-3.1-8B-Instruct-FP8
- ~1.8% faster than PTCacheBackend baseline

Architecture:
- TrtllmAttentionGlobalState: Singleton managing shared tensors and buffers
- TrtllmLayerState: Per-layer state linked to global shared tensors
- TrtllmAttentionConfig: Runtime configuration (page size, batch size, etc.)
- host_prepare_fn: Called before CUDA graph replay to update metadata

Files:
- trtllm_attention.py: Main implementation with attention kernel wrapper
- attention_interface.py: Interface extensions for TRT-LLM backend
- kvcache.py: Transform integration for cached attention
- llm_args.py: Configuration options for TRT-LLM backend

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@MrGeva MrGeva force-pushed the eg/trtllm_attn_v2_squashed branch from c4b169e to b0d1d1a Compare February 5, 2026 06:49
…ager

Fix garbage output when using thop.attention with AD's KVCacheManager by
correctly configuring the interleaved K/V block layout.

Key changes:
- TrtllmKVResourceHandler now extends PagedResourceHandler for proper
  KVCacheManager integration with HND layout support
- Configure KVCacheManager to use SELF cache type (kv_factor=2) when
  handlers request HND layout, avoiding memory-doubling copies
- Fix pool pointers: K ptr = AD's base address, V ptr = 0 (kernel uses
  block offsets to locate V)
- Fix pool mapping: Use AD's layer offsets directly
- Fix block offsets: Use multiplier = num_layers * kv_factor (64) with
  K = base_offsets and V = base_offsets + 1 for interleaved layout

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@MrGeva MrGeva force-pushed the eg/trtllm_attn_v2_squashed branch from 65a72d7 to 74de291 Compare February 5, 2026 12:53
…sizes

Fix tensor size mismatch error when attention ops are created with a smaller
max_batch_size (default 64) but the actual batch size from AD's cache config
is larger (e.g., 384 for Nano).

Changes:
- _init_shared_tensors now supports reallocation if current size < requested
- _init_cpu_buffers and _init_gpu_buffers also support reallocation
- get_or_create_layer_state re-links layer states after reallocation
- _prepare_trtllm_metadata checks at runtime if reallocation is needed

This fixes: RuntimeError: The size of tensor a (64) must match the size of
tensor b (384) at non-singleton dimension 0

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant