@@ -460,29 +460,27 @@ def __init__(
460460 config : "PretrainedConfig" ,
461461 ):
462462 super ().__init__ (models = models , config = config )
463- if not hasattr (self , "encoder" ):
464- raise AttributeError ("Expected attribute 'encoder' not found in the instance." )
465- if not hasattr (self , "text_decoder" ):
466- raise AttributeError ("Expected attribute 'text_decoder' not found in the instance." )
467- metadata = self .decoder .method_names ()
463+ if not hasattr (self , "model" ):
464+ raise AttributeError ("Expected attribute 'model' not found in the instance." )
465+ metadata = self .model .method_names ()
468466 if "use_kv_cache" in metadata :
469- self .use_kv_cache = self .decoder .run_method ("use_kv_cache" )[0 ]
467+ self .use_kv_cache = self .model .run_method ("use_kv_cache" )[0 ]
470468 if "get_max_seq_len" in metadata :
471- self .max_cache_size = self .decoder .run_method ("get_max_seq_len" )[0 ]
469+ self .max_cache_size = self .model .run_method ("get_max_seq_len" )[0 ]
472470 if "get_max_batch_size" in metadata :
473- self .max_batch_size = self .decoder .run_method ("get_max_batch_size" )[0 ]
471+ self .max_batch_size = self .model .run_method ("get_max_batch_size" )[0 ]
474472 if "get_dtype" in metadata :
475- self .dtype = self .decoder .run_method ("get_dtype" )[0 ]
473+ self .dtype = self .model .run_method ("get_dtype" )[0 ]
476474 if "get_bos_id" in metadata :
477- self .bos_token_id = self .decoder .run_method ("get_bos_id" )[0 ]
475+ self .bos_token_id = self .model .run_method ("get_bos_id" )[0 ]
478476 if "get_eos_id" in metadata :
479- self .eos_token_id = self .decoder .run_method ("get_eos_id" )[0 ]
477+ self .eos_token_id = self .model .run_method ("get_eos_id" )[0 ]
480478 if "get_vocab_size" in metadata :
481- self .vocab_size = self .decoder .run_method ("get_vocab_size" )[0 ]
479+ self .vocab_size = self .model .run_method ("get_vocab_size" )[0 ]
482480 if "max_hidden_seq_length" in metadata :
483- self .max_hidden_seq_length = self .decoder .run_method ("max_hidden_seq_length" )[0 ]
481+ self .max_hidden_seq_length = self .model .run_method ("max_hidden_seq_length" )[0 ]
484482 if "decoder_start_token_id" in metadata :
485- self .decoder_start_token_id = self .decoder .run_method ("decoder_start_token_id" )[0 ]
483+ self .decoder_start_token_id = self .model .run_method ("decoder_start_token_id" )[0 ]
486484
487485 def forward (
488486 self ,
@@ -491,15 +489,14 @@ def forward(
491489 cache_position : torch .Tensor ,
492490 encoder_outputs : Optional [torch .Tensor ] = None ,
493491 ):
494- # Encode if needed (first prediction pass)
495492 is_first_prediction = encoder_outputs is None
496493 self .stats .on_model_execution_start ()
497494 if is_first_prediction :
498- encoder_outputs = self .encoder . forward ( (input_ids ,))[0 ]
495+ encoder_outputs = self .model . run_method ( "encoder" , (input_ids ,))[0 ]
499496 self .stats .on_prompt_eval_end ()
500497
501498 result = (
502- self .decoder . forward ( (decoder_input_ids , encoder_outputs , cache_position ))[0 ],
499+ self .model . run_method ( "text_decoder" , (decoder_input_ids , encoder_outputs , cache_position ))[0 ],
503500 encoder_outputs ,
504501 )
505502 self .stats .on_model_execution_end ()
@@ -530,9 +527,6 @@ def generate(
530527 Returns:
531528 List[int]: List of generated token IDs.
532529
533- Note:
534- Temporarily implemented this method in Python due to limited access to ExecuTorch's c++ LLM runner via pybind.
535- Expect improvements to the pybind interface in ExecuTorch version 0.4.1.
536530 """
537531 self .device = torch .device ("cpu" )
538532 if max_seq_len is None :
@@ -550,7 +544,6 @@ def generate(
550544 encoder_input_ids = input_ids
551545 encoder_outputs = None
552546 generated_ids = [0 ]
553-
554547 first_token_generated = False
555548
556549 # Generate tokens one by one
0 commit comments