Skip to content
This repository was archived by the owner on Aug 29, 2025. It is now read-only.

Commit 1d9528f

Browse files
committed
Version 3.1.4: Clean up ChatterBox crash prevention and rename padding parameter
1 parent 100b081 commit 1d9528f

File tree

9 files changed

+329
-33
lines changed

9 files changed

+329
-33
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [3.1.4] - 2025-07-18
9+
10+
### Added
11+
12+
- Clean up ChatterBox crash prevention and rename padding parameter
813
## [3.1.3] - 2025-07-18
914

1015
### Fixed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
[![Forks][forks-shield]][forks-url]
77
[![Dynamic TOML Badge][version-shield]][version-url]
88

9-
# ComfyUI ChatterBox SRT Voice (diogod) v3.1.3
9+
# ComfyUI ChatterBox SRT Voice (diogod) v3.1.4
1010

1111
*This is a refactored node, originally created by [ShmuelRonen](https://github.com/ShmuelRonen/ComfyUI_ChatterBox_Voice).*
1212

chatterbox_srt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
# Version info
7-
__version__ = "3.1.3"
7+
__version__ = "3.1.4"
88
__author__ = "Diogod"
99

1010
# Import the new SRT modules

core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
# Version info
7-
__version__ = "3.1.3"
7+
__version__ = "3.1.4"
88
__author__ = "Diogod"
99

1010
# Make imports available at package level

core/chatterbox_subprocess.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#!/usr/bin/env python3
2+
"""
3+
ChatterBox TTS Subprocess Wrapper
4+
5+
This script runs ChatterBox TTS generation in an isolated subprocess to prevent
6+
CUDA crashes from affecting the main ComfyUI process.
7+
8+
Usage:
9+
python chatterbox_subprocess.py --text "Hello world" --reference_audio "path/to/ref.wav" --output "output.wav"
10+
"""
11+
12+
import sys
13+
import os
14+
import argparse
15+
import json
16+
import traceback
17+
import tempfile
18+
import torch
19+
import numpy as np
20+
from pathlib import Path
21+
22+
# Add the project root to Python path
23+
project_root = Path(__file__).parent.parent
24+
sys.path.insert(0, str(project_root))
25+
26+
def main():
27+
parser = argparse.ArgumentParser(description='ChatterBox TTS Subprocess')
28+
parser.add_argument('--text', required=True, help='Text to synthesize')
29+
parser.add_argument('--reference_audio', required=True, help='Path to reference audio')
30+
parser.add_argument('--output', required=True, help='Output audio file path')
31+
parser.add_argument('--device', default='auto', help='Device to use (auto/cuda/cpu)')
32+
parser.add_argument('--exaggeration', type=float, default=0.5, help='Exaggeration factor')
33+
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature')
34+
parser.add_argument('--cfg_weight', type=float, default=0.5, help='CFG weight')
35+
parser.add_argument('--seed', type=int, default=0, help='Random seed')
36+
37+
args = parser.parse_args()
38+
39+
try:
40+
# Import ChatterBox TTS modules
41+
from chatterbox.chatterbox import ChatterboxTTS
42+
import torchaudio
43+
44+
print(f"🔄 Subprocess: Loading ChatterBox TTS on {args.device}...")
45+
46+
# Initialize ChatterBox TTS
47+
if args.device == 'auto':
48+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
49+
else:
50+
device = args.device
51+
52+
chatterbox = ChatterboxTTS.from_pretrained(device=device)
53+
54+
print(f"🔄 Subprocess: Loading reference audio from {args.reference_audio}...")
55+
56+
# Load reference audio
57+
ref_audio, sample_rate = torchaudio.load(args.reference_audio)
58+
59+
# Ensure mono audio
60+
if ref_audio.shape[0] > 1:
61+
ref_audio = ref_audio.mean(dim=0, keepdim=True)
62+
63+
# Resample if necessary
64+
if sample_rate != chatterbox.sr:
65+
resampler = torchaudio.transforms.Resample(sample_rate, chatterbox.sr)
66+
ref_audio = resampler(ref_audio)
67+
68+
print(f"🔄 Subprocess: Generating speech for text: '{args.text[:50]}...'")
69+
70+
# Set seed for reproducibility
71+
if args.seed != 0:
72+
torch.manual_seed(args.seed)
73+
if torch.cuda.is_available():
74+
torch.cuda.manual_seed(args.seed)
75+
76+
# Generate audio
77+
generated_audio = chatterbox.generate(
78+
text=args.text,
79+
reference_audio=ref_audio,
80+
exaggeration=args.exaggeration,
81+
temperature=args.temperature,
82+
cfg_weight=args.cfg_weight
83+
)
84+
85+
print(f"🔄 Subprocess: Saving audio to {args.output}...")
86+
87+
# Save output audio
88+
torchaudio.save(args.output, generated_audio.cpu(), chatterbox.sr)
89+
90+
# Return success info
91+
duration = generated_audio.size(-1) / chatterbox.sr
92+
result = {
93+
'success': True,
94+
'output_path': args.output,
95+
'duration': duration,
96+
'sample_rate': chatterbox.sr,
97+
'audio_shape': list(generated_audio.shape)
98+
}
99+
100+
print(f"✅ Subprocess: Generation completed successfully ({duration:.2f}s)")
101+
print(json.dumps(result))
102+
103+
except Exception as e:
104+
error_result = {
105+
'success': False,
106+
'error': str(e),
107+
'traceback': traceback.format_exc()
108+
}
109+
print(f"❌ Subprocess: Generation failed: {e}")
110+
print(json.dumps(error_result))
111+
sys.exit(1)
112+
113+
if __name__ == "__main__":
114+
main()

nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Version and constants
2-
VERSION = "3.1.3"
2+
VERSION = "3.1.4"
33
IS_DEV = False # Set to False for release builds
44
VERSION_DISPLAY = f"v{VERSION}" + (" (dev)" if IS_DEV else "")
55
SEPARATOR = "=" * 70

nodes/srt_tts_node.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tempfile
99
import os
1010
import hashlib
11+
import gc
1112
from typing import Dict, Any, Optional, List, Tuple
1213

1314
# Use direct file imports that work when loaded via importlib
@@ -153,6 +154,10 @@ def INPUT_TYPES(cls):
153154
"step": 0.5,
154155
"tooltip": "Maximum allowed deviation (in seconds) for timing adjustments in 'smart_natural' mode. Higher values allow more flexibility."
155156
}),
157+
"crash_protection_template": ("STRING", {
158+
"default": "hmm ,, {seg} hmm ,,",
159+
"tooltip": "Custom padding template for short text segments to prevent ChatterBox crashes. ChatterBox has a bug where text shorter than ~21 characters causes CUDA tensor errors in sequential generation. Use {seg} as placeholder for the original text. Examples: '...ummmmm {seg}' (default hesitation), '{seg}... yes... {seg}' (repetition), 'Well, {seg}' (natural prefix), or empty string to disable padding. This only affects ChatterBox nodes, not F5-TTS nodes."
160+
}),
156161
}
157162
}
158163

