-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Add quantization support #163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
45ba769
573d852
54b594b
0328c1e
e3e4d69
cdde460
62565c4
276ad09
15ca0ac
c2a5bbf
8b0c2cf
188ffce
a0918a3
4d5140a
0bf0f0d
1b08105
ef5eaa2
a83b636
6e97a77
ea332ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,8 @@ | |
| import time | ||
| import torch | ||
| import copy | ||
|
|
||
| from vibevoice.utils.vram_utils import get_available_vram_gb, print_vram_info | ||
| from vibevoice.utils.quantization import get_quantization_config, apply_selective_quantization | ||
| from vibevoice.modular.modeling_vibevoice_streaming_inference import VibeVoiceStreamingForConditionalGenerationInference | ||
| from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor | ||
| from transformers.utils import logging | ||
|
|
@@ -129,6 +130,13 @@ def parse_args(): | |
| default=1.5, | ||
| help="CFG (Classifier-Free Guidance) scale for generation (default: 1.5)", | ||
| ) | ||
| parser.add_argument( | ||
| "--quantization", | ||
| type=str, | ||
| default="fp16", | ||
| choices=["fp16", "8bit", "4bit"], | ||
| help="Quantization level: fp16 (default, ~20GB), 8bit (~12GB), or 4bit (~7GB)" | ||
| ) | ||
|
|
||
| return parser.parse_args() | ||
|
|
||
|
|
@@ -146,6 +154,14 @@ def main(): | |
| args.device = "cpu" | ||
|
|
||
| print(f"Using device: {args.device}") | ||
|
|
||
| # VRAM Detection and Quantization Info (NEW) | ||
| if args.device == "cuda": | ||
| available_vram = get_available_vram_gb() | ||
| print_vram_info(available_vram, args.model_path, args.quantization) | ||
| elif args.quantization != "fp16": | ||
| print(f"Warning: Quantization ({args.quantization}) only works with CUDA. Using full precision.") | ||
| args.quantization = "fp16" | ||
maitrisavaliya marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Initialize voice mapper | ||
| voice_mapper = VoiceMapper() | ||
|
|
@@ -164,7 +180,7 @@ def main(): | |
| print("Error: No valid scripts found in the txt file") | ||
| return | ||
|
|
||
| full_script = scripts.replace("’", "'").replace('“', '"').replace('”', '"') | ||
| full_script = scripts.replace("'", "'").replace('"', '"').replace('"', '"') | ||
|
||
|
|
||
| print(f"Loading processor & model from {args.model_path}") | ||
| processor = VibeVoiceStreamingProcessor.from_pretrained(args.model_path) | ||
|
|
@@ -180,6 +196,15 @@ def main(): | |
| load_dtype = torch.float32 | ||
| attn_impl_primary = "sdpa" | ||
| print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}") | ||
|
|
||
| # Get quantization configuration (NEW) | ||
| quant_config = get_quantization_config(args.quantization) | ||
|
|
||
| if quant_config: | ||
| print(f"Using {args.quantization} quantization...") | ||
| else: | ||
| print("Using full precision (fp16)...") | ||
maitrisavaliya marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Load model with device-specific logic | ||
| try: | ||
| if args.device == "mps": | ||
|
|
@@ -191,12 +216,25 @@ def main(): | |
| ) | ||
| model.to("mps") | ||
| elif args.device == "cuda": | ||
| # MODIFIED SECTION - Add quantization support | ||
| model_kwargs = { | ||
| "torch_dtype": load_dtype, | ||
| "device_map": "cuda", | ||
| "attn_implementation": attn_impl_primary, | ||
| } | ||
|
|
||
| # Add quantization config if specified | ||
| if quant_config: | ||
| model_kwargs.update(quant_config) | ||
|
|
||
| model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( | ||
| args.model_path, | ||
| torch_dtype=load_dtype, | ||
| device_map="cuda", | ||
| attn_implementation=attn_impl_primary, | ||
| **model_kwargs | ||
| ) | ||
|
|
||
| # Apply selective quantization if needed (NEW) | ||
| if args.quantization in ["8bit", "4bit"]: | ||
| model = apply_selective_quantization(model, args.quantization) | ||
| else: # cpu | ||
| model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( | ||
| args.model_path, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| """Quantization utilities for VibeVoice models.""" | ||
|
|
||
| import logging | ||
| from typing import Optional | ||
| import torch | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def get_quantization_config(quantization: str = "fp16") -> Optional[dict]: | ||
| """ | ||
| Get quantization configuration for model loading. | ||
|
|
||
| Args: | ||
| quantization: Quantization level ("fp16", "8bit", or "4bit") | ||
|
|
||
| Returns: | ||
| dict: Quantization config for from_pretrained, or None for fp16 | ||
| """ | ||
| if quantization == "fp16" or quantization == "full": | ||
| return None | ||
|
|
||
| if quantization == "8bit": | ||
| try: | ||
| import bitsandbytes as bnb | ||
| logger.info("Using 8-bit quantization (selective LLM only)") | ||
| return { | ||
| "load_in_8bit": True, | ||
| "llm_int8_threshold": 6.0, | ||
| } | ||
| except ImportError: | ||
| logger.error( | ||
| "8-bit quantization requires bitsandbytes. " | ||
| "Install with: pip install bitsandbytes" | ||
| ) | ||
| raise | ||
|
|
||
| elif quantization == "4bit": | ||
| try: | ||
| import bitsandbytes as bnb | ||
| from transformers import BitsAndBytesConfig | ||
|
|
||
| logger.info("Using 4-bit NF4 quantization (selective LLM only)") | ||
| return { | ||
| "quantization_config": BitsAndBytesConfig( | ||
| load_in_4bit=True, | ||
| bnb_4bit_quant_type="nf4", | ||
| bnb_4bit_compute_dtype=torch.bfloat16, | ||
| bnb_4bit_use_double_quant=True, | ||
| ) | ||
| } | ||
| except ImportError: | ||
| logger.error( | ||
| "4-bit quantization requires bitsandbytes. " | ||
| "Install with: pip install bitsandbytes" | ||
| ) | ||
| raise | ||
|
|
||
| else: | ||
| raise ValueError( | ||
| f"Invalid quantization: {quantization}. " | ||
| f"Must be one of: fp16, 8bit, 4bit" | ||
| ) | ||
|
|
||
|
|
||
| def apply_selective_quantization(model, quantization: str): | ||
| """ | ||
| Apply selective quantization only to safe components. | ||
|
|
||
| This function identifies which modules should be quantized and which | ||
| should remain at full precision for audio quality preservation. | ||
|
|
||
| Args: | ||
| model: The VibeVoice model | ||
| quantization: Quantization level ("8bit" or "4bit") | ||
| """ | ||
| if quantization == "fp16": | ||
| return model | ||
|
|
||
| logger.info("Applying selective quantization...") | ||
|
|
||
| # Components to KEEP at full precision (audio-critical) | ||
| keep_fp_components = [ | ||
| "diffusion_head", | ||
| "acoustic_connector", | ||
| "semantic_connector", | ||
| "acoustic_tokenizer", | ||
| "semantic_tokenizer", | ||
| "vae", | ||
| ] | ||
|
|
||
| # Only quantize the LLM (Qwen2.5) component | ||
| quantize_components = ["llm", "language_model"] | ||
maitrisavaliya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| for name, module in model.named_modules(): | ||
| # Check if this module should stay at full precision | ||
| should_keep_fp = any(comp in name for comp in keep_fp_components) | ||
| should_quantize = any(comp in name for comp in quantize_components) | ||
|
|
||
| if should_keep_fp: | ||
| # Ensure audio components stay at full precision | ||
| if hasattr(module, 'weight') and module.weight.dtype != torch.float32: | ||
| module.weight.data = module.weight.data.to(torch.bfloat16) | ||
maitrisavaliya marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| logger.debug(f"Keeping {name} at full precision (audio-critical)") | ||
|
|
||
| elif should_quantize: | ||
| logger.debug(f"Quantized {name} to {quantization}") | ||
|
|
||
| logger.info(f"✓ Selective {quantization} quantization applied") | ||
| logger.info(" • LLM: Quantized") | ||
| logger.info(" • Audio components: Full precision") | ||
|
|
||
| return model | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| """VRAM detection and quantization recommendation utilities.""" | ||
|
|
||
| import torch | ||
| import logging | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
maitrisavaliya marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def get_available_vram_gb() -> float: | ||
| """ | ||
| Get available VRAM in GB. | ||
|
|
||
| Returns: | ||
| float: Available VRAM in GB, or 0 if no CUDA device available | ||
| """ | ||
| if not torch.cuda.is_available(): | ||
| return 0.0 | ||
|
|
||
| try: | ||
| # Get first CUDA device | ||
| device = torch.device("cuda:0") | ||
| # Get total and allocated memory | ||
| total = torch.cuda.get_device_properties(device).total_memory | ||
| allocated = torch.cuda.memory_allocated(device) | ||
| available = (total - allocated) / (1024 ** 3) # Convert to GB | ||
| return available | ||
| except Exception as e: | ||
| logger.warning(f"Could not detect VRAM: {e}") | ||
| return 0.0 | ||
|
|
||
|
|
||
| def suggest_quantization(available_vram_gb: float, model_name: str = "VibeVoice-7B") -> str: | ||
| """ | ||
| Suggest quantization level based on available VRAM. | ||
|
|
||
| Args: | ||
| available_vram_gb: Available VRAM in GB | ||
| model_name: Name of the model being loaded | ||
|
|
||
| Returns: | ||
| str: Suggested quantization level ("fp16", "8bit", or "4bit") | ||
| """ | ||
| # VibeVoice-7B memory requirements (approximate) | ||
| # Full precision (fp16/bf16): ~20GB | ||
| # 8-bit quantization: ~12GB | ||
| # 4-bit quantization: ~7GB | ||
|
|
||
| if "1.5B" in model_name: | ||
| # 1.5B model is smaller, adjust thresholds | ||
| if available_vram_gb >= 8: | ||
| return "fp16" | ||
| elif available_vram_gb >= 6: | ||
| return "8bit" | ||
| else: | ||
| return "4bit" | ||
| else: | ||
| # Assume 7B model | ||
| if available_vram_gb >= 22: | ||
| return "fp16" | ||
| elif available_vram_gb >= 14: | ||
| return "8bit" | ||
| else: | ||
| return "4bit" | ||
maitrisavaliya marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def print_vram_info(available_vram_gb: float, model_name: str, quantization: str = "fp16"): | ||
| """ | ||
| Print VRAM information and quantization recommendation. | ||
|
|
||
| Args: | ||
| available_vram_gb: Available VRAM in GB | ||
| model_name: Name of the model being loaded | ||
| quantization: Current quantization setting | ||
| """ | ||
| logger.info(f"Available VRAM: {available_vram_gb:.1f}GB") | ||
|
|
||
| suggested = suggest_quantization(available_vram_gb, model_name) | ||
|
|
||
| if suggested != quantization and quantization == "fp16": | ||
| logger.warning( | ||
| f"⚠️ Low VRAM detected ({available_vram_gb:.1f}GB). " | ||
| f"Recommended: --quantization {suggested}" | ||
| ) | ||
| logger.warning( | ||
| f" Example: python demo/inference_from_file.py " | ||
maitrisavaliya marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| f"--model_path {model_name} --quantization {suggested} ..." | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.