From 71e7d284d734e27344622de245df90da3328afdb Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 13 Nov 2025 15:11:53 +0200 Subject: [PATCH 1/3] Remove deprecated HybridChunkedCache from benchmark_inference.py --- thunder/benchmarks/benchmark_inference.py | 39 ++++++++--------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index 4639c5df4e..9daefa996b 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -32,7 +32,7 @@ from tqdm import tqdm import transformers from transformers import AutoConfig, AutoModelForCausalLM -from transformers.cache_utils import HybridChunkedCache, StaticCache +from transformers.cache_utils import StaticCache from transformers.models.llama4.modeling_llama4 import Llama4TextMoe from torch.distributed.tensor.placement_types import Shard from torch.distributed.tensor import DTensor @@ -335,35 +335,24 @@ def _load_model(self) -> torch.nn.Module: return model - def generate_batch(self) -> tuple[torch.Tensor, HybridChunkedCache]: + def generate_batch(self) -> tuple[torch.Tensor, StaticCache]: """Generate a batch of input tokens""" batch_size = self.config.batch_size input_length = self.config.input_length input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE) - if LooseVersion(transformers.__version__) >= LooseVersion("4.55"): - # Transformers deprecated HybridChunkedCache in favour of static in 4.55.x - past_key_values = StaticCache( - config=self.hf_config, - max_batch_size=input_ids.shape[0], - max_cache_len=input_ids.shape[1] + self.config.output_length, - device=DEVICE, - dtype=torch.bfloat16, - ) - else: - past_key_values = HybridChunkedCache( - self.hf_config, input_ids.shape[0], input_ids.shape[1] + self.config.output_length - ) - for layer_idx in range(self.hf_config.num_hidden_layers): - # key_states.shape[1] is used to retrieve the number of key value heads, all other dimensions can be 1 and ignored - # https://github.com/huggingface/transformers/blob/9300728665aaeb0ebf4db99f9d9fbce916b4a183/src/transformers/cache_utils.py#L1822 - dummy_key_states = torch.empty(1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE) - past_key_values.initialise_cache_layer(layer_idx, dummy_key_states) + past_key_values = StaticCache( + config=self.hf_config, + max_batch_size=input_ids.shape[0], + max_cache_len=input_ids.shape[1] + self.config.output_length, + device=DEVICE, + dtype=torch.bfloat16, + ) return input_ids, past_key_values def get_next_token( - self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache | StaticCache + self, input_ids: torch.Tensor, past_key_values: StaticCache ) -> torch.Tensor: start_pos = past_key_values.get_seq_length() cache_position = start_pos + torch.arange(0, input_ids.shape[1], device=start_pos.device, dtype=start_pos.dtype) @@ -376,7 +365,7 @@ def get_next_token( next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) return next_token - def prefill(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) -> torch.Tensor: + def prefill(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor: """ Prefill phase: Process the entire input prompt at once. Returns the next token. @@ -385,7 +374,7 @@ def prefill(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) """ return self.get_next_token(input_ids, past_key_values) - def decode_one_token(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) -> torch.Tensor: + def decode_one_token(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor: """ Decode phase: Generate a single token given the current sequence. Returns the next token. @@ -402,7 +391,7 @@ def decode_one_token(self, input_ids: torch.Tensor, past_key_values: HybridChunk # [rank1]: RuntimeError: Cannot set version_counter for inference tensor # @torch.inference_mode() def generate( - self, input_ids: torch.Tensor, max_new_tokens: int, past_key_values: HybridChunkedCache + self, input_ids: torch.Tensor, max_new_tokens: int, past_key_values: StaticCache ) -> dict[str, Any]: """ Generate tokens using separate prefill and decode phases. @@ -431,7 +420,7 @@ def generate( } def measure_inference_step( - self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache, max_new_tokens: int + self, input_ids: torch.Tensor, past_key_values: StaticCache, max_new_tokens: int ) -> dict[str, float]: """Measure a single inference step with detailed timing using separate prefill/decode""" # Generate tokens with separate prefill/decode tracking From ef6c46ba874d0a633757bf10860ed0be1c007be7 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 13 Nov 2025 15:14:10 +0200 Subject: [PATCH 2/3] Remove ignored kwargs from StaticCache construction --- thunder/benchmarks/benchmark_inference.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index 9daefa996b..f67136d712 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -343,10 +343,7 @@ def generate_batch(self) -> tuple[torch.Tensor, StaticCache]: input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE) past_key_values = StaticCache( config=self.hf_config, - max_batch_size=input_ids.shape[0], max_cache_len=input_ids.shape[1] + self.config.output_length, - device=DEVICE, - dtype=torch.bfloat16, ) return input_ids, past_key_values From 89273349ae4280f56f633d233aef2ab489921db7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Nov 2025 13:21:49 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/benchmarks/benchmark_inference.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index f67136d712..212f5f8e0d 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -22,7 +22,6 @@ import warnings from typing import Any from collections.abc import Callable -from looseversion import LooseVersion import torch import torch.distributed as dist @@ -30,7 +29,6 @@ from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel, ColwiseParallel from tqdm import tqdm -import transformers from transformers import AutoConfig, AutoModelForCausalLM from transformers.cache_utils import StaticCache from transformers.models.llama4.modeling_llama4 import Llama4TextMoe @@ -348,9 +346,7 @@ def generate_batch(self) -> tuple[torch.Tensor, StaticCache]: return input_ids, past_key_values - def get_next_token( - self, input_ids: torch.Tensor, past_key_values: StaticCache - ) -> torch.Tensor: + def get_next_token(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor: start_pos = past_key_values.get_seq_length() cache_position = start_pos + torch.arange(0, input_ids.shape[1], device=start_pos.device, dtype=start_pos.dtype) with torch.no_grad(): @@ -387,9 +383,7 @@ def decode_one_token(self, input_ids: torch.Tensor, past_key_values: StaticCache # [rank1]: ~^^^^^ # [rank1]: RuntimeError: Cannot set version_counter for inference tensor # @torch.inference_mode() - def generate( - self, input_ids: torch.Tensor, max_new_tokens: int, past_key_values: StaticCache - ) -> dict[str, Any]: + def generate(self, input_ids: torch.Tensor, max_new_tokens: int, past_key_values: StaticCache) -> dict[str, Any]: """ Generate tokens using separate prefill and decode phases. Returns detailed metrics for both phases.