Skip to content

Commit 4676609

Browse files
authored
Fix Whisper (#167)
Fixes Whisper export + misc `seq2seq`-related cleanups, which was broken a while back due to transformers changes Now runs on both portable and XNNPack.
1 parent 0e6054b commit 4676609

File tree

5 files changed

+104
-82
lines changed

5 files changed

+104
-82
lines changed

optimum/commands/export/executorch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,7 @@ def run(self):
185185
"--qlinear_packing_format can only be used when --device is set to CUDA (e.g., 'cuda', 'cuda:0', etc.)"
186186
)
187187
if not self.args.qlinear or self.args.qlinear != "4w":
188-
raise ValueError(
189-
"--qlinear_packing_format can only be used when --qlinear is set to '4w'"
190-
)
188+
raise ValueError("--qlinear_packing_format can only be used when --qlinear is set to '4w'")
191189
qlinear_encoder_packing_format = getattr(self.args, "qlinear_encoder_packing_format", None)
192190
if qlinear_encoder_packing_format:
193191
if not device or not device.startswith("cuda"):

optimum/executorch/modeling.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@
4040
from transformers.processing_utils import ProcessorMixin
4141
from transformers.utils import is_offline_mode
4242

43-
from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
43+
from executorch.extension.pybindings.portable_lib import (
44+
ExecuTorchModule,
45+
_load_for_executorch,
46+
)
4447
from executorch.kernels import quantized # noqa
4548

4649
from ..exporters import TasksManager
@@ -460,7 +463,7 @@ def __init__(
460463
if not hasattr(self, "encoder"):
461464
raise AttributeError("Expected attribute 'encoder' not found in the instance.")
462465
if not hasattr(self, "text_decoder"):
463-
raise AttributeError("Expected attribute 'decoder' not found in the instance.")
466+
raise AttributeError("Expected attribute 'text_decoder' not found in the instance.")
464467
metadata = self.decoder.method_names()
465468
if "use_kv_cache" in metadata:
466469
self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0]
@@ -495,7 +498,10 @@ def forward(
495498
encoder_outputs = self.encoder.forward((input_ids,))[0]
496499
self.stats.on_prompt_eval_end()
497500

498-
result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs)
501+
result = (
502+
self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0],
503+
encoder_outputs,
504+
)
499505
self.stats.on_model_execution_end()
500506
return result
501507

@@ -1022,29 +1028,27 @@ def __init__(
10221028
config: "PretrainedConfig",
10231029
):
10241030
super().__init__(models=models, config=config)
1025-
if not hasattr(self, "encoder"):
1026-
raise AttributeError("Expected attribute 'encoder' not found in the instance.")
1027-
if not hasattr(self, "text_decoder"):
1028-
raise AttributeError("Expected attribute 'decoder' not found in the instance.")
1029-
metadata = self.decoder.method_names()
1031+
if not hasattr(self, "model"):
1032+
raise AttributeError("Expected attribute 'model' not found in the instance.")
1033+
metadata = self.model.method_names()
10301034
if "use_kv_cache" in metadata:
1031-
self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0]
1035+
self.use_kv_cache = self.model.run_method("use_kv_cache")[0]
10321036
if "get_max_seq_len" in metadata:
1033-
self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0]
1037+
self.max_cache_size = self.model.run_method("get_max_seq_len")[0]
10341038
if "get_max_batch_size" in metadata:
1035-
self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0]
1039+
self.max_batch_size = self.model.run_method("get_max_batch_size")[0]
10361040
if "get_dtype" in metadata:
1037-
self.dtype = self.decoder.run_method("get_dtype")[0]
1041+
self.dtype = self.model.run_method("get_dtype")[0]
10381042
if "get_bos_id" in metadata:
1039-
self.bos_token_id = self.decoder.run_method("get_bos_id")[0]
1043+
self.bos_token_id = self.model.run_method("get_bos_id")[0]
10401044
if "get_eos_id" in metadata:
1041-
self.eos_token_id = self.decoder.run_method("get_eos_id")[0]
1045+
self.eos_token_id = self.model.run_method("get_eos_id")[0]
10421046
if "get_vocab_size" in metadata:
1043-
self.vocab_size = self.decoder.run_method("get_vocab_size")[0]
1047+
self.vocab_size = self.model.run_method("get_vocab_size")[0]
10441048
if "max_hidden_seq_length" in metadata:
1045-
self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0]
1049+
self.max_hidden_seq_length = self.model.run_method("max_hidden_seq_length")[0]
10461050
if "decoder_start_token_id" in metadata:
1047-
self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0]
1051+
self.decoder_start_token_id = self.model.run_method("decoder_start_token_id")[0]
10481052

