Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
45ba769
Add troubleshooting guide for common installation and usage issues
Dec 9, 2025
573d852
Add quantization support to reduce VRAM requirements
Dec 10, 2025
54b594b
Add quantization support to reduce VRAM requirements
Dec 10, 2025
0328c1e
Merge branch 'microsoft:main' into add-quantization-support
maitrisavaliya Dec 10, 2025
e3e4d69
Delete utils/quantization,py
maitrisavaliya Dec 10, 2025
cdde460
Update realtime_model_inference_from_file.py
maitrisavaliya Dec 10, 2025
62565c4
Delete TROUBLESHOOTING.md
maitrisavaliya Dec 10, 2025
276ad09
Update realtime_model_inference_from_file.py
maitrisavaliya Dec 10, 2025
15ca0ac
Update realtime_model_inference_from_file.py
maitrisavaliya Dec 10, 2025
c2a5bbf
Update vram_utils.py
maitrisavaliya Dec 10, 2025
8b0c2cf
Merge branch 'main' into add-quantization-support
maitrisavaliya Dec 17, 2025
188ffce
Apply suggestion from @Copilot
maitrisavaliya Feb 13, 2026
a0918a3
Apply suggestion from @Copilot
maitrisavaliya Feb 13, 2026
4d5140a
Apply suggestion from @Copilot
maitrisavaliya Feb 13, 2026
0bf0f0d
Apply suggestion from @Copilot
maitrisavaliya Feb 13, 2026
1b08105
Apply suggestion from @Copilot
maitrisavaliya Feb 13, 2026
ef5eaa2
Clarify help message for quantization option
maitrisavaliya Feb 13, 2026
a83b636
Refactor VRAM utility functions and logging
maitrisavaliya Feb 13, 2026
6e97a77
Enhance CUDA device handling and add VRAM info
maitrisavaliya Feb 13, 2026
ea332ac
Fix device_map condition for model loading
maitrisavaliya Feb 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions demo/realtime_model_inference_from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import time
import torch
import copy

from vibevoice.utils.vram_utils import get_available_vram_gb, print_vram_info
from vibevoice.utils.quantization import get_quantization_config, apply_selective_quantization
from vibevoice.modular.modeling_vibevoice_streaming_inference import VibeVoiceStreamingForConditionalGenerationInference
from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor
from transformers.utils import logging
Expand Down Expand Up @@ -129,6 +130,13 @@ def parse_args():
default=1.5,
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.5)",
)
parser.add_argument(
"--quantization",
type=str,
default="fp16",
choices=["fp16", "8bit", "4bit"],
help="Quantization level: fp16 (default, ~20GB), 8bit (~12GB), or 4bit (~7GB)"
)

return parser.parse_args()

Expand All @@ -146,6 +154,14 @@ def main():
args.device = "cpu"

print(f"Using device: {args.device}")

# VRAM Detection and Quantization Info (NEW)
if args.device == "cuda":
available_vram = get_available_vram_gb()
print_vram_info(available_vram, args.model_path, args.quantization)
elif args.quantization != "fp16":
print(f"Warning: Quantization ({args.quantization}) only works with CUDA. Using full precision.")
args.quantization = "fp16"

# Initialize voice mapper
voice_mapper = VoiceMapper()
Expand All @@ -164,7 +180,7 @@ def main():
print("Error: No valid scripts found in the txt file")
return

full_script = scripts.replace("", "'").replace('', '"').replace('', '"')
full_script = scripts.replace("'", "'").replace('"', '"').replace('"', '"')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pay attention to your code agent. DO NOT introduce bugs like this.

Copy link
Author

@maitrisavaliya maitrisavaliya Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will pay attention to this, and I have corrected it.
What are your thoughts on the quantization approach? Is it going in right direction or Should I change something?


print(f"Loading processor & model from {args.model_path}")
processor = VibeVoiceStreamingProcessor.from_pretrained(args.model_path)
Expand All @@ -180,6 +196,15 @@ def main():
load_dtype = torch.float32
attn_impl_primary = "sdpa"
print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")

# Get quantization configuration (NEW)
quant_config = get_quantization_config(args.quantization)

if quant_config:
print(f"Using {args.quantization} quantization...")
else:
print("Using full precision (fp16)...")

# Load model with device-specific logic
try:
if args.device == "mps":
Expand All @@ -191,12 +216,25 @@ def main():
)
model.to("mps")
elif args.device == "cuda":
# MODIFIED SECTION - Add quantization support
model_kwargs = {
"torch_dtype": load_dtype,
"device_map": "cuda",
"attn_implementation": attn_impl_primary,
}

# Add quantization config if specified
if quant_config:
model_kwargs.update(quant_config)

model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl_primary,
**model_kwargs
)

# Apply selective quantization if needed (NEW)
if args.quantization in ["8bit", "4bit"]:
model = apply_selective_quantization(model, args.quantization)
else: # cpu
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
Expand Down
Empty file added utils/__init__.py
Empty file.
113 changes: 113 additions & 0 deletions utils/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Quantization utilities for VibeVoice models."""

import logging
from typing import Optional
import torch

logger = logging.getLogger(__name__)


def get_quantization_config(quantization: str = "fp16") -> Optional[dict]:
"""
Get quantization configuration for model loading.