@@ -161,6 +166,52 @@ def INPUT_TYPES(cls):
161166
FUNCTION = "generate_srt_speech"
162167
CATEGORY = "ChatterBox Voice"
163168

169+
def _pad_short_text_for_chatterbox(self, text: str, padding_template: str = "...ummmmm {seg}", min_length: int = 21) -> str:
170+
"""
171+
Add custom padding to short text to prevent ChatterBox crashes.
172+
173+
ChatterBox has a bug where short text segments cause CUDA tensor indexing errors
174+
in sequential generation scenarios. Adding meaningful tokens with custom templates
175+
prevents these crashes while allowing user customization.
176+
177+
Args:
178+
text: Input text to check and pad if needed
179+
padding_template: Custom template with {seg} placeholder for original text
180+
min_length: Minimum text length threshold (default: 21 characters)
181+
182+
Returns:
183+
Original text or text with custom padding template if too short
184+
"""
185+
stripped_text = text.strip()
186+
if len(stripped_text) < min_length:
187+
# If template is empty, disable padding
188+
if not padding_template.strip():
189+
return text
190+
# Replace {seg} placeholder with original text
191+
return padding_template.replace("{seg}", stripped_text)
192+
return text
193+
194+
def _safe_generate_tts_audio(self, text, audio_prompt, exaggeration, temperature, cfg_weight):
195+
"""
196+
Wrapper around generate_tts_audio - simplified to just call the base method.
197+
CUDA crash recovery was removed as it didn't work reliably.
198+
"""
199+
try:
200+
return self.generate_tts_audio(text, audio_prompt, exaggeration, temperature, cfg_weight)
201+
except Exception as e:
202+
error_msg = str(e)
203+
is_cuda_crash = ("srcIndex < srcSelectDimSize" in error_msg or
204+
"CUDA" in error_msg or
205+
"device-side assert" in error_msg or
206+
"an illegal memory access" in error_msg)
207+
208+
if is_cuda_crash:
209+
print(f"🚨 ChatterBox CUDA crash detected: '{text[:50]}...'")
210+
print(f"🛡️ This is a known ChatterBox bug with certain text patterns.")
211+
raise RuntimeError(f"ChatterBox CUDA crash occurred. Text: '{text[:50]}...' - Try using padding template or longer text, or restart ComfyUI.")
212+
else:
213+
raise
214+
164215
def _generate_segment_cache_key(self, subtitle_text: str, exaggeration: float, temperature: float,
165216
cfg_weight: float, seed: int, audio_prompt_component: str,
166217
model_source: str, device: str) -> str:
@@ -199,7 +250,8 @@ def _detect_overlaps(self, subtitles: List) -> bool:
199250
def generate_srt_speech(self, srt_content, device, exaggeration, temperature, cfg_weight, seed,
200251
timing_mode, reference_audio=None, audio_prompt_path="",
201252
max_stretch_ratio=2.0, min_stretch_ratio=0.5, fade_for_StretchToFit=0.01,
202-
enable_audio_cache=True, timing_tolerance=2.0):
253+
enable_audio_cache=True, timing_tolerance=2.0,
254+
crash_protection_template="hmm ,, {seg} hmm ,,"):
203255

