Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions requirements/server.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@ librosa
openvino
openvino-genai
openvino-tokenizers
optimum
optimum-intel
optimum
optimum-intel

# mlx (for Apple Silicon M1/M2/M3)
mlx-whisper
7 changes: 6 additions & 1 deletion run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
parser.add_argument('--backend', '-b',
type=str,
default='faster_whisper',
help='Backends from ["tensorrt", "faster_whisper", "openvino"]')
help='Backends from ["tensorrt", "faster_whisper", "openvino", "mlx_whisper"]')
parser.add_argument('--faster_whisper_custom_model_path', '-fw',
type=str, default=None,
help="Custom Faster Whisper Model")
parser.add_argument('--mlx_model_path', '-mlx',
type=str,
default=None,
help='MLX Whisper model (e.g., "small.en", "mlx-community/whisper-large-v3-turbo")')
parser.add_argument('--trt_model_path', '-trt',
type=str,
default=None,
Expand Down Expand Up @@ -59,6 +63,7 @@
port=args.port,
backend=args.backend,
faster_whisper_custom_model_path=args.faster_whisper_custom_model_path,
mlx_model_path=args.mlx_model_path,
whisper_tensorrt_path=args.trt_model_path,
trt_multilingual=args.trt_multilingual,
trt_py_session=args.trt_py_session,
Expand Down
213 changes: 213 additions & 0 deletions whisper_live/backend/mlx_whisper_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""
MLX Whisper Backend for WhisperLive

This backend uses mlx-whisper for optimized transcription on Apple Silicon (M1/M2/M3).
MLX leverages Apple's Neural Engine and GPU for fast, efficient inference.
"""

import json
import logging
import threading
from typing import Optional

from whisper_live.transcriber.transcriber_mlx import WhisperMLX
from whisper_live.backend.base import ServeClientBase


class ServeClientMLXWhisper(ServeClientBase):
"""
Backend implementation for MLX Whisper on Apple Silicon.

This backend provides hardware-accelerated transcription using Apple's MLX
framework, optimized for M1/M2/M3 chips with Neural Engine and GPU support.
"""

SINGLE_MODEL = None
SINGLE_MODEL_LOCK = threading.Lock()

def __init__(
self,
websocket,
task: str = "transcribe",
language: Optional[str] = None,
client_uid: Optional[str] = None,
model: str = "small.en",
initial_prompt: Optional[str] = None,
vad_parameters: Optional[dict] = None,
use_vad: bool = True,
single_model: bool = False,
send_last_n_segments: int = 10,
no_speech_thresh: float = 0.45,
clip_audio: bool = False,
same_output_threshold: int = 7,
translation_queue=None,
):
"""
Initialize MLX Whisper backend.

Args:
websocket: WebSocket connection to the client
task (str): "transcribe" or "translate"
language (str, optional): Language code (e.g., "en", "es")
client_uid (str, optional): Unique client identifier
model (str): Model size or HuggingFace repo ID
initial_prompt (str, optional): Initial prompt for transcription
vad_parameters (dict, optional): VAD parameters (not used in MLX)
use_vad (bool): Whether to use VAD (not used in MLX)
single_model (bool): Share one model across all clients
send_last_n_segments (int): Number of recent segments to send
no_speech_thresh (float): Threshold for filtering silent segments
clip_audio (bool): Whether to clip audio with no valid segments
same_output_threshold (int): Threshold for repeated output filtering
translation_queue: Queue for translation (if enabled)
"""
super().__init__(
client_uid,
websocket,
send_last_n_segments,
no_speech_thresh,
clip_audio,
same_output_threshold,
translation_queue,
)

self.model_name = model
self.language = "en" if model.endswith(".en") else language
self.task = task
self.initial_prompt = initial_prompt
self.vad_parameters = vad_parameters or {"onset": 0.5}
self.use_vad = use_vad

logging.info(f"Initializing MLX Whisper backend")
logging.info(f"Model: {self.model_name}")
logging.info(f"Language: {self.language}")
logging.info(f"Task: {self.task}")
logging.info(f"Using Apple Neural Engine and GPU acceleration")

# Initialize model
try:
if single_model:
if ServeClientMLXWhisper.SINGLE_MODEL is None:
self.create_model()
ServeClientMLXWhisper.SINGLE_MODEL = self.transcriber
else:
self.transcriber = ServeClientMLXWhisper.SINGLE_MODEL
logging.info("Using shared MLX model instance")
else:
self.create_model()
except Exception as e:
logging.error(f"Failed to load MLX Whisper model: {e}")
self.websocket.send(
json.dumps({
"uid": self.client_uid,
"status": "ERROR",
"message": f"Failed to load MLX Whisper model: {str(e)}. "
f"Make sure you're running on Apple Silicon (M1/M2/M3) "
f"and have mlx-whisper installed.",
})
)
self.websocket.close()
return

# Start transcription thread
self.trans_thread = threading.Thread(target=self.speech_to_text)
self.trans_thread.start()

# Send ready message to client
self.websocket.send(
json.dumps({
"uid": self.client_uid,
"message": self.SERVER_READY,
"backend": "mlx_whisper",
})
)
logging.info(f"[MLX] Client {self.client_uid} initialized successfully")

def create_model(self):
"""
Initialize the MLX Whisper model.

This method loads the specified model using MLX for hardware acceleration
on Apple Silicon devices.
"""
logging.info(f"Loading MLX Whisper model: {self.model_name}")
self.transcriber = WhisperMLX(
model_name=self.model_name,
path_or_hf_repo=None,
)
logging.info("MLX Whisper model loaded successfully")

def transcribe_audio(self, input_sample):
"""
Transcribe audio using MLX Whisper.

Args:
input_sample (np.ndarray): Audio data as numpy array (16kHz)

Returns:
List[MLXSegment]: List of transcribed segments with timing information
"""
# Acquire lock if using shared model
if ServeClientMLXWhisper.SINGLE_MODEL:
ServeClientMLXWhisper.SINGLE_MODEL_LOCK.acquire()

try:
result = self.transcriber.transcribe(
input_sample,
language=self.language,
task=self.task,
initial_prompt=self.initial_prompt,
vad_filter=self.use_vad,
vad_parameters=self.vad_parameters,
)

# Auto-detect language if not set
if self.language is None and len(result) > 0:
detected_lang, prob = self.transcriber.detect_language(input_sample)
self.set_language(detected_lang)
logging.info(f"Detected language: {detected_lang} (probability: {prob:.2f})")

return result

except Exception as e:
logging.error(f"MLX transcription failed: {e}")
raise

finally:
# Release lock if using shared model
if ServeClientMLXWhisper.SINGLE_MODEL:
ServeClientMLXWhisper.SINGLE_MODEL_LOCK.release()

def handle_transcription_output(self, result, duration):
"""
Process and send transcription results to the client.

Args:
result: Transcription result from transcribe_audio()
duration (float): Duration of the audio chunk in seconds
"""
segments = []

if len(result):
last_segment = self.update_segments(result, duration)
segments = self.prepare_segments(last_segment)

if len(segments):
self.send_transcription_to_client(segments)

def set_language(self, language: str):
"""
Set the transcription language.

Args:
language (str): Language code (e.g., "en", "es", "fr")
"""
self.language = language
logging.info(f"Language set to: {language}")

def cleanup(self):
"""
Clean up resources when client disconnects.
"""
super().cleanup()
logging.info(f"[MLX] Client {self.client_uid} cleaned up")
57 changes: 50 additions & 7 deletions whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class BackendType(Enum):
FASTER_WHISPER = "faster_whisper"
TENSORRT = "tensorrt"
OPENVINO = "openvino"
MLX_WHISPER = "mlx_whisper"

@staticmethod
def valid_types() -> List[str]:
Expand All @@ -137,10 +138,13 @@ def is_faster_whisper(self) -> bool:

def is_tensorrt(self) -> bool:
return self == BackendType.TENSORRT

def is_openvino(self) -> bool:
return self == BackendType.OPENVINO

def is_mlx_whisper(self) -> bool:
return self == BackendType.MLX_WHISPER


class TranscriptionServer:
RATE = 16000
Expand All @@ -153,7 +157,7 @@ def __init__(self):

def initialize_client(
self, websocket, options, faster_whisper_custom_model_path,
whisper_tensorrt_path, trt_multilingual, trt_py_session=False,
mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=False,
):
client: Optional[ServeClientBase] = None

Expand Down Expand Up @@ -242,6 +246,42 @@ def initialize_client(
"Reverting to available backend: 'faster_whisper'"
}))

if self.backend.is_mlx_whisper():
try:
from whisper_live.backend.mlx_whisper_backend import ServeClientMLXWhisper
# Use server-provided model path if available, otherwise use client's model
if mlx_model_path is not None:
logging.info(f"Using server-specified MLX model: {mlx_model_path}")
options["model"] = mlx_model_path
client = ServeClientMLXWhisper(
websocket,
language=options["language"],
task=options["task"],
client_uid=options["uid"],
model=options["model"],
initial_prompt=options.get("initial_prompt"),
vad_parameters=options.get("vad_parameters"),
use_vad=self.use_vad,
single_model=self.single_model,
send_last_n_segments=options.get("send_last_n_segments", 10),
no_speech_thresh=options.get("no_speech_thresh", 0.45),
clip_audio=options.get("clip_audio", False),
same_output_threshold=options.get("same_output_threshold", 10),
translation_queue=translation_queue,
)
logging.info("Running MLX Whisper backend (Apple Silicon optimized).")
except Exception as e:
logging.error(f"MLX Whisper not supported: {e}")
self.backend = BackendType.FASTER_WHISPER
self.client_uid = options["uid"]
websocket.send(json.dumps({
"uid": self.client_uid,
"status": "WARNING",
"message": "MLX Whisper not supported. "
"Make sure you're on Apple Silicon (M1/M2/M3) and have mlx-whisper installed. "
"Reverting to available backend: 'faster_whisper'"
}))

try:
if self.backend.is_faster_whisper():
from whisper_live.backend.faster_whisper_backend import ServeClientFasterWhisper
Expand Down Expand Up @@ -297,7 +337,7 @@ def get_audio_from_websocket(self, websocket):
return np.frombuffer(frame_data, dtype=np.float32)

def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
whisper_tensorrt_path, trt_multilingual, trt_py_session=False):
mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=False):
try:
logging.info("New client connected")
options = websocket.recv()
Expand All @@ -311,7 +351,7 @@ def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
if self.backend.is_tensorrt():
self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
self.initialize_client(websocket, options, faster_whisper_custom_model_path,
whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session)
mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session)
return True
except json.JSONDecodeError:
logging.error("Failed to decode JSON from client")
Expand Down Expand Up @@ -343,9 +383,10 @@ def process_audio_frames(self, websocket):
return True

def recv_audio(self,
websocket,
websocket,
backend: BackendType = BackendType.FASTER_WHISPER,
faster_whisper_custom_model_path=None,
mlx_model_path=None,
whisper_tensorrt_path=None,
trt_multilingual=False,
trt_py_session=False):
Expand Down Expand Up @@ -375,7 +416,7 @@ def recv_audio(self,
"""
self.backend = backend
if not self.handle_new_connection(websocket, faster_whisper_custom_model_path,
whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session):
mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session):
return

try:
Expand All @@ -397,6 +438,7 @@ def run(self,
port=9090,
backend="tensorrt",
faster_whisper_custom_model_path=None,
mlx_model_path=None,
whisper_tensorrt_path=None,
trt_multilingual=False,
trt_py_session=False,
Expand All @@ -418,7 +460,7 @@ def run(self,
if whisper_tensorrt_path is not None and not os.path.exists(whisper_tensorrt_path):
raise ValueError(f"TensorRT model '{whisper_tensorrt_path}' is not a valid path.")
if single_model:
if faster_whisper_custom_model_path or whisper_tensorrt_path:
if faster_whisper_custom_model_path or whisper_tensorrt_path or mlx_model_path:
logging.info("Custom model option was provided. Switching to single model mode.")
self.single_model = True
# TODO: load model initially
Expand All @@ -431,6 +473,7 @@ def run(self,
self.recv_audio,
backend=BackendType(backend),
faster_whisper_custom_model_path=faster_whisper_custom_model_path,
mlx_model_path=mlx_model_path,
whisper_tensorrt_path=whisper_tensorrt_path,
trt_multilingual=trt_multilingual,
trt_py_session=trt_py_session,
Expand Down
Loading