10491053
def forward(
10501054
self,
@@ -1056,10 +1060,13 @@ def forward(
10561060
is_first_prediction = encoder_outputs is None
10571061
self.stats.on_model_execution_start()
10581062
if is_first_prediction:
1059-
encoder_outputs = self.encoder.forward((input_features,))[0]
1063+
encoder_outputs = self.model.run_method("encoder", (input_features,))[0]
10601064
self.stats.on_prompt_eval_end()
10611065

1062-
result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs)
1066+
result = (
1067+
self.model.run_method("text_decoder", (decoder_input_ids, encoder_outputs, cache_position))[0],
1068+
encoder_outputs,
1069+
)
10631070
self.stats.on_model_execution_end()
10641071
return result
10651072

@@ -1117,6 +1124,7 @@ def generate(
11171124
if not first_token_generated:
11181125
self.stats.on_first_token()
11191126
first_token_generated = True
1127+
11201128
# Get next token
11211129
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
11221130
generated_ids.append(next_token)

optimum/exporters/executorch/integrations.py

Lines changed: 65 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@
2222
from transformers import (
2323
AutoConfig,
2424
AutoProcessor,
25+
DynamicCache,
26+
EncoderDecoderCache,
2527
PreTrainedModel,
2628
StaticCache,
2729
T5ForConditionalGeneration,
2830
WhisperForConditionalGeneration,
2931
)
30-
from transformers.generation.configuration_utils import GenerationConfig
31-
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM, sdpa_mask_without_vmap
32+
from transformers.integrations.executorch import (
33+
TorchExportableModuleForDecoderOnlyLM,
34+
sdpa_mask_without_vmap,
35+
)
3236
from transformers.masking_utils import AttentionMaskInterface
3337
from transformers.modeling_utils import AttentionInterface
3438

@@ -50,7 +54,10 @@ def prepare_export_inputs(self):
5054
{
5155
"role": "user",
5256
"content": [
53-
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
57+
{
58+
"type": "image",
59+
"url": "https://llava-vl.github.io/static/images/view.jpg",
60+
},
5461
],
5562
},
5663
]
@@ -337,7 +344,10 @@ def export(
337344
mutated_gm,
338345
args=(),
339346
# For the ET runner, it's important to have cache position as the 2nd arg.
340-
kwargs={"inputs_embeds": inputs_embeds, "cache_position": cache_position},
347+
kwargs={
348+
"inputs_embeds": inputs_embeds,
349+
"cache_position": cache_position,
350+
},
341351
dynamic_shapes=dynamic_shapes,
342352
strict=True,
343353
)
@@ -400,7 +410,12 @@ class CausalLMExportableModule(torch.nn.Module):
400410
"""
401411

402412
def __init__(
403-
self, model, max_seq_len=2048, use_custom_kv_cache=False, use_custom_sdpa=False, disable_dynamic_shapes=False
413+
self,
414+
model,
415+
max_seq_len=2048,
416+
use_custom_kv_cache=False,
417+
use_custom_sdpa=False,
418+
disable_dynamic_shapes=False,
404419
):
405420
super().__init__()
406421
self.model = model
@@ -497,7 +512,10 @@ def export(
497512

498513
with torch.no_grad():
499514
exported_program = exportable_module.export(
500-
input_ids=input_ids, cache_position=cache_position, dynamic_shapes=dynamic_shapes, strict=strict
515+
input_ids=input_ids,
516+
cache_position=cache_position,
517+
dynamic_shapes=dynamic_shapes,
518+
strict=strict,
501519
)
502520
# Apply RemoveTransposes pass to remove
503521
# any back-to-back transpose ops that are not needed
@@ -645,26 +663,38 @@ def __init__(self, model, max_static_cache_length, batch_size):
645663
self.proj_out = model.lm_head
646664
self.config = model.config
647665

648-
# Initialize static cache
649-
self.static_cache = StaticCache(
666+
# Initialize self attention cache
667+
self.self_attention_cache = StaticCache(
650668
config=self.config,
651669
max_batch_size=batch_size,
652670
max_cache_len=max_static_cache_length,
653-
device="cpu",
671+
device=model.device,
654672
dtype=torch.float32,
655673
)
656-
657-
# Register cache buffers to make them exportable
658-
for i in range(len(self.static_cache.key_cache)):
659-
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
660-
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
674+
head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
675+
num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
676+
self.self_attention_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model.device)
677+
678+
# Initialize cross attention cache
679+
self.dynamic_cache = DynamicCache(config=self.config)
680+
self.cache = EncoderDecoderCache(self.self_attention_cache, self.dynamic_cache)
681+
682+
# Register cache buffers to make them exportable.
683+
# Cross attention cache buffer is not registered since it's not actually being used atm.
684+
for i in range(len(self.self_attention_cache)):
685+
self.register_buffer(
686+
f"self_attention_key_cache_{i}", self.self_attention_cache.layers[i].keys, persistent=False
687+
)
688+
self.register_buffer(
689+
f"self_attention_value_cache_{i}", self.self_attention_cache.layers[i].values, persistent=False
690+
)
661691

662692
def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
663693
# Get outputs from decoder
664694
outputs = self.decoder(
665695
input_ids=decoder_input_ids,
666696
encoder_hidden_states=encoder_hidden_states,
667-
past_key_values=self.static_cache,
697+
past_key_values=self.cache,
668698
use_cache=True,
669699
cache_position=cache_position,
670700
)
@@ -679,26 +709,18 @@ def __init__(
679709
self,
680710
model: PreTrainedModel,
681711
batch_size=1,
682-
max_hidden_seq_length=4096,
683-
cache_implementation="static",
684-
max_cache_length=1024,
712+
max_seq_len=1024,
713+
max_hidden_seq_len=4096,
685714
):
686715
super().__init__()
687716

688-
self.full_model = model
717+
self.model = model
689718
self.encoder = model.get_encoder()
690719
self.config = model.config
691-
self.max_hidden_seq_length = max_hidden_seq_length
692-
self.generation_config = GenerationConfig(
693-
use_cache=True,
694-
max_length=max_cache_length,
695-
cache_implementation=cache_implementation,
696-
cache_config={
697-
"batch_size": batch_size,
698-
"max_cache_len": max_cache_length,
699-
},
700-
)
701-
if isinstance(self.full_model, WhisperForConditionalGeneration):
720+
self.max_hidden_seq_len = max_hidden_seq_len
721+
self.batch_size = batch_size
722+
self.max_seq_len = max_seq_len
723+
if isinstance(self.model, WhisperForConditionalGeneration):
702724
self._processor = AutoProcessor.from_pretrained(model.config._name_or_path)
703725
self._expected_encoder_input_shape = torch.Size(
704726
(
@@ -707,33 +729,27 @@ def __init__(
707729
self._processor.feature_extractor.nb_max_frames,
708730
)
709731
)
710-
additional_configs = {}
711-
additional_configs["max_hidden_seq_length"] = max_hidden_seq_length
712732
# Metadata to be recorded in the pte model file
713-
self.metadata = save_config_to_constant_methods(
714-
self.config,
715-
self.generation_config,
716-
**additional_configs,
717-
)
733+
self.metadata = save_config_to_constant_methods(self.config, get_max_seq_len=max_seq_len)
718734
self.exported_encoder = None
719735
self.exported_decoder = None
720736

721737
def _export_encoder(self, encoder_input_ids):
722738
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval()
723739

724740
# Define dynamic sequence length for encoder
725-
if isinstance(self.full_model, WhisperForConditionalGeneration):
741+
if isinstance(self.model, WhisperForConditionalGeneration):
726742
assert (
727743
encoder_input_ids.shape == self._expected_encoder_input_shape
728744
), f"""This version of Whisper only accepts encoder input of shape {self._expected_encoder_input_shape}, passed shape: {encoder_input_ids.shape}.
729745
For more infromation, please refer to the Whisper preprocessor config."""
730746
dynamic_shapes = None
731-
elif isinstance(self.full_model, T5ForConditionalGeneration):
732-
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
747+
elif isinstance(self.model, T5ForConditionalGeneration):
748+
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_len)
733749
dynamic_shapes = {"input_ids": {1: encoder_seq_len_dim}}
734750
else:
735751
raise ValueError(
736-
f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule encoder export."
752+
f"Unsupported model type {type(self.model)} for Seq2SeqLMExportableModule encoder export."
737753
)
738754

739755
# Export the encoder
@@ -749,27 +765,27 @@ def _export_encoder(self, encoder_input_ids):
749765
def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
750766
wrapped_decoder = (
751767
Seq2SeqLMDecoderExportableModuleWithStaticCache(
752-
model=self.full_model,
753-
max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"),
754-
batch_size=self.generation_config.cache_config.get("batch_size"),
768+
model=self.model,
769+
max_static_cache_length=self.max_seq_len,
770+
batch_size=self.batch_size,
755771
)
756772
.to("cpu")
757773
.eval()
758774
)
759775

760-
if isinstance(self.full_model, WhisperForConditionalGeneration):
776+
if isinstance(self.model, WhisperForConditionalGeneration):
761777
dynamic_shapes = None
762-
elif isinstance(self.full_model, T5ForConditionalGeneration):
778+
elif isinstance(self.model, T5ForConditionalGeneration):
763779
# Define dynamic dimension for encoder output sequence length
764-
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
780+
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_len)
765781
dynamic_shapes = {
766782
"decoder_input_ids": None,
767783
"encoder_hidden_states": {1: encoder_seq_len_dim},
768784
"cache_position": None,
769785
}
770786
else:
771787
raise ValueError(
772-
f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule decoder export."
788+
f"Unsupported model type {type(self.model)} for Seq2SeqLMExportableModule decoder export."
773789
)
774790

775791
# Export the decoder
@@ -791,7 +807,7 @@ def export(
791807
cache_position=None,
792808
) -> Dict[str, ExportedProgram]:
793809
if encoder_input_ids is None:
794-
if isinstance(self.full_model, WhisperForConditionalGeneration):
810+
if isinstance(self.model, WhisperForConditionalGeneration):
795811
example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape)
796812
else:
797813
example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long)

optimum/exporters/executorch/tasks/asr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ def load_seq2seq_speech_model(model_name_or_path: str, **kwargs) -> Seq2SeqLMExp
4646
"""
4747
device = "cpu"
4848
batch_size = 1
49-
max_hidden_seq_length = kwargs.get("max_hidden_seq_length", 4096)
50-
max_cache_length = kwargs.get("max_cache_length", 1024)
49+
max_hidden_seq_len = kwargs.get("max_hidden_seq_len", 4096)
50+
max_seq_len = kwargs.get("max_seq_len", 1024)
5151

5252
full_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path).to(device).eval()
5353
return Seq2SeqLMExportableModule(
5454
full_model,
5555
batch_size=batch_size,
56-
max_hidden_seq_length=max_hidden_seq_length,
57-
max_cache_length=max_cache_length,
56+
max_seq_len=max_seq_len,
57+
max_hidden_seq_len=max_hidden_seq_len,
5858
)