Args:
quantization: Quantization level ("fp16", "8bit", or "4bit")

Returns:
dict: Quantization config for from_pretrained, or None for fp16
"""
if quantization == "fp16" or quantization == "full":
return None

if quantization == "8bit":
try:
import bitsandbytes as bnb
logger.info("Using 8-bit quantization (selective LLM only)")
return {
"load_in_8bit": True,
"llm_int8_threshold": 6.0,
}
except ImportError:
logger.error(
"8-bit quantization requires bitsandbytes. "
"Install with: pip install bitsandbytes"
)
raise

elif quantization == "4bit":
try:
import bitsandbytes as bnb
from transformers import BitsAndBytesConfig

logger.info("Using 4-bit NF4 quantization (selective LLM only)")
return {
"quantization_config": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
}
except ImportError:
logger.error(
"4-bit quantization requires bitsandbytes. "
"Install with: pip install bitsandbytes"
)
raise

else:
raise ValueError(
f"Invalid quantization: {quantization}. "
f"Must be one of: fp16, 8bit, 4bit"
)


def apply_selective_quantization(model, quantization: str):
"""
Apply selective quantization only to safe components.

This function identifies which modules should be quantized and which
should remain at full precision for audio quality preservation.

Args:
model: The VibeVoice model
quantization: Quantization level ("8bit" or "4bit")
"""
if quantization == "fp16":
return model

logger.info("Applying selective quantization...")

# Components to KEEP at full precision (audio-critical)
keep_fp_components = [
"diffusion_head",
"acoustic_connector",
"semantic_connector",
"acoustic_tokenizer",
"semantic_tokenizer",
"vae",
]

# Only quantize the LLM (Qwen2.5) component
quantize_components = ["llm", "language_model"]

for name, module in model.named_modules():
# Check if this module should stay at full precision
should_keep_fp = any(comp in name for comp in keep_fp_components)
should_quantize = any(comp in name for comp in quantize_components)

if should_keep_fp:
# Ensure audio components stay at full precision
if hasattr(module, 'weight') and module.weight.dtype != torch.float32:
module.weight.data = module.weight.data.to(torch.bfloat16)
logger.debug(f"Keeping {name} at full precision (audio-critical)")

elif should_quantize:
logger.debug(f"Quantized {name} to {quantization}")

logger.info(f"✓ Selective {quantization} quantization applied")
logger.info(" • LLM: Quantized")
logger.info(" • Audio components: Full precision")

return model
87 changes: 87 additions & 0 deletions utils/vram_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""VRAM detection and quantization recommendation utilities."""

import torch
import logging

logger = logging.getLogger(__name__)


def get_available_vram_gb() -> float:
"""
Get available VRAM in GB.

Returns:
float: Available VRAM in GB, or 0 if no CUDA device available
"""
if not torch.cuda.is_available():
return 0.0

try:
# Get first CUDA device
device = torch.device("cuda:0")
# Get total and allocated memory
total = torch.cuda.get_device_properties(device).total_memory
allocated = torch.cuda.memory_allocated(device)
available = (total - allocated) / (1024 ** 3) # Convert to GB
return available
except Exception as e:
logger.warning(f"Could not detect VRAM: {e}")
return 0.0


def suggest_quantization(available_vram_gb: float, model_name: str = "VibeVoice-7B") -> str:
"""
Suggest quantization level based on available VRAM.

Args:
available_vram_gb: Available VRAM in GB
model_name: Name of the model being loaded

Returns:
str: Suggested quantization level ("fp16", "8bit", or "4bit")
"""
# VibeVoice-7B memory requirements (approximate)
# Full precision (fp16/bf16): ~20GB
# 8-bit quantization: ~12GB
# 4-bit quantization: ~7GB

if "1.5B" in model_name:
# 1.5B model is smaller, adjust thresholds
if available_vram_gb >= 8:
return "fp16"
elif available_vram_gb >= 6:
return "8bit"
else:
return "4bit"
else:
# Assume 7B model
if available_vram_gb >= 22:
return "fp16"
elif available_vram_gb >= 14:
return "8bit"
else:
return "4bit"


def print_vram_info(available_vram_gb: float, model_name: str, quantization: str = "fp16"):
"""
Print VRAM information and quantization recommendation.

Args:
available_vram_gb: Available VRAM in GB
model_name: Name of the model being loaded
quantization: Current quantization setting
"""
logger.info(f"Available VRAM: {available_vram_gb:.1f}GB")

suggested = suggest_quantization(available_vram_gb, model_name)

if suggested != quantization and quantization == "fp16":
logger.warning(
f"⚠️ Low VRAM detected ({available_vram_gb:.1f}GB). "
f"Recommended: --quantization {suggested}"
)
logger.warning(
f" Example: python demo/inference_from_file.py "
f"--model_path {model_name} --quantization {suggested} ..."
)