204256
def _process():
205257
# Check if SRT support is available
@@ -320,9 +372,16 @@ def _process():
320372
print(f"📺 Generating SRT segment {i+1}/{len(subtitles)} (Seq {subtitle.sequence})...")
321373
else:
322374
print(f"🎭 Generating SRT segment {i+1}/{len(subtitles)} (Seq {subtitle.sequence}) using '{char}'")
323-
# Generate new audio for this character segment
324-
char_wav = self.generate_tts_audio(
325-
segment_text, char_audio, exaggeration, temperature, cfg_weight
375+
# BUGFIX: Pad short text with custom template to prevent ChatterBox sequential generation crashes
376+
processed_segment_text = self._pad_short_text_for_chatterbox(segment_text, crash_protection_template)
377+
378+
# DEBUG: Show actual text being sent to ChatterBox when padding might occur
379+
if len(segment_text.strip()) < 21:
380+
print(f"🔍 DEBUG: Original text: '{segment_text}' → Processed: '{processed_segment_text}' (len: {len(processed_segment_text)})")
381+
382+
# Generate new audio for this character segment with CUDA recovery
383+
char_wav = self._safe_generate_tts_audio(
384+
processed_segment_text, char_audio, exaggeration, temperature, cfg_weight
326385
)
327386

328387
if enable_audio_cache:
@@ -354,8 +413,16 @@ def _process():
354413
# Generate new audio
355414
print(f"📺 Generating SRT segment {i+1}/{len(subtitles)} (Seq {subtitle.sequence})...")
356415

357-
wav = self.generate_tts_audio(
358-
subtitle.text, audio_prompt, exaggeration, temperature, cfg_weight
416+
# BUGFIX: Pad short text with custom template to prevent ChatterBox sequential generation crashes
417+
processed_subtitle_text = self._pad_short_text_for_chatterbox(subtitle.text, crash_protection_template)
418+
419+
# DEBUG: Show actual text being sent to ChatterBox when padding might occur
420+
if len(subtitle.text.strip()) < 21:
421+
print(f"🔍 DEBUG: Original text: '{subtitle.text}' → Processed: '{processed_subtitle_text}' (len: {len(processed_subtitle_text)})")
422+
423+
# Generate new audio with CUDA recovery
424+
wav = self._safe_generate_tts_audio(
425+
processed_subtitle_text, audio_prompt, exaggeration, temperature, cfg_weight
359426
)
360427
natural_duration = self.AudioTimingUtils.get_audio_duration(wav, self.tts_model.sr)
361428

0 commit comments

Comments
 (0)