tests/models/test_modeling_whisper.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def test_whisper_export_to_executorch(self):
4949
shell=True,
5050
check=True,
5151
)
52-
self.assertTrue(os.path.exists(f"{tempdir}/executorch/encoder.pte"))
53-
self.assertTrue(os.path.exists(f"{tempdir}/executorch/decoder.pte"))
52+
self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte"))
5453
model = ExecuTorchModelForSpeechSeq2Seq.from_pretrained(f"{tempdir}/executorch")
5554
self._test_whisper_transcription(model_id, model)
5655

@@ -59,16 +58,17 @@ def _test_whisper_transcription(self, model_id: str, model: ExecuTorchModelForSp
5958
processor = AutoProcessor.from_pretrained(model_id)
6059

6160
self.assertIsInstance(model, ExecuTorchModelForSpeechSeq2Seq)
62-
self.assertTrue(hasattr(model, "encoder"))
63-
self.assertIsInstance(model.encoder, ExecuTorchModule)
64-
self.assertTrue(hasattr(model, "text_decoder"))
65-
self.assertIsInstance(model.decoder, ExecuTorchModule)
61+
self.assertTrue(hasattr(model, "model"))
62+
self.assertIsInstance(model.model, ExecuTorchModule)
6663

6764
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
6865
sample = dataset[0]["audio"]
6966

7067
input_features = processor(
71-
sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"]
68+
sample["array"],
69+
return_tensors="pt",
70+
truncation=False,
71+
sampling_rate=sample["sampling_rate"],
7272
).input_features
7373
# Current implementation of the transcibe method accepts up to 30 seconds of audio, therefore I trim the audio here.
7474
input_features_trimmed = input_features[:, :, :3000].contiguous()

0 commit comments

Comments
 (0)