Remove ChunkedHybridCache from benchmark_inference.py#2733
Remove ChunkedHybridCache from benchmark_inference.py#2733IvanYashchuk wants to merge 3 commits intomainfrom
Conversation
for more information, see https://pre-commit.ci
|
@kshitij12345, @riccardofelluga could you please review the change? |
kshitij12345
left a comment
There was a problem hiding this comment.
With pjnl-20251113 (and transformers version 4.55.4), running python thunder/benchmarks/benchmark_inference.py --model-name meta-llama/Llama-4-Maverick-17B-128E --mode eager --input-length 1024 --output-length 32 --batch-size 1 --num-iterations 20 --num-layers 2
leads to
Warming up with 10 iterations...
Traceback (most recent call last):
File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 733, in <module>
main()
File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 722, in main
benchmark.run_benchmark()
File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 458, in run_benchmark
input_ids, past_key_values = self.generate_batch()
^^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 342, in generate_batch
past_key_values = StaticCache(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py", line 1451, in __init__
super().__init__(layer_classes=StaticLayer, *args, **kwargs)
File "/usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py", line 1110, in __init__
self.append_new_layers(self.num_hidden_layers - 1)
File "/usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py", line 1172, in append_new_layers
new_layer = new_layer_class(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: StaticLayer.__init__() missing 1 required positional argument: 'batch_size'| # Transformers deprecated HybridChunkedCache in favour of static in 4.55.x | ||
| past_key_values = StaticCache( | ||
| config=self.hf_config, | ||
| max_batch_size=input_ids.shape[0], |
There was a problem hiding this comment.
Looking at the error here, I think max_batch_size is required.
There was a problem hiding this comment.
Thank you for running it with transformers version 4.55.4! I was running with the latest release. Need to update the requirements pin first before merging this change.
| max_batch_size=input_ids.shape[0], | ||
| max_cache_len=input_ids.shape[1] + self.config.output_length, | ||
| device=DEVICE, | ||
| dtype=torch.bfloat16, |
There was a problem hiding this comment.
Also, device and dtype seem necessary -
from transformers.cache_utils import StaticCache
from transformers import AutoConfig, AutoModelForCausalLM
import torch
model_id = "meta-llama/Llama-4-Maverick-17B-128E"
config = AutoConfig.from_pretrained(model_id)
if hasattr(config, "text_config"):
config = config.text_config
config.num_hidden_layers = 2
past_key_values = StaticCache(config=config, max_batch_size=1, max_cache_len=256)
print(past_key_values.layers[0].keys.dtype) # torch.float32
print(past_key_values.layers[0].keys.device) # cpu
past_key_values = StaticCache(config=config, max_batch_size=1, max_cache_len=256, dtype=torch.bfloat16, device="cuda")
print(past_key_values.layers[0].keys.dtype) # torch.bfloat16
print(past_key_values.layers[0].keys.device) # cuda:0
riccardofelluga
left a comment
There was a problem hiding this comment.
Good idea to move on to the StaticCache. Just need couple of fixed on the args of the object.
Does perf improve?
| dummy_key_states = torch.empty(1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE) | ||
| past_key_values.initialise_cache_layer(layer_idx, dummy_key_states) | ||
| past_key_values = StaticCache( | ||
| config=self.hf_config, |
There was a problem hiding this comment.
| config=self.hf_config, | |
| config=self.hf_config, | |
| max_batch_size=input_ids.shape[0], |
| past_key_values.initialise_cache_layer(layer_idx, dummy_key_states) | ||
| past_key_values = StaticCache( | ||
| config=self.hf_config, | ||
| max_cache_len=input_ids.shape[1] + self.config.output_length, |
There was a problem hiding this comment.
Also device and dtype seem to be required:
RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_bmm)| max_cache_len=input_ids.shape[1] + self.config.output_length, | |
| max_cache_len=input_ids.shape[1] + self.config.output_length, | |
| device=DEVICE, | |
| dtype=torch.bfloat16, |
It's not moving on, it's already used because of |
| for layer_idx in range(self.hf_config.num_hidden_layers): | ||
| # key_states.shape[1] is used to retrieve the number of key value heads, all other dimensions can be 1 and ignored | ||
| # https://github.com/huggingface/transformers/blob/9300728665aaeb0ebf4db99f9d9fbce916b4a183/src/transformers/cache_utils.py#L1822 | ||
| dummy_key_states = torch.empty(1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE) |
There was a problem hiding this comment.
We also need to preserve hf_config.num_key_value_heads // WORLD_SIZE for distributed setting.
The patch can be something like the following
diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py
index 212f5f8e..13af8175 100644
--- a/thunder/benchmarks/benchmark_inference.py
+++ b/thunder/benchmarks/benchmark_inference.py
@@ -339,9 +339,15 @@ class InferenceBenchmark:
input_length = self.config.input_length
input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE)
+ import copy
+ hf_config = copy.copy(self.hf_config)
+ hf_config.num_key_value_heads //= WORLD_SIZE
past_key_values = StaticCache(
- config=self.hf_config,
+ config=hf_config,
max_cache_len=input_ids.shape[1] + self.config.output_length,
+ max_batch_size=batch_size,
+ dtype=torch.bfloat16,
+ device=DEVICE,
)
return input_ids, past_key_values
ChunkedHybridCache used in the inference benchmark is deprecated and should be replaced with StaticCache (https://github.com/huggingface/transformers/blob/ce40ca0d4c7d2e0a3f8bd3ddc30f29c6a105efb5/src/transformers/cache_utils.py#L1356).
This PR also removes unused keyword arguments when initializing StaticCache.
cc @crcrpar