2222from transformers import (
2323 AutoConfig ,
2424 AutoProcessor ,
25+ DynamicCache ,
26+ EncoderDecoderCache ,
2527 PreTrainedModel ,
2628 StaticCache ,
2729 T5ForConditionalGeneration ,
2830 WhisperForConditionalGeneration ,
2931)
30- from transformers .generation .configuration_utils import GenerationConfig
31- from transformers .integrations .executorch import TorchExportableModuleForDecoderOnlyLM , sdpa_mask_without_vmap
32+ from transformers .integrations .executorch import (
33+ TorchExportableModuleForDecoderOnlyLM ,
34+ sdpa_mask_without_vmap ,
35+ )
3236from transformers .masking_utils import AttentionMaskInterface
3337from transformers .modeling_utils import AttentionInterface
3438
@@ -50,7 +54,10 @@ def prepare_export_inputs(self):
5054 {
5155 "role" : "user" ,
5256 "content" : [
53- {"type" : "image" , "url" : "https://llava-vl.github.io/static/images/view.jpg" },
57+ {
58+ "type" : "image" ,
59+ "url" : "https://llava-vl.github.io/static/images/view.jpg" ,
60+ },
5461 ],
5562 },
5663 ]
@@ -337,7 +344,10 @@ def export(
337344 mutated_gm ,
338345 args = (),
339346 # For the ET runner, it's important to have cache position as the 2nd arg.
340- kwargs = {"inputs_embeds" : inputs_embeds , "cache_position" : cache_position },
347+ kwargs = {
348+ "inputs_embeds" : inputs_embeds ,
349+ "cache_position" : cache_position ,
350+ },
341351 dynamic_shapes = dynamic_shapes ,
342352 strict = True ,
343353 )
@@ -400,7 +410,12 @@ class CausalLMExportableModule(torch.nn.Module):
400410 """
401411
402412 def __init__ (
403- self , model , max_seq_len = 2048 , use_custom_kv_cache = False , use_custom_sdpa = False , disable_dynamic_shapes = False
413+ self ,
414+ model ,
415+ max_seq_len = 2048 ,
416+ use_custom_kv_cache = False ,
417+ use_custom_sdpa = False ,
418+ disable_dynamic_shapes = False ,
404419 ):
405420 super ().__init__ ()
406421 self .model = model
@@ -497,7 +512,10 @@ def export(
497512
498513 with torch .no_grad ():
499514 exported_program = exportable_module .export (
500- input_ids = input_ids , cache_position = cache_position , dynamic_shapes = dynamic_shapes , strict = strict
515+ input_ids = input_ids ,
516+ cache_position = cache_position ,
517+ dynamic_shapes = dynamic_shapes ,
518+ strict = strict ,
501519 )
502520 # Apply RemoveTransposes pass to remove
503521 # any back-to-back transpose ops that are not needed
@@ -645,26 +663,38 @@ def __init__(self, model, max_static_cache_length, batch_size):
645663 self .proj_out = model .lm_head
646664 self .config = model .config
647665
648- # Initialize static cache
649- self .static_cache = StaticCache (
666+ # Initialize self attention cache
667+ self .self_attention_cache = StaticCache (
650668 config = self .config ,
651669 max_batch_size = batch_size ,
652670 max_cache_len = max_static_cache_length ,
653- device = "cpu" ,
671+ device = model . device ,
654672 dtype = torch .float32 ,
655673 )
656-
657- # Register cache buffers to make them exportable
658- for i in range (len (self .static_cache .key_cache )):
659- self .register_buffer (f"key_cache_{ i } " , self .static_cache .key_cache [i ], persistent = False )
660- self .register_buffer (f"value_cache_{ i } " , self .static_cache .value_cache [i ], persistent = False )
674+ head_dim = getattr (self .config , "head_dim" , self .config .hidden_size // self .config .num_attention_heads )
675+ num_heads = getattr (self .config , "num_key_value_heads" , self .config .num_attention_heads )
676+ self .self_attention_cache .early_initialization (batch_size , num_heads , head_dim , torch .float32 , model .device )
677+
678+ # Initialize cross attention cache
679+ self .dynamic_cache = DynamicCache (config = self .config )
680+ self .cache = EncoderDecoderCache (self .self_attention_cache , self .dynamic_cache )
681+
682+ # Register cache buffers to make them exportable.
683+ # Cross attention cache buffer is not registered since it's not actually being used atm.
684+ for i in range (len (self .self_attention_cache )):
685+ self .register_buffer (
686+ f"self_attention_key_cache_{ i } " , self .self_attention_cache .layers [i ].keys , persistent = False
687+ )
688+ self .register_buffer (
689+ f"self_attention_value_cache_{ i } " , self .self_attention_cache .layers [i ].values , persistent = False
690+ )
661691
662692 def forward (self , decoder_input_ids , encoder_hidden_states , cache_position ):
663693 # Get outputs from decoder
664694 outputs = self .decoder (
665695 input_ids = decoder_input_ids ,
666696 encoder_hidden_states = encoder_hidden_states ,
667- past_key_values = self .static_cache ,
697+ past_key_values = self .cache ,
668698 use_cache = True ,
669699 cache_position = cache_position ,
670700 )
@@ -679,26 +709,18 @@ def __init__(
679709 self ,
680710 model : PreTrainedModel ,
681711 batch_size = 1 ,
682- max_hidden_seq_length = 4096 ,
683- cache_implementation = "static" ,
684- max_cache_length = 1024 ,
712+ max_seq_len = 1024 ,
713+ max_hidden_seq_len = 4096 ,
685714 ):
686715 super ().__init__ ()
687716
688- self .full_model = model
717+ self .model = model
689718 self .encoder = model .get_encoder ()
690719 self .config = model .config
691- self .max_hidden_seq_length = max_hidden_seq_length
692- self .generation_config = GenerationConfig (
693- use_cache = True ,
694- max_length = max_cache_length ,
695- cache_implementation = cache_implementation ,
696- cache_config = {
697- "batch_size" : batch_size ,
698- "max_cache_len" : max_cache_length ,
699- },
700- )
701- if isinstance (self .full_model , WhisperForConditionalGeneration ):
720+ self .max_hidden_seq_len = max_hidden_seq_len
721+ self .batch_size = batch_size
722+ self .max_seq_len = max_seq_len
723+ if isinstance (self .model , WhisperForConditionalGeneration ):
702724 self ._processor = AutoProcessor .from_pretrained (model .config ._name_or_path )
703725 self ._expected_encoder_input_shape = torch .Size (
704726 (
@@ -707,33 +729,27 @@ def __init__(
707729 self ._processor .feature_extractor .nb_max_frames ,
708730 )
709731 )
710- additional_configs = {}
711- additional_configs ["max_hidden_seq_length" ] = max_hidden_seq_length
712732 # Metadata to be recorded in the pte model file
713- self .metadata = save_config_to_constant_methods (
714- self .config ,
715- self .generation_config ,
716- ** additional_configs ,
717- )
733+ self .metadata = save_config_to_constant_methods (self .config , get_max_seq_len = max_seq_len )
718734 self .exported_encoder = None
719735 self .exported_decoder = None
720736
721737 def _export_encoder (self , encoder_input_ids ):
722738 wrapped_encoder = Seq2SeqLMEncoderExportableModule (self .encoder ).to ("cpu" ).eval ()
723739
724740 # Define dynamic sequence length for encoder
725- if isinstance (self .full_model , WhisperForConditionalGeneration ):
741+ if isinstance (self .model , WhisperForConditionalGeneration ):
726742 assert (
727743 encoder_input_ids .shape == self ._expected_encoder_input_shape
728744 ), f"""This version of Whisper only accepts encoder input of shape { self ._expected_encoder_input_shape } , passed shape: { encoder_input_ids .shape } .
729745 For more infromation, please refer to the Whisper preprocessor config."""
730746 dynamic_shapes = None
731- elif isinstance (self .full_model , T5ForConditionalGeneration ):
732- encoder_seq_len_dim = torch .export .Dim ("encoder_hidden_seq_length" , max = self .max_hidden_seq_length )
747+ elif isinstance (self .model , T5ForConditionalGeneration ):
748+ encoder_seq_len_dim = torch .export .Dim ("encoder_hidden_seq_length" , max = self .max_hidden_seq_len )
733749 dynamic_shapes = {"input_ids" : {1 : encoder_seq_len_dim }}
734750 else :
735751 raise ValueError (
736- f"Unsupported model type { type (self .full_model )} for Seq2SeqLMExportableModule encoder export."
752+ f"Unsupported model type { type (self .model )} for Seq2SeqLMExportableModule encoder export."
737753 )
738754
739755 # Export the encoder
@@ -749,27 +765,27 @@ def _export_encoder(self, encoder_input_ids):
749765 def _export_decoder (self , decoder_input_ids , encoder_hidden_states , cache_position ):
750766 wrapped_decoder = (
751767 Seq2SeqLMDecoderExportableModuleWithStaticCache (
752- model = self .full_model ,
753- max_static_cache_length = self .generation_config . cache_config . get ( "max_cache_len" ) ,
754- batch_size = self .generation_config . cache_config . get ( " batch_size" ) ,
768+ model = self .model ,
769+ max_static_cache_length = self .max_seq_len ,
770+ batch_size = self .batch_size ,
755771 )
756772 .to ("cpu" )
757773 .eval ()
758774 )
759775
760- if isinstance (self .full_model , WhisperForConditionalGeneration ):
776+ if isinstance (self .model , WhisperForConditionalGeneration ):
761777 dynamic_shapes = None
762- elif isinstance (self .full_model , T5ForConditionalGeneration ):
778+ elif isinstance (self .model , T5ForConditionalGeneration ):
763779 # Define dynamic dimension for encoder output sequence length
764- encoder_seq_len_dim = torch .export .Dim ("encoder_hidden_seq_length" , max = self .max_hidden_seq_length )
780+ encoder_seq_len_dim = torch .export .Dim ("encoder_hidden_seq_length" , max = self .max_hidden_seq_len )
765781 dynamic_shapes = {
766782 "decoder_input_ids" : None ,
767783 "encoder_hidden_states" : {1 : encoder_seq_len_dim },
768784 "cache_position" : None ,
769785 }
770786 else :
771787 raise ValueError (
772- f"Unsupported model type { type (self .full_model )} for Seq2SeqLMExportableModule decoder export."
788+ f"Unsupported model type { type (self .model )} for Seq2SeqLMExportableModule decoder export."
773789 )
774790
775791 # Export the decoder
@@ -791,7 +807,7 @@ def export(
791807 cache_position = None ,
792808 ) -> Dict [str , ExportedProgram ]:
793809 if encoder_input_ids is None :
794- if isinstance (self .full_model , WhisperForConditionalGeneration ):
810+ if isinstance (self .model , WhisperForConditionalGeneration ):
795811 example_encoder_input_ids = torch .rand (self ._expected_encoder_input_shape )
796812 else :
797813 example_encoder_input_ids = torch .ones ((1 , 10 ), dtype = torch .long )
0 commit comments