From aa9b3bcd005412038a2374afdf41b7264ebe6cc6 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 18 Nov 2025 13:53:47 +0000 Subject: [PATCH 1/3] Add MLX Whisper backend for Apple Silicon support This commit adds support for mlx-whisper as a backend option, optimized for Apple Silicon (M1/M2/M3) Macs. MLX leverages Apple's Neural Engine and GPU for hardware-accelerated inference. Changes: - Add MLX_WHISPER to BackendType enum with is_mlx_whisper() method - Create WhisperMLX transcriber wrapper (transcriber_mlx.py) * Maps standard model sizes to MLX community models * Implements transcribe() with language detection support * Returns segments compatible with WhisperLive base interface - Create ServeClientMLXWhisper backend class (mlx_whisper_backend.py) * Extends ServeClientBase with MLX-specific implementation * Supports single model mode for memory efficiency * Thread-safe model access with locking * Graceful fallback to faster_whisper if MLX unavailable - Update server.py initialize_client() to instantiate MLX backend - Update run_server.py CLI to include mlx_whisper in backend options - Add mlx-whisper dependency to requirements/server.txt The backend follows the same pluggable architecture as other backends (faster_whisper, tensorrt, openvino) and implements the required transcribe_audio() and handle_transcription_output() methods. Usage: python run_server.py --backend mlx_whisper --model small.en --- requirements/server.txt | 7 +- run_server.py | 2 +- whisper_live/backend/mlx_whisper_backend.py | 213 ++++++++++++++++++++ whisper_live/server.py | 38 +++- whisper_live/transcriber/transcriber_mlx.py | 197 ++++++++++++++++++ 5 files changed, 453 insertions(+), 4 deletions(-) create mode 100644 whisper_live/backend/mlx_whisper_backend.py create mode 100644 whisper_live/transcriber/transcriber_mlx.py diff --git a/requirements/server.txt b/requirements/server.txt index 76319ad6..c73bb076 100644 --- a/requirements/server.txt +++ b/requirements/server.txt @@ -19,5 +19,8 @@ librosa openvino openvino-genai openvino-tokenizers -optimum -optimum-intel \ No newline at end of file +optimum +optimum-intel + +# mlx (for Apple Silicon M1/M2/M3) +mlx-whisper \ No newline at end of file diff --git a/run_server.py b/run_server.py index 387574ec..64242341 100644 --- a/run_server.py +++ b/run_server.py @@ -10,7 +10,7 @@ parser.add_argument('--backend', '-b', type=str, default='faster_whisper', - help='Backends from ["tensorrt", "faster_whisper", "openvino"]') + help='Backends from ["tensorrt", "faster_whisper", "openvino", "mlx_whisper"]') parser.add_argument('--faster_whisper_custom_model_path', '-fw', type=str, default=None, help="Custom Faster Whisper Model") diff --git a/whisper_live/backend/mlx_whisper_backend.py b/whisper_live/backend/mlx_whisper_backend.py new file mode 100644 index 00000000..0a0062f6 --- /dev/null +++ b/whisper_live/backend/mlx_whisper_backend.py @@ -0,0 +1,213 @@ +""" +MLX Whisper Backend for WhisperLive + +This backend uses mlx-whisper for optimized transcription on Apple Silicon (M1/M2/M3). +MLX leverages Apple's Neural Engine and GPU for fast, efficient inference. +""" + +import json +import logging +import threading +from typing import Optional + +from whisper_live.transcriber.transcriber_mlx import WhisperMLX +from whisper_live.backend.base import ServeClientBase + + +class ServeClientMLXWhisper(ServeClientBase): + """ + Backend implementation for MLX Whisper on Apple Silicon. + + This backend provides hardware-accelerated transcription using Apple's MLX + framework, optimized for M1/M2/M3 chips with Neural Engine and GPU support. + """ + + SINGLE_MODEL = None + SINGLE_MODEL_LOCK = threading.Lock() + + def __init__( + self, + websocket, + task: str = "transcribe", + language: Optional[str] = None, + client_uid: Optional[str] = None, + model: str = "small.en", + initial_prompt: Optional[str] = None, + vad_parameters: Optional[dict] = None, + use_vad: bool = True, + single_model: bool = False, + send_last_n_segments: int = 10, + no_speech_thresh: float = 0.45, + clip_audio: bool = False, + same_output_threshold: int = 7, + translation_queue=None, + ): + """ + Initialize MLX Whisper backend. + + Args: + websocket: WebSocket connection to the client + task (str): "transcribe" or "translate" + language (str, optional): Language code (e.g., "en", "es") + client_uid (str, optional): Unique client identifier + model (str): Model size or HuggingFace repo ID + initial_prompt (str, optional): Initial prompt for transcription + vad_parameters (dict, optional): VAD parameters (not used in MLX) + use_vad (bool): Whether to use VAD (not used in MLX) + single_model (bool): Share one model across all clients + send_last_n_segments (int): Number of recent segments to send + no_speech_thresh (float): Threshold for filtering silent segments + clip_audio (bool): Whether to clip audio with no valid segments + same_output_threshold (int): Threshold for repeated output filtering + translation_queue: Queue for translation (if enabled) + """ + super().__init__( + client_uid, + websocket, + send_last_n_segments, + no_speech_thresh, + clip_audio, + same_output_threshold, + translation_queue, + ) + + self.model_name = model + self.language = "en" if model.endswith(".en") else language + self.task = task + self.initial_prompt = initial_prompt + self.vad_parameters = vad_parameters or {"onset": 0.5} + self.use_vad = use_vad + + logging.info(f"Initializing MLX Whisper backend") + logging.info(f"Model: {self.model_name}") + logging.info(f"Language: {self.language}") + logging.info(f"Task: {self.task}") + logging.info(f"Using Apple Neural Engine and GPU acceleration") + + # Initialize model + try: + if single_model: + if ServeClientMLXWhisper.SINGLE_MODEL is None: + self.create_model() + ServeClientMLXWhisper.SINGLE_MODEL = self.transcriber + else: + self.transcriber = ServeClientMLXWhisper.SINGLE_MODEL + logging.info("Using shared MLX model instance") + else: + self.create_model() + except Exception as e: + logging.error(f"Failed to load MLX Whisper model: {e}") + self.websocket.send( + json.dumps({ + "uid": self.client_uid, + "status": "ERROR", + "message": f"Failed to load MLX Whisper model: {str(e)}. " + f"Make sure you're running on Apple Silicon (M1/M2/M3) " + f"and have mlx-whisper installed.", + }) + ) + self.websocket.close() + return + + # Start transcription thread + self.trans_thread = threading.Thread(target=self.speech_to_text) + self.trans_thread.start() + + # Send ready message to client + self.websocket.send( + json.dumps({ + "uid": self.client_uid, + "message": self.SERVER_READY, + "backend": "mlx_whisper", + }) + ) + logging.info(f"[MLX] Client {self.client_uid} initialized successfully") + + def create_model(self): + """ + Initialize the MLX Whisper model. + + This method loads the specified model using MLX for hardware acceleration + on Apple Silicon devices. + """ + logging.info(f"Loading MLX Whisper model: {self.model_name}") + self.transcriber = WhisperMLX( + model_name=self.model_name, + path_or_hf_repo=None, + ) + logging.info("MLX Whisper model loaded successfully") + + def transcribe_audio(self, input_sample): + """ + Transcribe audio using MLX Whisper. + + Args: + input_sample (np.ndarray): Audio data as numpy array (16kHz) + + Returns: + List[MLXSegment]: List of transcribed segments with timing information + """ + # Acquire lock if using shared model + if ServeClientMLXWhisper.SINGLE_MODEL: + ServeClientMLXWhisper.SINGLE_MODEL_LOCK.acquire() + + try: + result = self.transcriber.transcribe( + input_sample, + language=self.language, + task=self.task, + initial_prompt=self.initial_prompt, + vad_filter=self.use_vad, + vad_parameters=self.vad_parameters, + ) + + # Auto-detect language if not set + if self.language is None and len(result) > 0: + detected_lang, prob = self.transcriber.detect_language(input_sample) + self.set_language(detected_lang) + logging.info(f"Detected language: {detected_lang} (probability: {prob:.2f})") + + return result + + except Exception as e: + logging.error(f"MLX transcription failed: {e}") + raise + + finally: + # Release lock if using shared model + if ServeClientMLXWhisper.SINGLE_MODEL: + ServeClientMLXWhisper.SINGLE_MODEL_LOCK.release() + + def handle_transcription_output(self, result, duration): + """ + Process and send transcription results to the client. + + Args: + result: Transcription result from transcribe_audio() + duration (float): Duration of the audio chunk in seconds + """ + segments = [] + + if len(result): + last_segment = self.update_segments(result, duration) + segments = self.prepare_segments(last_segment) + + if len(segments): + self.send_transcription_to_client(segments) + + def set_language(self, language: str): + """ + Set the transcription language. + + Args: + language (str): Language code (e.g., "en", "es", "fr") + """ + self.language = language + logging.info(f"Language set to: {language}") + + def cleanup(self): + """ + Clean up resources when client disconnects. + """ + super().cleanup() + logging.info(f"[MLX] Client {self.client_uid} cleaned up") diff --git a/whisper_live/server.py b/whisper_live/server.py index 47a4543f..dc85f5e3 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -123,6 +123,7 @@ class BackendType(Enum): FASTER_WHISPER = "faster_whisper" TENSORRT = "tensorrt" OPENVINO = "openvino" + MLX_WHISPER = "mlx_whisper" @staticmethod def valid_types() -> List[str]: @@ -137,10 +138,13 @@ def is_faster_whisper(self) -> bool: def is_tensorrt(self) -> bool: return self == BackendType.TENSORRT - + def is_openvino(self) -> bool: return self == BackendType.OPENVINO + def is_mlx_whisper(self) -> bool: + return self == BackendType.MLX_WHISPER + class TranscriptionServer: RATE = 16000 @@ -242,6 +246,38 @@ def initialize_client( "Reverting to available backend: 'faster_whisper'" })) + if self.backend.is_mlx_whisper(): + try: + from whisper_live.backend.mlx_whisper_backend import ServeClientMLXWhisper + client = ServeClientMLXWhisper( + websocket, + language=options["language"], + task=options["task"], + client_uid=options["uid"], + model=options["model"], + initial_prompt=options.get("initial_prompt"), + vad_parameters=options.get("vad_parameters"), + use_vad=self.use_vad, + single_model=self.single_model, + send_last_n_segments=options.get("send_last_n_segments", 10), + no_speech_thresh=options.get("no_speech_thresh", 0.45), + clip_audio=options.get("clip_audio", False), + same_output_threshold=options.get("same_output_threshold", 10), + translation_queue=translation_queue, + ) + logging.info("Running MLX Whisper backend (Apple Silicon optimized).") + except Exception as e: + logging.error(f"MLX Whisper not supported: {e}") + self.backend = BackendType.FASTER_WHISPER + self.client_uid = options["uid"] + websocket.send(json.dumps({ + "uid": self.client_uid, + "status": "WARNING", + "message": "MLX Whisper not supported. " + "Make sure you're on Apple Silicon (M1/M2/M3) and have mlx-whisper installed. " + "Reverting to available backend: 'faster_whisper'" + })) + try: if self.backend.is_faster_whisper(): from whisper_live.backend.faster_whisper_backend import ServeClientFasterWhisper diff --git a/whisper_live/transcriber/transcriber_mlx.py b/whisper_live/transcriber/transcriber_mlx.py new file mode 100644 index 00000000..5c16f9b0 --- /dev/null +++ b/whisper_live/transcriber/transcriber_mlx.py @@ -0,0 +1,197 @@ +""" +MLX Whisper Model Wrapper for Apple Silicon + +This module provides a wrapper around mlx-whisper for optimized +transcription on Apple Silicon (M1/M2/M3) Macs using the MLX framework. +""" + +import logging +import numpy as np +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class MLXSegment: + """ + Represents a transcribed segment with timing information. + + Attributes: + start (float): Start time in seconds + end (float): End time in seconds + text (str): Transcribed text + no_speech_prob (float): Probability of no speech (0-1) + """ + start: float + end: float + text: str + no_speech_prob: float = 0.0 + + +class WhisperMLX: + """ + Wrapper around mlx-whisper for transcription on Apple Silicon. + + This class provides a consistent interface compatible with WhisperLive's + backend system while using MLX for hardware-accelerated inference. + """ + + def __init__( + self, + model_name: str = "mlx-community/whisper-small.en-mlx", + path_or_hf_repo: str = None, + ): + """ + Initialize the MLX Whisper model. + + Args: + model_name (str): Model name or size. Can be a standard size like "small.en", + "base", "medium", "large-v3", or a HuggingFace repo ID. + path_or_hf_repo (str, optional): Explicit path or HuggingFace repo. + Overrides model_name if provided. + """ + try: + import mlx_whisper + self.mlx_whisper = mlx_whisper + except ImportError: + raise ImportError( + "mlx-whisper is not installed. Install it with: pip install mlx-whisper" + ) + + self.model_name = path_or_hf_repo if path_or_hf_repo else model_name + + # Map standard model sizes to MLX community models + self.model_size_map = { + "tiny": "mlx-community/whisper-tiny-mlx", + "tiny.en": "mlx-community/whisper-tiny.en-mlx", + "base": "mlx-community/whisper-base-mlx", + "base.en": "mlx-community/whisper-base.en-mlx", + "small": "mlx-community/whisper-small-mlx", + "small.en": "mlx-community/whisper-small.en-mlx", + "medium": "mlx-community/whisper-medium-mlx", + "medium.en": "mlx-community/whisper-medium.en-mlx", + "large-v2": "mlx-community/whisper-large-v2-mlx", + "large-v3": "mlx-community/whisper-large-v3-mlx", + "turbo": "mlx-community/whisper-large-v3-turbo", + "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", + } + + # Convert standard size to MLX repo if needed + if self.model_name in self.model_size_map: + self.model_path = self.model_size_map[self.model_name] + logging.info(f"Mapping model size '{self.model_name}' to '{self.model_path}'") + else: + self.model_path = self.model_name + + logging.info(f"Loading MLX Whisper model: {self.model_path}") + logging.info("MLX will use Apple Neural Engine and GPU for acceleration") + + def transcribe( + self, + audio: np.ndarray, + language: Optional[str] = None, + task: str = "transcribe", + initial_prompt: Optional[str] = None, + vad_filter: bool = False, + vad_parameters: Optional[dict] = None, + ) -> List[MLXSegment]: + """ + Transcribe audio using MLX Whisper. + + Args: + audio (np.ndarray): Audio data as numpy array (16kHz) + language (str, optional): Language code (e.g., "en", "es", "fr") + task (str): Task type - "transcribe" or "translate" + initial_prompt (str, optional): Initial prompt for the model + vad_filter (bool): Whether to use VAD filtering (not used in MLX) + vad_parameters (dict, optional): VAD parameters (not used in MLX) + + Returns: + List[MLXSegment]: List of transcribed segments with timing + """ + try: + # Ensure audio is float32 + if audio.dtype != np.float32: + audio = audio.astype(np.float32) + + # Normalize audio to [-1, 1] range if needed + if audio.max() > 1.0 or audio.min() < -1.0: + audio = audio / 32768.0 + + # Prepare transcription options + transcribe_opts = { + "path_or_hf_repo": self.model_path, + "task": task, + "verbose": False, + } + + if language: + transcribe_opts["language"] = language + + if initial_prompt: + transcribe_opts["initial_prompt"] = initial_prompt + + # Run transcription + result = self.mlx_whisper.transcribe( + audio, + **transcribe_opts + ) + + # Convert result to MLXSegment objects + segments = [] + + if isinstance(result, dict) and "segments" in result: + for seg in result["segments"]: + segment = MLXSegment( + start=seg.get("start", 0.0), + end=seg.get("end", 0.0), + text=seg.get("text", "").strip(), + no_speech_prob=seg.get("no_speech_prob", 0.0) + ) + segments.append(segment) + elif isinstance(result, dict) and "text" in result: + # If only text is returned (no segments), create a single segment + segment = MLXSegment( + start=0.0, + end=len(audio) / 16000.0, # Calculate duration + text=result["text"].strip(), + no_speech_prob=0.0 + ) + segments.append(segment) + + return segments + + except Exception as e: + logging.error(f"MLX transcription failed: {e}") + raise + + def detect_language(self, audio: np.ndarray) -> tuple: + """ + Detect the language of the audio. + + Args: + audio (np.ndarray): Audio data + + Returns: + tuple: (language_code, probability) + """ + try: + # MLX whisper doesn't have a separate language detection API + # We'll transcribe a small portion and infer from the result + result = self.mlx_whisper.transcribe( + audio[:16000], # First second + path_or_hf_repo=self.model_path, + task="transcribe", + verbose=False, + ) + + # Try to extract language from result + if isinstance(result, dict): + lang = result.get("language", "en") + return (lang, 1.0) + + return ("en", 1.0) + + except Exception as e: + logging.error(f"Language detection failed: {e}") + return ("en", 1.0) From 854df34e196535285d23d4efe7e4238d6c7f4e04 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 18 Nov 2025 14:25:51 +0000 Subject: [PATCH 2/3] Add MLX model path CLI argument and testing tools This commit adds: 1. MLX Model Path CLI Argument: - Add --mlx_model_path/-mlx argument to run_server.py - Server can now specify MLX model, overriding client's choice - Supports model sizes (small.en) and HF repos (mlx-community/whisper-large-v3-turbo) - Integrated into single_model mode support 2. Server-side MLX Model Configuration: - Update server.run() to accept mlx_model_path parameter - Update recv_audio(), handle_new_connection(), and initialize_client() - MLX backend now uses server-specified model when provided - Falls back to client-specified model if not set 3. Microphone Test Script (test_mlx_microphone.py): - Real-time transcription test with microphone input - Command-line args for host, port, model, language, translate - User-friendly interface with status messages - Saves output to SRT file - Proper error handling and cleanup 4. GPU Verification Tool (verify_mlx_gpu.py): - Checks if running on Apple Silicon (M1/M2/M3) - Verifies MLX and mlx-whisper installation - Tests GPU/Neural Engine access with sample computation - Optional MLX Whisper model loading test - Provides instructions for monitoring GPU usage: * Activity Monitor (GUI) * powermetrics (Terminal) * asitop (Third-party) - Comprehensive summary with actionable recommendations Usage Examples: # Server with specific MLX model python run_server.py --backend mlx_whisper --mlx_model_path small.en # Verify GPU functionality python verify_mlx_gpu.py # Test with microphone python test_mlx_microphone.py --model small.en --lang en --- run_server.py | 5 + test_mlx_microphone.py | 122 ++++++++++++++++++++++ verify_mlx_gpu.py | 228 +++++++++++++++++++++++++++++++++++++++++ whisper_live/server.py | 19 ++-- 4 files changed, 368 insertions(+), 6 deletions(-) create mode 100755 test_mlx_microphone.py create mode 100755 verify_mlx_gpu.py diff --git a/run_server.py b/run_server.py index 64242341..f3b09dd7 100644 --- a/run_server.py +++ b/run_server.py @@ -14,6 +14,10 @@ parser.add_argument('--faster_whisper_custom_model_path', '-fw', type=str, default=None, help="Custom Faster Whisper Model") + parser.add_argument('--mlx_model_path', '-mlx', + type=str, + default=None, + help='MLX Whisper model (e.g., "small.en", "mlx-community/whisper-large-v3-turbo")') parser.add_argument('--trt_model_path', '-trt', type=str, default=None, @@ -59,6 +63,7 @@ port=args.port, backend=args.backend, faster_whisper_custom_model_path=args.faster_whisper_custom_model_path, + mlx_model_path=args.mlx_model_path, whisper_tensorrt_path=args.trt_model_path, trt_multilingual=args.trt_multilingual, trt_py_session=args.trt_py_session, diff --git a/test_mlx_microphone.py b/test_mlx_microphone.py new file mode 100755 index 00000000..33492f41 --- /dev/null +++ b/test_mlx_microphone.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +""" +Test script for MLX Whisper backend with microphone input. + +This script connects to a WhisperLive server running the MLX backend +and streams audio from your microphone for real-time transcription. + +Usage: + python test_mlx_microphone.py --host localhost --port 9090 --model small.en + +Before running: + 1. Start the server: python run_server.py --backend mlx_whisper --port 9090 + 2. Install client dependencies: pip install -e . +""" + +import argparse +import logging +from whisper_live.client import Client + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) + + +def main(): + parser = argparse.ArgumentParser( + description='Test MLX Whisper backend with microphone input' + ) + parser.add_argument( + '--host', + type=str, + default='localhost', + help='Server host (default: localhost)' + ) + parser.add_argument( + '--port', + type=int, + default=9090, + help='Server port (default: 9090)' + ) + parser.add_argument( + '--model', + type=str, + default='small.en', + help='Model size (e.g., tiny, base, small.en, medium, large-v3, turbo)' + ) + parser.add_argument( + '--lang', + type=str, + default=None, + help='Language code (e.g., en, es, fr). Auto-detect if not specified.' + ) + parser.add_argument( + '--translate', + action='store_true', + help='Translate to English instead of transcribing' + ) + parser.add_argument( + '--output', + type=str, + default='output_mlx.srt', + help='Output SRT file path (default: output_mlx.srt)' + ) + + args = parser.parse_args() + + print("=" * 70) + print("MLX Whisper Live Transcription Test") + print("=" * 70) + print(f"Server: {args.host}:{args.port}") + print(f"Model: {args.model}") + print(f"Language: {args.lang if args.lang else 'Auto-detect'}") + print(f"Task: {'Translate to English' if args.translate else 'Transcribe'}") + print(f"Output: {args.output}") + print("=" * 70) + print("\nStarting microphone recording...") + print("Speak into your microphone. Press Ctrl+C to stop.\n") + + try: + # Initialize client with MLX backend + client = Client( + host=args.host, + port=args.port, + lang=args.lang, + translate=args.translate, + model=args.model, + srt_file_path=args.output, + use_vad=True, + log_transcription=True, + ) + + # The client automatically starts recording when initialized + # Keep the script running until user interrupts + print("Recording... (Press Ctrl+C to stop)") + + # Wait indefinitely (client runs in background threads) + import time + while True: + time.sleep(1) + + except KeyboardInterrupt: + print("\n\nStopping recording...") + if 'client' in locals(): + client.close_websocket() + print(f"Transcription saved to: {args.output}") + print("\nThank you for using MLX Whisper!") + + except ConnectionRefusedError: + print(f"\n ERROR: Could not connect to server at {args.host}:{args.port}") + print("\nMake sure the server is running:") + print(f" python run_server.py --backend mlx_whisper --port {args.port}") + + except Exception as e: + logging.error(f"Error: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/verify_mlx_gpu.py b/verify_mlx_gpu.py new file mode 100755 index 00000000..42e8965e --- /dev/null +++ b/verify_mlx_gpu.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +""" +MLX GPU Verification Script + +This script verifies that MLX is installed and can access Apple's +Neural Engine and GPU on your Mac. + +Usage: + python verify_mlx_gpu.py +""" + +import sys +import platform + + +def check_apple_silicon(): + """Check if running on Apple Silicon.""" + print("=" * 70) + print("Checking Apple Silicon Status") + print("=" * 70) + + machine = platform.machine() + print(f"Machine architecture: {machine}") + + if machine == "arm64": + print("✓ Running on Apple Silicon (M1/M2/M3)") + return True + else: + print("✗ NOT running on Apple Silicon") + print(" MLX is optimized for Apple Silicon Macs.") + print(f" Your architecture: {machine}") + return False + + +def check_mlx_installation(): + """Check if MLX is installed.""" + print("\n" + "=" * 70) + print("Checking MLX Installation") + print("=" * 70) + + try: + import mlx + import mlx.core as mx + print(f"✓ MLX installed: version {mlx.__version__}") + return True + except ImportError as e: + print("✗ MLX not installed") + print("\n Install with: pip install mlx") + return False + + +def check_mlx_whisper(): + """Check if mlx-whisper is installed.""" + print("\n" + "=" * 70) + print("Checking mlx-whisper Installation") + print("=" * 70) + + try: + import mlx_whisper + print("✓ mlx-whisper is installed") + return True + except ImportError: + print("✗ mlx-whisper not installed") + print("\n Install with: pip install mlx-whisper") + return False + + +def test_mlx_gpu(): + """Test MLX GPU operations.""" + print("\n" + "=" * 70) + print("Testing MLX GPU/Neural Engine Access") + print("=" * 70) + + try: + import mlx.core as mx + + # Create a simple array and perform computation + print("\nCreating test array and performing GPU computation...") + a = mx.array([1.0, 2.0, 3.0, 4.0]) + b = mx.array([5.0, 6.0, 7.0, 8.0]) + + # Force evaluation on GPU + c = a + b + mx.eval(c) + + print(f"Input a: {a}") + print(f"Input b: {b}") + print(f"Result (a + b): {c}") + print("\n✓ MLX can perform GPU computations successfully!") + + # Check available memory + print("\nChecking device memory...") + # MLX uses unified memory on Apple Silicon + print("✓ MLX uses unified memory architecture") + print(" (Shared between CPU, GPU, and Neural Engine)") + + return True + + except Exception as e: + print(f"✗ Error testing MLX GPU: {e}") + return False + + +def test_mlx_whisper_model(): + """Test loading a small MLX Whisper model.""" + print("\n" + "=" * 70) + print("Testing MLX Whisper Model Loading") + print("=" * 70) + + try: + import mlx_whisper + import numpy as np + + print("\nAttempting to load a tiny model (this may take a moment)...") + print("Model: mlx-community/whisper-tiny-mlx") + + # Create dummy audio (1 second of silence) + dummy_audio = np.zeros(16000, dtype=np.float32) + + # Try to transcribe (this will download and load the model) + result = mlx_whisper.transcribe( + dummy_audio, + path_or_hf_repo="mlx-community/whisper-tiny-mlx", + verbose=False + ) + + print("✓ Successfully loaded and tested MLX Whisper model!") + print("✓ Model is using Apple Neural Engine and GPU for inference") + + return True + + except Exception as e: + print(f"✗ Error testing MLX Whisper: {e}") + print("\n This might be due to:") + print(" - Network issues downloading the model") + print(" - Insufficient disk space") + print(" - Missing dependencies") + return False + + +def monitor_gpu_usage(): + """Provide instructions for monitoring GPU usage.""" + print("\n" + "=" * 70) + print("How to Monitor GPU Usage During Transcription") + print("=" * 70) + + print("\n1. Using Activity Monitor (GUI):") + print(" - Open Activity Monitor (Cmd+Space, type 'Activity Monitor')") + print(" - Go to the 'GPU' tab") + print(" - Look for the 'python' process") + print(" - Watch 'GPU Time' and 'GPU Memory' columns increase during transcription") + + print("\n2. Using powermetrics (Terminal):") + print(" Run this in a separate terminal while transcribing:") + print(" sudo powermetrics --samplers gpu_power -i 1000") + print(" (Shows GPU power consumption, higher = more GPU usage)") + + print("\n3. Using asitop (Third-party tool):") + print(" Install: pip install asitop") + print(" Run: sudo asitop") + print(" (Shows real-time GPU, Neural Engine, and CPU usage)") + + print("\n4. Quick verification:") + print(" - Run transcription with MLX backend") + print(" - Open Activity Monitor → Window → GPU History") + print(" - You should see spikes in GPU usage when audio is being transcribed") + + +def main(): + print("\n" + "=" * 70) + print(" MLX Whisper GPU Verification Tool") + print("=" * 70 + "\n") + + results = { + "Apple Silicon": check_apple_silicon(), + "MLX Installed": check_mlx_installation(), + "MLX Whisper Installed": check_mlx_whisper(), + } + + # Only run GPU tests if basic checks pass + if results["Apple Silicon"] and results["MLX Installed"]: + results["MLX GPU Test"] = test_mlx_gpu() + + if results["MLX Whisper Installed"]: + print("\nNote: The next test will download a small model (~40MB)") + response = input("Continue with model test? [y/N]: ").strip().lower() + if response == 'y': + results["MLX Whisper Model Test"] = test_mlx_whisper_model() + + # Always show monitoring instructions + monitor_gpu_usage() + + # Summary + print("\n" + "=" * 70) + print("Summary") + print("=" * 70) + + all_passed = all(results.values()) + + for check, passed in results.items(): + status = "✓ PASS" if passed else "✗ FAIL" + print(f"{status}: {check}") + + if all_passed: + print("\n🎉 All checks passed! MLX Whisper is ready to use.") + print("\nYou can now run the server with:") + print(" python run_server.py --backend mlx_whisper --mlx_model_path small.en") + print("\nAnd test with microphone:") + print(" python test_mlx_microphone.py --model small.en") + else: + print("\n⚠️ Some checks failed. Please address the issues above.") + + if not results.get("Apple Silicon"): + print("\n MLX requires Apple Silicon (M1/M2/M3) to run.") + + if not results.get("MLX Installed"): + print("\n Install MLX: pip install mlx") + + if not results.get("MLX Whisper Installed"): + print("\n Install mlx-whisper: pip install mlx-whisper") + + print("=" * 70 + "\n") + + return 0 if all_passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/whisper_live/server.py b/whisper_live/server.py index dc85f5e3..7381ce53 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -157,7 +157,7 @@ def __init__(self): def initialize_client( self, websocket, options, faster_whisper_custom_model_path, - whisper_tensorrt_path, trt_multilingual, trt_py_session=False, + mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=False, ): client: Optional[ServeClientBase] = None @@ -249,6 +249,10 @@ def initialize_client( if self.backend.is_mlx_whisper(): try: from whisper_live.backend.mlx_whisper_backend import ServeClientMLXWhisper + # Use server-provided model path if available, otherwise use client's model + if mlx_model_path is not None: + logging.info(f"Using server-specified MLX model: {mlx_model_path}") + options["model"] = mlx_model_path client = ServeClientMLXWhisper( websocket, language=options["language"], @@ -333,7 +337,7 @@ def get_audio_from_websocket(self, websocket): return np.frombuffer(frame_data, dtype=np.float32) def handle_new_connection(self, websocket, faster_whisper_custom_model_path, - whisper_tensorrt_path, trt_multilingual, trt_py_session=False): + mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=False): try: logging.info("New client connected") options = websocket.recv() @@ -347,7 +351,7 @@ def handle_new_connection(self, websocket, faster_whisper_custom_model_path, if self.backend.is_tensorrt(): self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE) self.initialize_client(websocket, options, faster_whisper_custom_model_path, - whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session) + mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session) return True except json.JSONDecodeError: logging.error("Failed to decode JSON from client") @@ -379,9 +383,10 @@ def process_audio_frames(self, websocket): return True def recv_audio(self, - websocket, + websocket, backend: BackendType = BackendType.FASTER_WHISPER, faster_whisper_custom_model_path=None, + mlx_model_path=None, whisper_tensorrt_path=None, trt_multilingual=False, trt_py_session=False): @@ -411,7 +416,7 @@ def recv_audio(self, """ self.backend = backend if not self.handle_new_connection(websocket, faster_whisper_custom_model_path, - whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session): + mlx_model_path, whisper_tensorrt_path, trt_multilingual, trt_py_session=trt_py_session): return try: @@ -433,6 +438,7 @@ def run(self, port=9090, backend="tensorrt", faster_whisper_custom_model_path=None, + mlx_model_path=None, whisper_tensorrt_path=None, trt_multilingual=False, trt_py_session=False, @@ -454,7 +460,7 @@ def run(self, if whisper_tensorrt_path is not None and not os.path.exists(whisper_tensorrt_path): raise ValueError(f"TensorRT model '{whisper_tensorrt_path}' is not a valid path.") if single_model: - if faster_whisper_custom_model_path or whisper_tensorrt_path: + if faster_whisper_custom_model_path or whisper_tensorrt_path or mlx_model_path: logging.info("Custom model option was provided. Switching to single model mode.") self.single_model = True # TODO: load model initially @@ -467,6 +473,7 @@ def run(self, self.recv_audio, backend=BackendType(backend), faster_whisper_custom_model_path=faster_whisper_custom_model_path, + mlx_model_path=mlx_model_path, whisper_tensorrt_path=whisper_tensorrt_path, trt_multilingual=trt_multilingual, trt_py_session=trt_py_session, From 4d8917bf02d972868e1701409b715deb597539ed Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 18 Nov 2025 15:51:02 +0000 Subject: [PATCH 3/3] Remove test scripts - not needed for PR --- test_mlx_microphone.py | 122 ---------------------- verify_mlx_gpu.py | 228 ----------------------------------------- 2 files changed, 350 deletions(-) delete mode 100755 test_mlx_microphone.py delete mode 100755 verify_mlx_gpu.py diff --git a/test_mlx_microphone.py b/test_mlx_microphone.py deleted file mode 100755 index 33492f41..00000000 --- a/test_mlx_microphone.py +++ /dev/null @@ -1,122 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for MLX Whisper backend with microphone input. - -This script connects to a WhisperLive server running the MLX backend -and streams audio from your microphone for real-time transcription. - -Usage: - python test_mlx_microphone.py --host localhost --port 9090 --model small.en - -Before running: - 1. Start the server: python run_server.py --backend mlx_whisper --port 9090 - 2. Install client dependencies: pip install -e . -""" - -import argparse -import logging -from whisper_live.client import Client - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s' -) - - -def main(): - parser = argparse.ArgumentParser( - description='Test MLX Whisper backend with microphone input' - ) - parser.add_argument( - '--host', - type=str, - default='localhost', - help='Server host (default: localhost)' - ) - parser.add_argument( - '--port', - type=int, - default=9090, - help='Server port (default: 9090)' - ) - parser.add_argument( - '--model', - type=str, - default='small.en', - help='Model size (e.g., tiny, base, small.en, medium, large-v3, turbo)' - ) - parser.add_argument( - '--lang', - type=str, - default=None, - help='Language code (e.g., en, es, fr). Auto-detect if not specified.' - ) - parser.add_argument( - '--translate', - action='store_true', - help='Translate to English instead of transcribing' - ) - parser.add_argument( - '--output', - type=str, - default='output_mlx.srt', - help='Output SRT file path (default: output_mlx.srt)' - ) - - args = parser.parse_args() - - print("=" * 70) - print("MLX Whisper Live Transcription Test") - print("=" * 70) - print(f"Server: {args.host}:{args.port}") - print(f"Model: {args.model}") - print(f"Language: {args.lang if args.lang else 'Auto-detect'}") - print(f"Task: {'Translate to English' if args.translate else 'Transcribe'}") - print(f"Output: {args.output}") - print("=" * 70) - print("\nStarting microphone recording...") - print("Speak into your microphone. Press Ctrl+C to stop.\n") - - try: - # Initialize client with MLX backend - client = Client( - host=args.host, - port=args.port, - lang=args.lang, - translate=args.translate, - model=args.model, - srt_file_path=args.output, - use_vad=True, - log_transcription=True, - ) - - # The client automatically starts recording when initialized - # Keep the script running until user interrupts - print("Recording... (Press Ctrl+C to stop)") - - # Wait indefinitely (client runs in background threads) - import time - while True: - time.sleep(1) - - except KeyboardInterrupt: - print("\n\nStopping recording...") - if 'client' in locals(): - client.close_websocket() - print(f"Transcription saved to: {args.output}") - print("\nThank you for using MLX Whisper!") - - except ConnectionRefusedError: - print(f"\n ERROR: Could not connect to server at {args.host}:{args.port}") - print("\nMake sure the server is running:") - print(f" python run_server.py --backend mlx_whisper --port {args.port}") - - except Exception as e: - logging.error(f"Error: {e}") - import traceback - traceback.print_exc() - - -if __name__ == "__main__": - main() diff --git a/verify_mlx_gpu.py b/verify_mlx_gpu.py deleted file mode 100755 index 42e8965e..00000000 --- a/verify_mlx_gpu.py +++ /dev/null @@ -1,228 +0,0 @@ -#!/usr/bin/env python3 -""" -MLX GPU Verification Script - -This script verifies that MLX is installed and can access Apple's -Neural Engine and GPU on your Mac. - -Usage: - python verify_mlx_gpu.py -""" - -import sys -import platform - - -def check_apple_silicon(): - """Check if running on Apple Silicon.""" - print("=" * 70) - print("Checking Apple Silicon Status") - print("=" * 70) - - machine = platform.machine() - print(f"Machine architecture: {machine}") - - if machine == "arm64": - print("✓ Running on Apple Silicon (M1/M2/M3)") - return True - else: - print("✗ NOT running on Apple Silicon") - print(" MLX is optimized for Apple Silicon Macs.") - print(f" Your architecture: {machine}") - return False - - -def check_mlx_installation(): - """Check if MLX is installed.""" - print("\n" + "=" * 70) - print("Checking MLX Installation") - print("=" * 70) - - try: - import mlx - import mlx.core as mx - print(f"✓ MLX installed: version {mlx.__version__}") - return True - except ImportError as e: - print("✗ MLX not installed") - print("\n Install with: pip install mlx") - return False - - -def check_mlx_whisper(): - """Check if mlx-whisper is installed.""" - print("\n" + "=" * 70) - print("Checking mlx-whisper Installation") - print("=" * 70) - - try: - import mlx_whisper - print("✓ mlx-whisper is installed") - return True - except ImportError: - print("✗ mlx-whisper not installed") - print("\n Install with: pip install mlx-whisper") - return False - - -def test_mlx_gpu(): - """Test MLX GPU operations.""" - print("\n" + "=" * 70) - print("Testing MLX GPU/Neural Engine Access") - print("=" * 70) - - try: - import mlx.core as mx - - # Create a simple array and perform computation - print("\nCreating test array and performing GPU computation...") - a = mx.array([1.0, 2.0, 3.0, 4.0]) - b = mx.array([5.0, 6.0, 7.0, 8.0]) - - # Force evaluation on GPU - c = a + b - mx.eval(c) - - print(f"Input a: {a}") - print(f"Input b: {b}") - print(f"Result (a + b): {c}") - print("\n✓ MLX can perform GPU computations successfully!") - - # Check available memory - print("\nChecking device memory...") - # MLX uses unified memory on Apple Silicon - print("✓ MLX uses unified memory architecture") - print(" (Shared between CPU, GPU, and Neural Engine)") - - return True - - except Exception as e: - print(f"✗ Error testing MLX GPU: {e}") - return False - - -def test_mlx_whisper_model(): - """Test loading a small MLX Whisper model.""" - print("\n" + "=" * 70) - print("Testing MLX Whisper Model Loading") - print("=" * 70) - - try: - import mlx_whisper - import numpy as np - - print("\nAttempting to load a tiny model (this may take a moment)...") - print("Model: mlx-community/whisper-tiny-mlx") - - # Create dummy audio (1 second of silence) - dummy_audio = np.zeros(16000, dtype=np.float32) - - # Try to transcribe (this will download and load the model) - result = mlx_whisper.transcribe( - dummy_audio, - path_or_hf_repo="mlx-community/whisper-tiny-mlx", - verbose=False - ) - - print("✓ Successfully loaded and tested MLX Whisper model!") - print("✓ Model is using Apple Neural Engine and GPU for inference") - - return True - - except Exception as e: - print(f"✗ Error testing MLX Whisper: {e}") - print("\n This might be due to:") - print(" - Network issues downloading the model") - print(" - Insufficient disk space") - print(" - Missing dependencies") - return False - - -def monitor_gpu_usage(): - """Provide instructions for monitoring GPU usage.""" - print("\n" + "=" * 70) - print("How to Monitor GPU Usage During Transcription") - print("=" * 70) - - print("\n1. Using Activity Monitor (GUI):") - print(" - Open Activity Monitor (Cmd+Space, type 'Activity Monitor')") - print(" - Go to the 'GPU' tab") - print(" - Look for the 'python' process") - print(" - Watch 'GPU Time' and 'GPU Memory' columns increase during transcription") - - print("\n2. Using powermetrics (Terminal):") - print(" Run this in a separate terminal while transcribing:") - print(" sudo powermetrics --samplers gpu_power -i 1000") - print(" (Shows GPU power consumption, higher = more GPU usage)") - - print("\n3. Using asitop (Third-party tool):") - print(" Install: pip install asitop") - print(" Run: sudo asitop") - print(" (Shows real-time GPU, Neural Engine, and CPU usage)") - - print("\n4. Quick verification:") - print(" - Run transcription with MLX backend") - print(" - Open Activity Monitor → Window → GPU History") - print(" - You should see spikes in GPU usage when audio is being transcribed") - - -def main(): - print("\n" + "=" * 70) - print(" MLX Whisper GPU Verification Tool") - print("=" * 70 + "\n") - - results = { - "Apple Silicon": check_apple_silicon(), - "MLX Installed": check_mlx_installation(), - "MLX Whisper Installed": check_mlx_whisper(), - } - - # Only run GPU tests if basic checks pass - if results["Apple Silicon"] and results["MLX Installed"]: - results["MLX GPU Test"] = test_mlx_gpu() - - if results["MLX Whisper Installed"]: - print("\nNote: The next test will download a small model (~40MB)") - response = input("Continue with model test? [y/N]: ").strip().lower() - if response == 'y': - results["MLX Whisper Model Test"] = test_mlx_whisper_model() - - # Always show monitoring instructions - monitor_gpu_usage() - - # Summary - print("\n" + "=" * 70) - print("Summary") - print("=" * 70) - - all_passed = all(results.values()) - - for check, passed in results.items(): - status = "✓ PASS" if passed else "✗ FAIL" - print(f"{status}: {check}") - - if all_passed: - print("\n🎉 All checks passed! MLX Whisper is ready to use.") - print("\nYou can now run the server with:") - print(" python run_server.py --backend mlx_whisper --mlx_model_path small.en") - print("\nAnd test with microphone:") - print(" python test_mlx_microphone.py --model small.en") - else: - print("\n⚠️ Some checks failed. Please address the issues above.") - - if not results.get("Apple Silicon"): - print("\n MLX requires Apple Silicon (M1/M2/M3) to run.") - - if not results.get("MLX Installed"): - print("\n Install MLX: pip install mlx") - - if not results.get("MLX Whisper Installed"): - print("\n Install mlx-whisper: pip install mlx-whisper") - - print("=" * 70 + "\n") - - return 0 if all_passed else 1 - - -if __name__ == "__main__": - sys.exit(main())