@@ -402,9 +402,12 @@ def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None:
402402 # get kv cache stats for both model and draft model
403403 kv_stats = py_executor .resource_manager .resource_managers .get (
404404 ResourceManagerType .KV_CACHE_MANAGER ).get_kv_cache_stats ()
405- kv_stats_draft = py_executor .resource_manager .resource_managers .get (
406- ResourceManagerType .DRAFT_KV_CACHE_MANAGER ).get_kv_cache_stats (
407- ) if self ._draft_model_engine is not None else None
405+ # Get draft KV cache stats if present (either from two-model mode or one-model
406+ # mode with separate draft KV cache)
407+ draft_kv_cache_manager = py_executor .resource_manager .resource_managers .get (
408+ ResourceManagerType .DRAFT_KV_CACHE_MANAGER )
409+ kv_stats_draft = draft_kv_cache_manager .get_kv_cache_stats (
410+ ) if draft_kv_cache_manager is not None else None
408411
409412 # get total allocated bytes
410413 allocated_bytes = kv_stats .allocated_bytes + (
@@ -525,83 +528,42 @@ def _create_one_model_draft_kv_cache_manager(
525528 Create a KV cache manager for draft model layers in one-model mode
526529 when target and draft have different KV cache layouts.
527530 """
528- draft_pretrained_config = self ._draft_config .pretrained_config
529- quant_config = self ._draft_config .quant_config
530-
531- # Determine KV cache dtype
532- if quant_config is not None and quant_config .quant_mode .has_fp8_kv_cache (
533- ):
534- kv_cache_dtype = tensorrt_llm .bindings .DataType .FP8
535- elif quant_config is not None and quant_config .quant_mode .has_fp4_kv_cache (
536- ):
537- kv_cache_dtype = tensorrt_llm .bindings .DataType .NVFP4
538- else :
539- kv_cache_dtype = str_dtype_to_binding (
540- torch_dtype_to_str (self ._model_engine .dtype ))
541-
542- num_draft_layers = draft_pretrained_config .num_hidden_layers
543-
544531 # Get target model's num_hidden_layers to compute correct layer indices.
545532 # Draft model layers in one-model mode start at target_num_layers.
546533 target_pretrained_config = self ._model_engine .model .model_config .pretrained_config
547534 target_num_layers = target_pretrained_config .num_hidden_layers
535+ num_draft_layers = self ._draft_config .pretrained_config .num_hidden_layers
548536
549537 # Create layer_mask: False for target layers, True for draft layers.
550538 # This ensures the draft KV cache manager uses the correct layer indices
551539 # (e.g., layers 32, 33, ... instead of 0, 1, ...).
552540 layer_mask = [False ] * target_num_layers + [True ] * num_draft_layers
553541
554- if is_mla (draft_pretrained_config ):
555- # Draft uses MLA
556- return self ._kv_cache_manager_cls (
557- self ._kv_cache_config ,
558- tensorrt_llm .bindings .internal .batch_manager .CacheType .
559- SELFKONLY ,
560- num_layers = num_draft_layers ,
561- num_kv_heads = 1 ,
562- head_dim = draft_pretrained_config .kv_lora_rank +
563- draft_pretrained_config .qk_rope_head_dim ,
564- tokens_per_block = self ._tokens_per_block ,
565- max_seq_len = self ._max_seq_len ,
566- max_batch_size = self ._max_batch_size ,
567- mapping = self ._mapping ,
568- dtype = kv_cache_dtype ,
569- spec_config = self ._speculative_config ,
570- max_num_tokens = self ._max_num_tokens ,
571- is_draft = True ,
572- is_estimating_kv_cache = estimating_kv_cache ,
573- execution_stream = self ._execution_stream ,
574- layer_mask = layer_mask ,
575- )
576- else :
577- # Draft uses standard attention
578- hidden_size = draft_pretrained_config .hidden_size
579- num_attention_heads = draft_pretrained_config .num_attention_heads
580- num_key_value_heads = getattr (draft_pretrained_config ,
581- 'num_key_value_heads' ,
582- num_attention_heads )
583- head_dim = getattr (draft_pretrained_config , "head_dim" , None )
584- if not isinstance (head_dim , int ):
585- head_dim = hidden_size // num_attention_heads
586-
587- return self ._kv_cache_manager_cls (
588- self ._kv_cache_config ,
589- tensorrt_llm .bindings .internal .batch_manager .CacheType .SELF ,
590- num_layers = num_draft_layers ,
591- num_kv_heads = num_key_value_heads ,
592- head_dim = head_dim ,
593- tokens_per_block = self ._tokens_per_block ,
594- max_seq_len = self ._max_seq_len ,
595- max_batch_size = self ._max_batch_size ,
596- mapping = self ._mapping ,
597- dtype = kv_cache_dtype ,
598- spec_config = self ._speculative_config ,
599- max_num_tokens = self ._max_num_tokens ,
600- is_draft = True ,
601- is_estimating_kv_cache = estimating_kv_cache ,
602- execution_stream = self ._execution_stream ,
603- layer_mask = layer_mask ,
604- )
542+ # Get the appropriate KV cache manager class for the draft model
543+ draft_kv_cache_manager_cls = get_kv_cache_manager_cls (
544+ self ._draft_config )
545+
546+ return _create_kv_cache_manager (
547+ model_engine = None ,
548+ kv_cache_manager_cls = draft_kv_cache_manager_cls ,
549+ mapping = self ._mapping ,
550+ kv_cache_config = self ._kv_cache_config ,
551+ tokens_per_block = self ._tokens_per_block ,
552+ max_seq_len = self ._max_seq_len ,
553+ max_batch_size = self ._max_batch_size ,
554+ spec_config = self ._speculative_config ,
555+ sparse_attn_config = None , # Not applicable for draft in one-model mode
556+ max_num_tokens = self ._max_num_tokens ,
557+ max_beam_width = self ._max_beam_width ,
558+ kv_connector_manager = None , # Not supported for draft models
559+ estimating_kv_cache = estimating_kv_cache ,
560+ execution_stream = self ._execution_stream ,
561+ # One-model draft specific overrides
562+ model_config = self ._draft_config ,
563+ dtype = self ._model_engine .dtype ,
564+ is_draft = True ,
565+ layer_mask = layer_mask ,
566+ )
605567
606568 def build_managers (self ,
607569 resources : Dict ,
@@ -641,7 +603,7 @@ def teardown_managers(self, resources: Dict) -> None:
641603
642604
643605def _create_kv_cache_manager (
644- model_engine : PyTorchModelEngine ,
606+ model_engine : Optional [ PyTorchModelEngine ] ,
645607 kv_cache_manager_cls ,
646608 mapping : Mapping ,
647609 kv_cache_config : KvCacheConfig ,
@@ -654,13 +616,42 @@ def _create_kv_cache_manager(
654616 max_beam_width : int ,
655617 kv_connector_manager : Optional [KvCacheConnectorManager ],
656618 estimating_kv_cache : bool ,
657- execution_stream : Optional [torch .cuda .Stream ] = None ) -> KVCacheManager :
619+ execution_stream : Optional [torch .cuda .Stream ] = None ,
620+ # Optional overrides for one-model draft case (when model_engine is None)
621+ model_config : Optional [ModelConfig ] = None ,
622+ dtype : Optional [torch .dtype ] = None ,
623+ is_draft : Optional [bool ] = None ,
624+ layer_mask : Optional [List [bool ]] = None ) -> KVCacheManager :
658625 """
626+ Create a KVCacheManager instance.
627+
628+ Args:
629+ model_engine: The model engine (can be None if model_config is provided)
630+ model_config: Optional ModelConfig to use instead of extracting from model_engine
631+ dtype: Optional dtype override (required if model_engine is None)
632+ is_draft: Optional is_draft flag override (required if model_engine is None)
633+ layer_mask: Optional layer mask for one-model draft KV cache
634+
659635 Returns:
660- A KVCacheManager instance for the given model_engine
636+ A KVCacheManager instance
661637 """
662- config = model_engine .model .model_config .pretrained_config
663- quant_config = model_engine .model .model_config .quant_config
638+ # Extract config from model_engine or use provided model_config
639+ if model_config is not None :
640+ config = model_config .pretrained_config
641+ quant_config = model_config .quant_config
642+ _model_config = model_config
643+ else :
644+ config = model_engine .model .model_config .pretrained_config
645+ quant_config = model_engine .model .model_config .quant_config
646+ _model_config = model_engine .model .model_config
647+
648+ # Determine dtype
649+ if dtype is None :
650+ dtype = model_engine .dtype
651+
652+ # Determine is_draft
653+ if is_draft is None :
654+ is_draft = model_engine .is_draft_model
664655
665656 hidden_size = config .hidden_size
666657 num_attention_heads = config .num_attention_heads
@@ -676,8 +667,7 @@ def _create_kv_cache_manager(
676667 ):
677668 kv_cache_dtype = tensorrt_llm .bindings .DataType .NVFP4
678669 else :
679- kv_cache_dtype = str_dtype_to_binding (
680- torch_dtype_to_str (model_engine .dtype ))
670+ kv_cache_dtype = str_dtype_to_binding (torch_dtype_to_str (dtype ))
681671
682672 num_hidden_layers = config .num_hidden_layers
683673
@@ -695,12 +685,13 @@ def _create_kv_cache_manager(
695685 dtype = kv_cache_dtype ,
696686 spec_config = spec_config ,
697687 max_beam_width = max_beam_width ,
698- is_draft = model_engine . is_draft_model ,
688+ is_draft = is_draft ,
699689 kv_connector_manager = kv_connector_manager
700690 if not estimating_kv_cache else None ,
701691 sparse_attn_config = sparse_attn_config ,
702692 is_estimating_kv_cache = estimating_kv_cache ,
703693 execution_stream = execution_stream ,
694+ layer_mask = layer_mask ,
704695 )
705696 elif is_nemotron_hybrid (config ):
706697 if max_beam_width > 1 :
@@ -712,9 +703,10 @@ def _create_kv_cache_manager(
712703 "Connector manager is not supported for MambaHybridCacheManager."
713704 )
714705
715- config = model_engine .model .model_config .pretrained_config
716706 num_layers = config .hybrid_override_pattern .count ("*" )
717- layer_mask = [char == "*" for char in config .hybrid_override_pattern ]
707+ hybrid_layer_mask = [
708+ char == "*" for char in config .hybrid_override_pattern
709+ ]
718710 mamba_num_layers = config .hybrid_override_pattern .count ("M" )
719711 mamba_layer_mask = [
720712 char == "M" for char in config .hybrid_override_pattern
@@ -729,12 +721,13 @@ def _create_kv_cache_manager(
729721 mamba_num_layers ,
730722 mamba_layer_mask ,
731723 config .torch_dtype ,
732- model_engine .model .model_config .quant_config .mamba_ssm_cache_dtype ,
724+ quant_config .mamba_ssm_cache_dtype
725+ if quant_config is not None else None ,
733726 # kv cache parameters
734727 kv_cache_config ,
735728 tensorrt_llm .bindings .internal .batch_manager .CacheType .SELF ,
736729 num_layers = num_layers ,
737- layer_mask = layer_mask ,
730+ layer_mask = hybrid_layer_mask ,
738731 num_kv_heads = num_key_value_heads ,
739732 head_dim = head_dim ,
740733 tokens_per_block = tokens_per_block ,
@@ -755,13 +748,12 @@ def _create_kv_cache_manager(
755748 raise NotImplementedError (
756749 "Connector manager is not supported for MambaHybridCacheManager."
757750 )
758- config = model_engine .model .model_config .pretrained_config
759751 mamba_layer_mask = [
760752 True if i %
761753 config .full_attention_interval != config .full_attention_interval -
762754 1 else False for i in range (num_hidden_layers )
763755 ]
764- layer_mask = [
756+ hybrid_layer_mask = [
765757 False if i %
766758 config .full_attention_interval != config .full_attention_interval -
767759 1 else True for i in range (num_hidden_layers )
@@ -779,12 +771,13 @@ def _create_kv_cache_manager(
779771 num_mamba_layers ,
780772 mamba_layer_mask ,
781773 config .torch_dtype ,
782- model_engine .model .model_config .quant_config .mamba_ssm_cache_dtype ,
774+ quant_config .mamba_ssm_cache_dtype
775+ if quant_config is not None else None ,
783776 # kv cache parameters
784777 kv_cache_config ,
785778 tensorrt_llm .bindings .internal .batch_manager .CacheType .SELF ,
786779 num_layers = num_layers ,
787- layer_mask = layer_mask ,
780+ layer_mask = hybrid_layer_mask ,
788781 num_kv_heads = num_key_value_heads ,
789782 head_dim = head_dim ,
790783 tokens_per_block = tokens_per_block ,
@@ -800,7 +793,7 @@ def _create_kv_cache_manager(
800793 # NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
801794 is_vswa = kv_cache_config .max_attention_window is not None and len (
802795 set (kv_cache_config .max_attention_window )) > 1
803- binding_model_config = model_engine . model . model_config .get_bindings_model_config (
796+ binding_model_config = _model_config .get_bindings_model_config (
804797 tokens_per_block = tokens_per_block ) if is_vswa else None
805798
806799 kv_cache_manager = kv_cache_manager_cls (
@@ -818,12 +811,13 @@ def _create_kv_cache_manager(
818811 max_num_tokens = max_num_tokens ,
819812 model_config = binding_model_config ,
820813 max_beam_width = max_beam_width ,
821- is_draft = model_engine . is_draft_model ,
814+ is_draft = is_draft ,
822815 kv_connector_manager = kv_connector_manager
823816 if not estimating_kv_cache else None ,
824817 sparse_attn_config = sparse_attn_config ,
825818 is_estimating_kv_cache = estimating_kv_cache ,
826819 execution_stream = execution_stream ,
820+ layer_mask = layer_mask ,
827821 )
828822 return kv_cache_manager
829823
0 commit comments