diff --git a/requirements/server.txt b/requirements/server.txt index 76319ad6..c73bb076 100644 --- a/requirements/server.txt +++ b/requirements/server.txt @@ -19,5 +19,8 @@ librosa openvino openvino-genai openvino-tokenizers -optimum -optimum-intel \ No newline at end of file +optimum +optimum-intel + +# mlx (for Apple Silicon M1/M2/M3) +mlx-whisper \ No newline at end of file diff --git a/run_server.py b/run_server.py index 387574ec..f3b09dd7 100644 --- a/run_server.py +++ b/run_server.py @@ -10,10 +10,14 @@ parser.add_argument('--backend', '-b', type=str, default='faster_whisper', - help='Backends from ["tensorrt", "faster_whisper", "openvino"]') + help='Backends from ["tensorrt", "faster_whisper", "openvino", "mlx_whisper"]') parser.add_argument('--faster_whisper_custom_model_path', '-fw', type=str, default=None, help="Custom Faster Whisper Model") + parser.add_argument('--mlx_model_path', '-mlx', + type=str, + default=None, + help='MLX Whisper model (e.g., "small.en", "mlx-community/whisper-large-v3-turbo")') parser.add_argument('--trt_model_path', '-trt', type=str, default=None, @@ -59,6 +63,7 @@ port=args.port, backend=args.backend, faster_whisper_custom_model_path=args.faster_whisper_custom_model_path, + mlx_model_path=args.mlx_model_path, whisper_tensorrt_path=args.trt_model_path, trt_multilingual=args.trt_multilingual, trt_py_session=args.trt_py_session, diff --git a/whisper_live/backend/mlx_whisper_backend.py b/whisper_live/backend/mlx_whisper_backend.py new file mode 100644 index 00000000..0a0062f6 --- /dev/null +++ b/whisper_live/backend/mlx_whisper_backend.py @@ -0,0 +1,213 @@ +""" +MLX Whisper Backend for WhisperLive + +This backend uses mlx-whisper for optimized transcription on Apple Silicon (M1/M2/M3). +MLX leverages Apple's Neural Engine and GPU for fast, efficient inference. +""" + +import json +import logging +import threading +from typing import Optional + +from whisper_live.transcriber.transcriber_mlx import WhisperMLX +from whisper_live.backend.base import ServeClientBase + + +class ServeClientMLXWhisper(ServeClientBase): + """ + Backend implementation for MLX Whisper on Apple Silicon. + + This backend provides hardware-accelerated transcription using Apple's MLX + framework, optimized for M1/M2/M3 chips with Neural Engine and GPU support. + """ + + SINGLE_MODEL = None + SINGLE_MODEL_LOCK = threading.Lock() + + def __init__( + self, + websocket, + task: str = "transcribe", + language: Optional[str] = None, + client_uid: Optional[str] = None, + model: str = "small.en", + initial_prompt: Optional[str] = None, + vad_parameters: Optional[dict] = None, + use_vad: bool = True, + single_model: bool = False, + send_last_n_segments: int = 10, + no_speech_thresh: float = 0.45, + clip_audio: bool = False, + same_output_threshold: int = 7, + translation_queue=None, + ): + """ + Initialize MLX Whisper backend. + + Args: + websocket: WebSocket connection to the client + task (str): "transcribe" or "translate" + language (str, optional): Language code (e.g., "en", "es") + client_uid (str, optional): Unique client identifier + model (str): Model size or HuggingFace repo ID + initial_prompt (str, optional): Initial prompt for transcription + vad_parameters (dict, optional): VAD parameters (not used in MLX) + use_vad (bool): Whether to use VAD (not used in MLX) + single_model (bool): Share one model across all clients + send_last_n_segments (int): Number of recent segments to send + no_speech_thresh (float): Threshold for filtering silent segments + clip_audio (bool): Whether to clip audio with no valid segments + same_output_threshold (int): Threshold for repeated output filtering + translation_queue: Queue for translation (if enabled) + """ + super().__init__( + client_uid, + websocket, + send_last_n_segments, + no_speech_thresh, + clip_audio, + same_output_threshold, + translation_queue, + ) + + self.model_name = model + self.language = "en" if model.endswith(".en") else language + self.task = task + self.initial_prompt = initial_prompt + self.vad_parameters = vad_parameters or {"onset": 0.5} + self.use_vad = use_vad + + logging.info(f"Initializing MLX Whisper backend") + logging.info(f"Model: {self.model_name}") + logging.info(f"Language: {self.language}") + logging.info(f"Task: {self.task}") + logging.info(f"Using Apple Neural Engine and GPU acceleration") + + # Initialize model + try: + if single_model: + if ServeClientMLXWhisper.SINGLE_MODEL is None: + self.create_model() + ServeClientMLXWhisper.SINGLE_MODEL = self.transcriber + else: + self.transcriber = ServeClientMLXWhisper.SINGLE_MODEL + logging.info("Using shared MLX model instance") + else: + self.create_model() + except Exception as e: + logging.error(f"Failed to load MLX Whisper model: {e}") + self.websocket.send( + json.dumps({ + "uid": self.client_uid, + "status": "ERROR", + "message": f"Failed to load MLX Whisper model: {str(e)}. " + f"Make sure you're running on Apple Silicon (M1/M2/M3) " + f"and have mlx-whisper installed.", + }) + ) + self.websocket.close() + return + + # Start transcription thread + self.trans_thread = threading.Thread(target=self.speech_to_text) + self.trans_thread.start() + + # Send ready message to client + self.websocket.send( + json.dumps({ + "uid": self.client_uid, + "message": self.SERVER_READY, + "backend": "mlx_whisper", + }) + ) + logging.info(f"[MLX] Client {self.client_uid} initialized successfully") + + def create_model(self): + """ + Initialize the MLX Whisper model. + + This method loads the specified model using MLX for hardware acceleration + on Apple Silicon devices. + """ + logging.info(f"Loading MLX Whisper model: {self.model_name}") + self.transcriber = WhisperMLX( + model_name=self.model_name, + path_or_hf_repo=None, + ) + logging.info("MLX Whisper model loaded successfully") + + def transcribe_audio(self, input_sample): + """ + Transcribe audio using MLX Whisper. + + Args: + input_sample (np.ndarray): Audio data as numpy array (16kHz) + + Returns: + List[MLXSegment]: List of transcribed segments with timing information + """ + # Acquire lock if using shared model + if ServeClientMLXWhisper.SINGLE_MODEL: + ServeClientMLXWhisper.SINGLE_MODEL_LOCK.acquire() + + try: + result = self.transcriber.transcribe( + input_sample, + language=self.language, + task=self.task, + initial_prompt=self.initial_prompt, + vad_filter=self.use_vad, + vad_parameters=self.vad_parameters, + ) + + # Auto-detect language if not set + if self.language is None and len(result) > 0: + detected_lang, prob = self.transcriber.detect_language(input_sample) + self.set_language(detected_lang) + logging.info(f"Detected language: {detected_lang} (probability: {prob:.2f})") + + return result + + except Exception as e: + logging.error(f"MLX transcription failed: {e}") + raise + + finally: + # Release lock if using shared model + if ServeClientMLXWhisper.SINGLE_MODEL: + ServeClientMLXWhisper.SINGLE_MODEL_LOCK.release() + + def handle_transcription_output(self, result, duration): + """ + Process and send transcription results to the client. + + Args: + result: Transcription result from transcribe_audio() + duration (float): Duration of the audio chunk in seconds + """ + segments = [] + + if len(result): + last_segment = self.update_segments(result, duration) + segments = self.prepare_segments(last_segment) + + if len(segments): + self.send_transcription_to_client(segments) + + def set_language(self, language: str): + """ + Set the transcription language. + + Args: + language (str): Language code (e.g., "en", "es", "fr") + """ + self.language = language + logging.info(f"Language set to: {language}") + + def cleanup(self): + """ + Clean up resources when client disconnects. + """ + super().cleanup() + logging.info(f"[MLX] Client {self.client_uid} cleaned up") diff --git a/whisper_live/server.py b/whisper_live/server.py index 47a4543f..7381ce53 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -123,6 +123,7 @@ class BackendType(Enum): FASTER_WHISPER = "faster_whisper" TENSORRT = "tensorrt" OPENVINO = "openvino" + MLX_WHISPER = "mlx_whisper" @staticmethod def valid_types() -> List[str]: @@ -137,10 +138,13 @@ def is_faster_whisper(self) -> bool: def is_tensorrt(self) -> bool: return self == BackendType.TENSORRT - + def is_openvino(self) -> bool: return self == BackendType.OPENVINO + def is_mlx_whisper(self) -> bool: + return self == BackendType.MLX_WHISPER + class TranscriptionServer: RATE = 16000 @@ -153,7 +157,7 @@ def __init__(self): def initialize_client( self, websocket, options, faster_whisper_custom_model_path, - whisper_tensorrt_path, trt_multilingual, trt_py_session=False, + mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=False, ): client: Optional[ServeClientBase] = None @@ -242,6 +246,42 @@ def initialize_client( "Reverting to available backend: 'faster_whisper'" })) + if self.backend.is_mlx_whisper(): + try: + from whisper_live.backend.mlx_whisper_backend import ServeClientMLXWhisper + # Use server-provided model path if available, otherwise use client's model + if mlx_model_path is not None: + logging.info(f"Using server-specified MLX model: {mlx_model_path}") + options["model"] = mlx_model_path + client = ServeClientMLXWhisper( + websocket, + language=options["language"], + task=options["task"], + client_uid=options["uid"], + model=options["model"], + initial_prompt=options.get("initial_prompt"), + vad_parameters=options.get("vad_parameters"), + use_vad=self.use_vad, + single_model=self.single_model, + send_last_n_segments=options.get("send_last_n_segments", 10), + no_speech_thresh=options.get("no_speech_thresh", 0.45), + clip_audio=options.get("clip_audio", False), + same_output_threshold=options.get("same_output_threshold", 10), + translation_queue=translation_queue, + ) + logging.info("Running MLX Whisper backend (Apple Silicon optimized).") + except Exception as e: + logging.error(f"MLX Whisper not supported: {e}") + self.backend = BackendType.FASTER_WHISPER + self.client_uid = options["uid"] + websocket.send(json.dumps({ + "uid": self.client_uid, + "status": "WARNING", + "message": "MLX Whisper not supported. " + "Make sure you're on Apple Silicon (M1/M2/M3) and have mlx-whisper installed. " + "Reverting to available backend: 'faster_whisper'" + })) + try: if self.backend.is_faster_whisper(): from whisper_live.backend.faster_whisper_backend import ServeClientFasterWhisper @@ -297,7 +337,7 @@ def get_audio_from_websocket(self, websocket): return np.frombuffer(frame_data, dtype=np.float32) def handle_new_connection(self, websocket, faster_whisper_custom_model_path, - whisper_tensorrt_path, trt_multilingual, trt_py_session=False): + mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=False): try: logging.info("New client connected") options = websocket.recv() @@ -311,7 +351,7 @@ def handle_new_connection(self, websocket, faster_whisper_custom_model_path, if self.backend.is_tensorrt(): self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE) self.initialize_client(websocket, options, faster_whisper_custom_model_path, - whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session) + mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session) return True except json.JSONDecodeError: logging.error("Failed to decode JSON from client") @@ -343,9 +383,10 @@ def process_audio_frames(self, websocket): return True def recv_audio(self, - websocket, + websocket, backend: BackendType = BackendType.FASTER_WHISPER, faster_whisper_custom_model_path=None, + mlx_model_path=None, whisper_tensorrt_path=None, trt_multilingual=False, trt_py_session=False): @@ -375,7 +416,7 @@ def recv_audio(self, """ self.backend = backend if not self.handle_new_connection(websocket, faster_whisper_custom_model_path, - whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session): + mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session): return try: @@ -397,6 +438,7 @@ def run(self, port=9090, backend="tensorrt", faster_whisper_custom_model_path=None, + mlx_model_path=None, whisper_tensorrt_path=None, trt_multilingual=False, trt_py_session=False, @@ -418,7 +460,7 @@ def run(self, if whisper_tensorrt_path is not None and not os.path.exists(whisper_tensorrt_path): raise ValueError(f"TensorRT model '{whisper_tensorrt_path}' is not a valid path.") if single_model: - if faster_whisper_custom_model_path or whisper_tensorrt_path: + if faster_whisper_custom_model_path or whisper_tensorrt_path or mlx_model_path: logging.info("Custom model option was provided. Switching to single model mode.") self.single_model = True # TODO: load model initially @@ -431,6 +473,7 @@ def run(self, self.recv_audio, backend=BackendType(backend), faster_whisper_custom_model_path=faster_whisper_custom_model_path, + mlx_model_path=mlx_model_path, whisper_tensorrt_path=whisper_tensorrt_path, trt_multilingual=trt_multilingual, trt_py_session=trt_py_session, diff --git a/whisper_live/transcriber/transcriber_mlx.py b/whisper_live/transcriber/transcriber_mlx.py new file mode 100644 index 00000000..5c16f9b0 --- /dev/null +++ b/whisper_live/transcriber/transcriber_mlx.py @@ -0,0 +1,197 @@ +""" +MLX Whisper Model Wrapper for Apple Silicon + +This module provides a wrapper around mlx-whisper for optimized +transcription on Apple Silicon (M1/M2/M3) Macs using the MLX framework. +""" + +import logging +import numpy as np +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class MLXSegment: + """ + Represents a transcribed segment with timing information. + + Attributes: + start (float): Start time in seconds + end (float): End time in seconds + text (str): Transcribed text + no_speech_prob (float): Probability of no speech (0-1) + """ + start: float + end: float + text: str + no_speech_prob: float = 0.0 + + +class WhisperMLX: + """ + Wrapper around mlx-whisper for transcription on Apple Silicon. + + This class provides a consistent interface compatible with WhisperLive's + backend system while using MLX for hardware-accelerated inference. + """ + + def __init__( + self, + model_name: str = "mlx-community/whisper-small.en-mlx", + path_or_hf_repo: str = None, + ): + """ + Initialize the MLX Whisper model. + + Args: + model_name (str): Model name or size. Can be a standard size like "small.en", + "base", "medium", "large-v3", or a HuggingFace repo ID. + path_or_hf_repo (str, optional): Explicit path or HuggingFace repo. + Overrides model_name if provided. + """ + try: + import mlx_whisper + self.mlx_whisper = mlx_whisper + except ImportError: + raise ImportError( + "mlx-whisper is not installed. Install it with: pip install mlx-whisper" + ) + + self.model_name = path_or_hf_repo if path_or_hf_repo else model_name + + # Map standard model sizes to MLX community models + self.model_size_map = { + "tiny": "mlx-community/whisper-tiny-mlx", + "tiny.en": "mlx-community/whisper-tiny.en-mlx", + "base": "mlx-community/whisper-base-mlx", + "base.en": "mlx-community/whisper-base.en-mlx", + "small": "mlx-community/whisper-small-mlx", + "small.en": "mlx-community/whisper-small.en-mlx", + "medium": "mlx-community/whisper-medium-mlx", + "medium.en": "mlx-community/whisper-medium.en-mlx", + "large-v2": "mlx-community/whisper-large-v2-mlx", + "large-v3": "mlx-community/whisper-large-v3-mlx", + "turbo": "mlx-community/whisper-large-v3-turbo", + "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", + } + + # Convert standard size to MLX repo if needed + if self.model_name in self.model_size_map: + self.model_path = self.model_size_map[self.model_name] + logging.info(f"Mapping model size '{self.model_name}' to '{self.model_path}'") + else: + self.model_path = self.model_name + + logging.info(f"Loading MLX Whisper model: {self.model_path}") + logging.info("MLX will use Apple Neural Engine and GPU for acceleration") + + def transcribe( + self, + audio: np.ndarray, + language: Optional[str] = None, + task: str = "transcribe", + initial_prompt: Optional[str] = None, + vad_filter: bool = False, + vad_parameters: Optional[dict] = None, + ) -> List[MLXSegment]: + """ + Transcribe audio using MLX Whisper. + + Args: + audio (np.ndarray): Audio data as numpy array (16kHz) + language (str, optional): Language code (e.g., "en", "es", "fr") + task (str): Task type - "transcribe" or "translate" + initial_prompt (str, optional): Initial prompt for the model + vad_filter (bool): Whether to use VAD filtering (not used in MLX) + vad_parameters (dict, optional): VAD parameters (not used in MLX) + + Returns: + List[MLXSegment]: List of transcribed segments with timing + """ + try: + # Ensure audio is float32 + if audio.dtype != np.float32: + audio = audio.astype(np.float32) + + # Normalize audio to [-1, 1] range if needed + if audio.max() > 1.0 or audio.min() < -1.0: + audio = audio / 32768.0 + + # Prepare transcription options + transcribe_opts = { + "path_or_hf_repo": self.model_path, + "task": task, + "verbose": False, + } + + if language: + transcribe_opts["language"] = language + + if initial_prompt: + transcribe_opts["initial_prompt"] = initial_prompt + + # Run transcription + result = self.mlx_whisper.transcribe( + audio, + **transcribe_opts + ) + + # Convert result to MLXSegment objects + segments = [] + + if isinstance(result, dict) and "segments" in result: + for seg in result["segments"]: + segment = MLXSegment( + start=seg.get("start", 0.0), + end=seg.get("end", 0.0), + text=seg.get("text", "").strip(), + no_speech_prob=seg.get("no_speech_prob", 0.0) + ) + segments.append(segment) + elif isinstance(result, dict) and "text" in result: + # If only text is returned (no segments), create a single segment + segment = MLXSegment( + start=0.0, + end=len(audio) / 16000.0, # Calculate duration + text=result["text"].strip(), + no_speech_prob=0.0 + ) + segments.append(segment) + + return segments + + except Exception as e: + logging.error(f"MLX transcription failed: {e}") + raise + + def detect_language(self, audio: np.ndarray) -> tuple: + """ + Detect the language of the audio. + + Args: + audio (np.ndarray): Audio data + + Returns: + tuple: (language_code, probability) + """ + try: + # MLX whisper doesn't have a separate language detection API + # We'll transcribe a small portion and infer from the result + result = self.mlx_whisper.transcribe( + audio[:16000], # First second + path_or_hf_repo=self.model_path, + task="transcribe", + verbose=False, + ) + + # Try to extract language from result + if isinstance(result, dict): + lang = result.get("language", "en") + return (lang, 1.0) + + return ("en", 1.0) + + except Exception as e: + logging.error(f"Language detection failed: {e}") + return ("en", 1.0)