Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union, final

import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator
from torch._ops import OpOverloadPacket
from torch.fx import Node
from torch.types import Number
Expand Down Expand Up @@ -283,6 +284,41 @@ def to(self, *args, **kwargs) -> None:
self._device_views = self._create_views(self._device_buffer)


class CacheConfig(BaseModel):
"""Cache configuration for attention-related dtypes."""

model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)

dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.")
mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.")
delta_dtype: Optional[torch.dtype] = Field(
default=torch.float32, description="Delta cache dtype. Defaults to float32."
)

@field_validator("dtype", "mamba_dtype", "delta_dtype", mode="before")
@classmethod
def _coerce_dtype(cls, value):
if value is None or isinstance(value, torch.dtype):
return value
if isinstance(value, str):
dtype = getattr(torch, value, None)
assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}"
return dtype
Comment on lines +301 to +309
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

Pydantic v2 field validator assert vs ValueError validation error handling

💡 Result:

In Pydantic v2 field validators (@field_validator / BeforeValidator / AfterValidator), you should signal “this input is invalid” by raising ValueError or AssertionError (including via assert ...). Pydantic catches those and aggregates them into a single ValidationError. You generally should not raise ValidationError yourself inside validation code. [1][2]

assert / AssertionError

  • Works and becomes a validation error (typically error type like assertion_error). [1][3]
  • Not safe for real validation logic because assert statements are skipped when Python runs with -O (optimize), meaning your validator checks can silently stop running. [4][1]

ValueError

  • The recommended/default way to fail validation (always executed regardless of -O). [1][2]
  • Produces a validation error (often value_error) that’s collected into the final ValidationError. [1][2]

What not to do (common gotcha)

  • If you raise other exception types (e.g., TypeError, IndexError, etc.), they are not treated as validation errors and can bubble up directly rather than being wrapped into a ValidationError. [5]

When you need richer error codes/messages

  • Use pydantic_core.PydanticCustomError for structured/custom error types and templated messages. [1]

Practical rule: use ValueError for user/data validation, reserve assert for developer invariants/tests (or only when you’re sure -O won’t be used). [4][2]

Sources: [1] https://docs.pydantic.dev/latest/concepts/validators/ , [2] https://docs.pydantic.dev/latest/errors/errors/ , [3] https://docs.pydantic.dev/2.0/usage/validation_errors/ , [4] https://docs.pydantic.dev/2.5/concepts/validators/ , [5] https://docs.pydantic.dev/2.3/usage/validators/


🏁 Script executed:

# Check if the file exists and view the actual context around the problematic code
fd -t f "attention_interface.py" --exec wc -l {} \;

Repository: NVIDIA/TensorRT-LLM

Length of output: 135


🏁 Script executed:

# Also check the imports to confirm this is using Pydantic v2
fd -t f "attention_interface.py" --exec head -50 {} \;

Repository: NVIDIA/TensorRT-LLM

Length of output: 2134


🏁 Script executed:

# View the specific lines mentioned in the review (301-309)
sed -n '301,309p' ./tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 453


🏁 Script executed:

# Also check a broader context to understand the class and imports
sed -n '280,320p' ./tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1739


🏁 Script executed:

# Check for the presence of copyright header in the file
head -20 ./tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1066


Use ValueError instead of assert for dtype validation in Pydantic field validator.

assert statements are optimized away when Python runs with the -O flag, causing the validation check to silently skip. Pydantic v2 field validators should raise ValueError (or AssertionError directly, but ValueError is preferred) to ensure invalid inputs are always rejected.

