|
| 1 | +"""Tests for realtime speech-to-text (Scribe) functionality. |
| 2 | +
|
| 3 | +These tests cover URL building, validation, and event handling behavior |
| 4 | +that don't require an actual WebSocket connection. |
| 5 | +""" |
| 6 | + |
| 7 | +from unittest.mock import AsyncMock, MagicMock, patch |
| 8 | + |
| 9 | +import pytest |
| 10 | + |
| 11 | +from elevenlabs.realtime.connection import RealtimeConnection, RealtimeEvents |
| 12 | +from elevenlabs.realtime.scribe import ( |
| 13 | + AudioFormat, |
| 14 | + CommitStrategy, |
| 15 | + ScribeRealtime, |
| 16 | +) |
| 17 | + |
| 18 | + |
| 19 | +class TestBuildWebsocketUrl: |
| 20 | + """Tests for _build_websocket_url helper method""" |
| 21 | + |
| 22 | + def setup_method(self): |
| 23 | + """Set up test fixtures""" |
| 24 | + self.scribe = ScribeRealtime(api_key="test-api-key") |
| 25 | + |
| 26 | + def test_builds_url_with_all_parameters(self): |
| 27 | + """Test URL construction with required and optional parameters""" |
| 28 | + url = self.scribe._build_websocket_url( |
| 29 | + model_id="scribe_v2_realtime", |
| 30 | + audio_format="pcm_16000", |
| 31 | + commit_strategy="vad", |
| 32 | + vad_silence_threshold_secs=0.5, |
| 33 | + vad_threshold=0.3, |
| 34 | + min_speech_duration_ms=100, |
| 35 | + min_silence_duration_ms=200, |
| 36 | + language_code="es", |
| 37 | + include_timestamps=True |
| 38 | + ) |
| 39 | + |
| 40 | + # Base URL structure |
| 41 | + assert url.startswith("wss://api.elevenlabs.io/v1/speech-to-text/realtime?") |
| 42 | + |
| 43 | + # Required parameters |
| 44 | + assert "model_id=scribe_v2_realtime" in url |
| 45 | + assert "audio_format=pcm_16000" in url |
| 46 | + assert "commit_strategy=vad" in url |
| 47 | + |
| 48 | + # Optional parameters |
| 49 | + assert "vad_silence_threshold_secs=0.5" in url |
| 50 | + assert "vad_threshold=0.3" in url |
| 51 | + assert "min_speech_duration_ms=100" in url |
| 52 | + assert "min_silence_duration_ms=200" in url |
| 53 | + assert "language_code=es" in url |
| 54 | + assert "include_timestamps=true" in url |
| 55 | + |
| 56 | + def test_optional_parameters_omitted_when_none(self): |
| 57 | + """Test that None parameters are not included in URL""" |
| 58 | + url = self.scribe._build_websocket_url( |
| 59 | + model_id="scribe_v2_realtime", |
| 60 | + audio_format="pcm_16000", |
| 61 | + commit_strategy="manual", |
| 62 | + vad_silence_threshold_secs=None, |
| 63 | + language_code=None |
| 64 | + ) |
| 65 | + |
| 66 | + assert "vad_silence_threshold_secs" not in url |
| 67 | + assert "language_code" not in url |
| 68 | + |
| 69 | + def test_url_converts_https_to_wss(self): |
| 70 | + """Test that https base URLs are converted to wss""" |
| 71 | + scribe = ScribeRealtime( |
| 72 | + api_key="test-api-key", |
| 73 | + base_url="https://api.elevenlabs.io" |
| 74 | + ) |
| 75 | + url = scribe._build_websocket_url( |
| 76 | + model_id="scribe_v2_realtime", |
| 77 | + audio_format="pcm_16000", |
| 78 | + commit_strategy="manual" |
| 79 | + ) |
| 80 | + |
| 81 | + assert url.startswith("wss://") |
| 82 | + assert not url.startswith("wss://wss://") |
| 83 | + |
| 84 | + def test_url_converts_http_to_ws(self): |
| 85 | + """Test that http base URLs are converted to ws""" |
| 86 | + scribe = ScribeRealtime( |
| 87 | + api_key="test-api-key", |
| 88 | + base_url="http://localhost:8080" |
| 89 | + ) |
| 90 | + url = scribe._build_websocket_url( |
| 91 | + model_id="scribe_v2_realtime", |
| 92 | + audio_format="pcm_16000", |
| 93 | + commit_strategy="manual" |
| 94 | + ) |
| 95 | + |
| 96 | + assert url.startswith("ws://localhost:8080") |
| 97 | + |
| 98 | + |
| 99 | +class TestConnectValidation: |
| 100 | + """Tests for connect method validation""" |
| 101 | + |
| 102 | + def setup_method(self): |
| 103 | + """Set up test fixtures""" |
| 104 | + self.scribe = ScribeRealtime(api_key="test-api-key") |
| 105 | + |
| 106 | + @pytest.mark.asyncio |
| 107 | + async def test_connect_requires_model_id(self): |
| 108 | + """Test that connect raises error without model_id""" |
| 109 | + with pytest.raises(ValueError, match="model_id is required"): |
| 110 | + await self.scribe.connect({}) # type: ignore |
| 111 | + |
| 112 | + @pytest.mark.asyncio |
| 113 | + async def test_connect_audio_mode_requires_format_and_sample_rate(self): |
| 114 | + """Test that audio mode requires both audio_format and sample_rate""" |
| 115 | + with pytest.raises(ValueError, match="audio_format and sample_rate are required"): |
| 116 | + await self.scribe.connect({ |
| 117 | + "model_id": "scribe_v2_realtime" |
| 118 | + }) # type: ignore |
| 119 | + |
| 120 | + with pytest.raises(ValueError, match="audio_format and sample_rate are required"): |
| 121 | + await self.scribe.connect({ |
| 122 | + "model_id": "scribe_v2_realtime", |
| 123 | + "audio_format": AudioFormat.PCM_16000 |
| 124 | + # missing sample_rate |
| 125 | + }) # type: ignore |
| 126 | + |
| 127 | + with pytest.raises(ValueError, match="audio_format and sample_rate are required"): |
| 128 | + await self.scribe.connect({ |
| 129 | + "model_id": "scribe_v2_realtime", |
| 130 | + "sample_rate": 16000 |
| 131 | + # missing audio_format |
| 132 | + }) # type: ignore |
| 133 | + |
| 134 | + @pytest.mark.asyncio |
| 135 | + @patch("elevenlabs.realtime.scribe.websocket_connect") |
| 136 | + async def test_connect_url_mode_requires_url(self, mock_ws_connect): |
| 137 | + """Test that URL mode requires non-empty url parameter""" |
| 138 | + with pytest.raises(ValueError, match="url is required"): |
| 139 | + await self.scribe.connect({ |
| 140 | + "model_id": "scribe_v2_realtime", |
| 141 | + "url": "" |
| 142 | + }) # type: ignore |
| 143 | + |
| 144 | + |
| 145 | +class TestConnectEnumHandling: |
| 146 | + """Tests for correct enum value extraction when building WebSocket URLs. |
| 147 | + |
| 148 | + Regression tests to ensure AudioFormat and CommitStrategy enums are |
| 149 | + converted to their string values (e.g., 'pcm_16000') rather than being |
| 150 | + passed as enum objects (which would result in 'AudioFormat.PCM_16000'). |
| 151 | + """ |
| 152 | + |
| 153 | + def setup_method(self): |
| 154 | + """Set up test fixtures""" |
| 155 | + self.scribe = ScribeRealtime(api_key="test-api-key") |
| 156 | + |
| 157 | + @pytest.mark.asyncio |
| 158 | + @patch("elevenlabs.realtime.scribe.websocket_connect", new_callable=AsyncMock) |
| 159 | + async def test_connect_audio_uses_enum_values_in_url(self, mock_ws_connect): |
| 160 | + """Test that AudioFormat and CommitStrategy enum values are correctly extracted. |
| 161 | + |
| 162 | + This is a regression test: previously, if .value was not called on enums, |
| 163 | + the URL would contain 'AudioFormat.PCM_16000' instead of 'pcm_16000'. |
| 164 | + """ |
| 165 | + mock_websocket = MagicMock() |
| 166 | + mock_ws_connect.return_value = mock_websocket |
| 167 | + # Mock the async iterator for the websocket (needed for message handler) |
| 168 | + mock_websocket.__aiter__ = MagicMock(return_value=iter([])) |
| 169 | + |
| 170 | + await self.scribe.connect({ |
| 171 | + "model_id": "scribe_v2_realtime", |
| 172 | + "audio_format": AudioFormat.PCM_16000, |
| 173 | + "sample_rate": 16000, |
| 174 | + "commit_strategy": CommitStrategy.VAD |
| 175 | + }) |
| 176 | + |
| 177 | + # Verify websocket_connect was called |
| 178 | + mock_ws_connect.assert_awaited_once() |
| 179 | + |
| 180 | + # Extract the URL that was passed to websocket_connect |
| 181 | + call_args = mock_ws_connect.call_args |
| 182 | + url = call_args[0][0] # First positional argument |
| 183 | + |
| 184 | + # Verify the URL contains the string values, not enum representations |
| 185 | + assert "audio_format=pcm_16000" in url, \ |
| 186 | + f"URL should contain 'audio_format=pcm_16000', not enum repr. Got: {url}" |
| 187 | + assert "AudioFormat" not in url, \ |
| 188 | + f"URL should not contain 'AudioFormat' enum name. Got: {url}" |
| 189 | + |
| 190 | + assert "commit_strategy=vad" in url, \ |
| 191 | + f"URL should contain 'commit_strategy=vad', not enum repr. Got: {url}" |
| 192 | + assert "CommitStrategy" not in url, \ |
| 193 | + f"URL should not contain 'CommitStrategy' enum name. Got: {url}" |
| 194 | + |
| 195 | + @pytest.mark.asyncio |
| 196 | + @patch("elevenlabs.realtime.scribe.websocket_connect", new_callable=AsyncMock) |
| 197 | + async def test_connect_audio_default_commit_strategy_is_manual(self, mock_ws_connect): |
| 198 | + """Test that the default commit strategy is MANUAL when not specified.""" |
| 199 | + mock_websocket = MagicMock() |
| 200 | + mock_ws_connect.return_value = mock_websocket |
| 201 | + mock_websocket.__aiter__ = MagicMock(return_value=iter([])) |
| 202 | + |
| 203 | + await self.scribe.connect({ |
| 204 | + "model_id": "scribe_v2_realtime", |
| 205 | + "audio_format": AudioFormat.PCM_16000, |
| 206 | + "sample_rate": 16000 |
| 207 | + # commit_strategy not specified |
| 208 | + }) |
| 209 | + |
| 210 | + url = mock_ws_connect.call_args[0][0] |
| 211 | + assert "commit_strategy=manual" in url |
| 212 | + |
| 213 | + |
| 214 | +class TestRealtimeConnectionEventHandling: |
| 215 | + """Tests for RealtimeConnection event handling behavior""" |
| 216 | + |
| 217 | + def setup_method(self): |
| 218 | + """Set up test fixtures""" |
| 219 | + self.mock_websocket = MagicMock() |
| 220 | + self.connection = RealtimeConnection( |
| 221 | + websocket=self.mock_websocket, |
| 222 | + current_sample_rate=16000, |
| 223 | + ffmpeg_process=None |
| 224 | + ) |
| 225 | + |
| 226 | + def test_emit_calls_all_registered_handlers(self): |
| 227 | + """Test that emitting an event calls all registered handlers in order""" |
| 228 | + call_order = [] |
| 229 | + |
| 230 | + def handler1(data): |
| 231 | + call_order.append(("handler1", data)) |
| 232 | + |
| 233 | + def handler2(data): |
| 234 | + call_order.append(("handler2", data)) |
| 235 | + |
| 236 | + self.connection.on("test_event", handler1) |
| 237 | + self.connection.on("test_event", handler2) |
| 238 | + self.connection._emit("test_event", {"value": 42}) |
| 239 | + |
| 240 | + assert call_order == [ |
| 241 | + ("handler1", {"value": 42}), |
| 242 | + ("handler2", {"value": 42}) |
| 243 | + ] |
| 244 | + |
| 245 | + def test_emit_isolates_handler_exceptions(self, capsys): |
| 246 | + """Test that an exception in one handler doesn't prevent others from running""" |
| 247 | + results = [] |
| 248 | + |
| 249 | + def bad_handler(data): |
| 250 | + raise ValueError("Handler error") |
| 251 | + |
| 252 | + def good_handler(data): |
| 253 | + results.append(data) |
| 254 | + |
| 255 | + self.connection.on("test_event", bad_handler) |
| 256 | + self.connection.on("test_event", good_handler) |
| 257 | + |
| 258 | + # Should not raise, and good_handler should still be called |
| 259 | + self.connection._emit("test_event", {"value": "test"}) |
| 260 | + |
| 261 | + assert results == [{"value": "test"}] |
| 262 | + captured = capsys.readouterr() |
| 263 | + assert "Error in event handler" in captured.out |
| 264 | + |
| 265 | + def test_emit_with_no_handlers_does_not_raise(self): |
| 266 | + """Test that emitting to an event with no handlers is a no-op""" |
| 267 | + # Should not raise |
| 268 | + self.connection._emit("nonexistent_event", {"data": "test"}) |
| 269 | + |
| 270 | + def test_handlers_receive_correct_arguments(self): |
| 271 | + """Test that handlers receive all emitted arguments""" |
| 272 | + received_args = [] |
| 273 | + |
| 274 | + def handler(*args): |
| 275 | + received_args.extend(args) |
| 276 | + |
| 277 | + self.connection.on(RealtimeEvents.PARTIAL_TRANSCRIPT, handler) |
| 278 | + self.connection._emit(RealtimeEvents.PARTIAL_TRANSCRIPT, "arg1", "arg2", {"key": "value"}) |
| 279 | + |
| 280 | + assert received_args == ["arg1", "arg2", {"key": "value"}] |
0 commit comments