Suggested change
-            assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}"
-            return dtype
+            if not isinstance(dtype, torch.dtype):
+                raise ValueError(f"Invalid dtype string: {value!r}")
+            return dtype
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py` around
lines 301 - 309, The field validator method _coerce_dtype should not use assert
for invalid dtype checks because asserts can be skipped under -O; replace the
assert with raising a ValueError (e.g., raise ValueError(f"Invalid dtype:
{value!r}") or include the resolved dtype) so Pydantic always rejects bad
inputs; update the validation branch in the _coerce_dtype classmethod (decorated
with `@field_validator`("dtype", "mamba_dtype", "delta_dtype", mode="before")) to
raise ValueError when getattr(torch, value, None) is not a torch.dtype and
return the resolved torch.dtype otherwise.

return value

def __or__(self, other: "CacheConfig") -> "CacheConfig":
"""Combine two CacheConfig objects field-wise using Python's `or` semantics."""
if not isinstance(other, CacheConfig):
raise NotImplementedError(f"Cannot combine CacheConfig with {type(other)}")
merged_kwargs = {}
for field_name in type(self).model_fields.keys():
merged_kwargs[field_name] = getattr(self, field_name) or getattr(other, field_name)
return CacheConfig(**merged_kwargs)


class SequenceInfo:
"""An interface to hold information about how the sequence is laid out and stored in cache.

Expand Down Expand Up @@ -499,6 +535,12 @@ def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:

def _get_arg(self, name: str) -> torch.Tensor:
"""Get the argument from the input buffer either on device or host."""
# Handle special KV cache pool arguments for TRT-LLM attention
if name == "kv_cache_pool_pointers":
return self._kv_cache_pool_pointers
elif name == "kv_cache_pool_mapping":
return self._kv_cache_pool_mapping

if name.endswith("_host"):
arg = self._input_buffer.get_host_view(name.replace("_host", ""))
else:
Expand Down Expand Up @@ -1003,6 +1045,38 @@ def run_host_prepare_for_attention_forward(self) -> None:
for host_function, args in self._host_prepare_functions:
host_function(**{arg: self._get_arg(arg) for arg in args})

# KV Cache pool info for TRT-LLM attention (set by CacheInterface after KVCacheManager creation)
_kv_cache_pool_pointers: Optional[torch.Tensor] = None
_kv_cache_pool_mapping: Optional[torch.Tensor] = None

def set_kv_cache_pool_info(
self, pool_pointers: torch.Tensor, pool_mapping: torch.Tensor
) -> None:
"""Set KV cache pool pointers and mapping for TRT-LLM attention.

This is called by CacheInterface after KVCacheManager is created,
allowing TRT-LLM attention to use the same pool as AD's cache system.

Args:
pool_pointers: Pool pointer tensor from KVCacheManager.kv_cache_pool_pointers
pool_mapping: Layer to pool mapping from KVCacheManager.kv_cache_pool_mapping
"""
self._kv_cache_pool_pointers = pool_pointers
self._kv_cache_pool_mapping = pool_mapping
# Add to available_args so host prepare functions can request them
self._available_args.add("kv_cache_pool_pointers")
self._available_args.add("kv_cache_pool_mapping")

@property
def kv_cache_pool_pointers(self) -> Optional[torch.Tensor]:
"""Get KV cache pool pointers if set."""
return self._kv_cache_pool_pointers

@property
def kv_cache_pool_mapping(self) -> Optional[torch.Tensor]:
"""Get KV cache pool mapping if set."""
return self._kv_cache_pool_mapping


class ResourceHandler(ABC):
"""An abstract interface to handle a generic resource needed by attention operators.
Expand Down Expand Up @@ -1031,15 +1105,20 @@ class PagedResourceHandler(ManagedResourceHandler):
The PagedResourceHandler can be used to handle resources that support paging such as kv-caches.
"""

def __init__(self, *token_shape: int, dtype: torch.dtype) -> None:
def __init__(
self, *token_shape: int, dtype: torch.dtype, kv_layout: Literal["NHD", "HND"] = "NHD"
) -> None:
"""Initialize the PagedResourceHandler.

Args:
page_shape: The shape of a single page of the resource.
token_shape: The shape of a single token's worth of data in the resource.
dtype: The dtype of the resource.
kv_layout: Memory layout for KV cache. "NHD" = [blocks, tokens, kv_factor, heads, dim],
"HND" = [blocks, kv_factor, heads, tokens, dim]. Default is "NHD".
"""
self.token_shape = token_shape
self.dtype = dtype
self.kv_layout = kv_layout


class StateResourceHandler(ManagedResourceHandler):
Expand Down Expand Up @@ -1101,6 +1180,16 @@ def __call__(
) -> List[torch.Tensor]: ...


class GetCacheCallable(Protocol):
def __call__(self, sequence_info: SequenceInfo) -> torch.Tensor: ...


class GetBufferCallable(GetCacheCallable):
pass


CacheInitializerDict = Dict[str, GetCacheCallable]
BufferInitializerDict = Dict[str, GetBufferCallable]
AttentionLayout = Literal["bsnd", "bnsd"]

ResourceHandlerDict = Dict[str, ResourceHandler]
Expand Down
Loading