Skip to content

Commit a94695b

Browse files
committed
Add tests
1 parent 5450942 commit a94695b

File tree

2 files changed

+281
-1
lines changed

2 files changed

+281
-1
lines changed

src/elevenlabs/realtime/scribe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def _build_websocket_url(
389389
if language_code is not None:
390390
params.append(f"language_code={language_code}")
391391
if include_timestamps is not None:
392-
params.append(f"include_timestamps={include_timestamps}")
392+
params.append(f"include_timestamps={str(include_timestamps).lower()}")
393393

394394
query_string = "&".join(params)
395395
return f"{base}/v1/speech-to-text/realtime?{query_string}"

tests/test_stt_realtime.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
"""Tests for realtime speech-to-text (Scribe) functionality.
2+
3+
These tests cover URL building, validation, and event handling behavior
4+
that don't require an actual WebSocket connection.
5+
"""
6+
7+
from unittest.mock import AsyncMock, MagicMock, patch
8+
9+
import pytest
10+
11+
from elevenlabs.realtime.connection import RealtimeConnection, RealtimeEvents
12+
from elevenlabs.realtime.scribe import (
13+
AudioFormat,
14+
CommitStrategy,
15+
ScribeRealtime,
16+
)
17+
18+
19+
class TestBuildWebsocketUrl:
20+
"""Tests for _build_websocket_url helper method"""
21+
22+
def setup_method(self):
23+
"""Set up test fixtures"""
24+
self.scribe = ScribeRealtime(api_key="test-api-key")
25+
26+
def test_builds_url_with_all_parameters(self):
27+
"""Test URL construction with required and optional parameters"""
28+
url = self.scribe._build_websocket_url(
29+
model_id="scribe_v2_realtime",
30+
audio_format="pcm_16000",
31+
commit_strategy="vad",
32+
vad_silence_threshold_secs=0.5,
33+
vad_threshold=0.3,
34+
min_speech_duration_ms=100,
35+
min_silence_duration_ms=200,
36+
language_code="es",
37+
include_timestamps=True
38+
)
39+
40+
# Base URL structure
41+
assert url.startswith("wss://api.elevenlabs.io/v1/speech-to-text/realtime?")
42+
43+
# Required parameters
44+
assert "model_id=scribe_v2_realtime" in url
45+
assert "audio_format=pcm_16000" in url
46+
assert "commit_strategy=vad" in url
47+
48+
# Optional parameters
49+
assert "vad_silence_threshold_secs=0.5" in url
50+
assert "vad_threshold=0.3" in url
51+
assert "min_speech_duration_ms=100" in url
52+
assert "min_silence_duration_ms=200" in url
53+
assert "language_code=es" in url
54+
assert "include_timestamps=true" in url
55+
56+
def test_optional_parameters_omitted_when_none(self):
57+
"""Test that None parameters are not included in URL"""
58+
url = self.scribe._build_websocket_url(
59+
model_id="scribe_v2_realtime",
60+
audio_format="pcm_16000",
61+
commit_strategy="manual",
62+
vad_silence_threshold_secs=None,
63+
language_code=None
64+
)
65+
66+
assert "vad_silence_threshold_secs" not in url
67+
assert "language_code" not in url
68+
69+
def test_url_converts_https_to_wss(self):
70+
"""Test that https base URLs are converted to wss"""
71+
scribe = ScribeRealtime(
72+
api_key="test-api-key",
73+
base_url="https://api.elevenlabs.io"
74+
)
75+
url = scribe._build_websocket_url(
76+
model_id="scribe_v2_realtime",
77+
audio_format="pcm_16000",
78+
commit_strategy="manual"
79+
)
80+
81+
assert url.startswith("wss://")
82+
assert not url.startswith("wss://wss://")
83+
84+
def test_url_converts_http_to_ws(self):
85+
"""Test that http base URLs are converted to ws"""
86+
scribe = ScribeRealtime(
87+
api_key="test-api-key",
88+
base_url="http://localhost:8080"
89+
)
90+
url = scribe._build_websocket_url(
91+
model_id="scribe_v2_realtime",
92+
audio_format="pcm_16000",
93+
commit_strategy="manual"
94+
)
95+
96+
assert url.startswith("ws://localhost:8080")
97+
98+
99+
class TestConnectValidation:
100+
"""Tests for connect method validation"""
101+
102+
def setup_method(self):
103+
"""Set up test fixtures"""
104+
self.scribe = ScribeRealtime(api_key="test-api-key")
105+
106+
@pytest.mark.asyncio
107+
async def test_connect_requires_model_id(self):
108+
"""Test that connect raises error without model_id"""
109+
with pytest.raises(ValueError, match="model_id is required"):
110+
await self.scribe.connect({}) # type: ignore
111+
112+
@pytest.mark.asyncio
113+
async def test_connect_audio_mode_requires_format_and_sample_rate(self):
114+
"""Test that audio mode requires both audio_format and sample_rate"""
115+
with pytest.raises(ValueError, match="audio_format and sample_rate are required"):
116+
await self.scribe.connect({
117+
"model_id": "scribe_v2_realtime"
118+
}) # type: ignore
119+
120+
with pytest.raises(ValueError, match="audio_format and sample_rate are required"):
121+
await self.scribe.connect({
122+
"model_id": "scribe_v2_realtime",
123+
"audio_format": AudioFormat.PCM_16000
124+
# missing sample_rate
125+
}) # type: ignore
126+
127+
with pytest.raises(ValueError, match="audio_format and sample_rate are required"):
128+
await self.scribe.connect({
129+
"model_id": "scribe_v2_realtime",
130+
"sample_rate": 16000
131+
# missing audio_format
132+
}) # type: ignore
133+
134+
@pytest.mark.asyncio
135+
@patch("elevenlabs.realtime.scribe.websocket_connect")
136+
async def test_connect_url_mode_requires_url(self, mock_ws_connect):
137+
"""Test that URL mode requires non-empty url parameter"""
138+
with pytest.raises(ValueError, match="url is required"):
139+
await self.scribe.connect({
140+
"model_id": "scribe_v2_realtime",
141+
"url": ""
142+
}) # type: ignore
143+
144+
145+
class TestConnectEnumHandling:
146+
"""Tests for correct enum value extraction when building WebSocket URLs.
147+
148+
Regression tests to ensure AudioFormat and CommitStrategy enums are
149+
converted to their string values (e.g., 'pcm_16000') rather than being
150+
passed as enum objects (which would result in 'AudioFormat.PCM_16000').
151+
"""
152+
153+
def setup_method(self):
154+
"""Set up test fixtures"""
155+
self.scribe = ScribeRealtime(api_key="test-api-key")
156+
157+
@pytest.mark.asyncio
158+
@patch("elevenlabs.realtime.scribe.websocket_connect", new_callable=AsyncMock)
159+
async def test_connect_audio_uses_enum_values_in_url(self, mock_ws_connect):
160+
"""Test that AudioFormat and CommitStrategy enum values are correctly extracted.
161+
162+
This is a regression test: previously, if .value was not called on enums,
163+
the URL would contain 'AudioFormat.PCM_16000' instead of 'pcm_16000'.
164+
"""
165+
mock_websocket = MagicMock()
166+
mock_ws_connect.return_value = mock_websocket
167+
# Mock the async iterator for the websocket (needed for message handler)
168+
mock_websocket.__aiter__ = MagicMock(return_value=iter([]))
169+
170+
await self.scribe.connect({
171+
"model_id": "scribe_v2_realtime",
172+
"audio_format": AudioFormat.PCM_16000,
173+
"sample_rate": 16000,
174+
"commit_strategy": CommitStrategy.VAD
175+
})
176+
177+
# Verify websocket_connect was called
178+
mock_ws_connect.assert_awaited_once()
179+
180+
# Extract the URL that was passed to websocket_connect
181+
call_args = mock_ws_connect.call_args
182+
url = call_args[0][0] # First positional argument
183+
184+
# Verify the URL contains the string values, not enum representations
185+
assert "audio_format=pcm_16000" in url, \
186+
f"URL should contain 'audio_format=pcm_16000', not enum repr. Got: {url}"
187+
assert "AudioFormat" not in url, \
188+
f"URL should not contain 'AudioFormat' enum name. Got: {url}"
189+
190+
assert "commit_strategy=vad" in url, \
191+
f"URL should contain 'commit_strategy=vad', not enum repr. Got: {url}"
192+
assert "CommitStrategy" not in url, \
193+
f"URL should not contain 'CommitStrategy' enum name. Got: {url}"
194+
195+
@pytest.mark.asyncio
196+
@patch("elevenlabs.realtime.scribe.websocket_connect", new_callable=AsyncMock)
197+
async def test_connect_audio_default_commit_strategy_is_manual(self, mock_ws_connect):
198+
"""Test that the default commit strategy is MANUAL when not specified."""
199+
mock_websocket = MagicMock()
200+
mock_ws_connect.return_value = mock_websocket
201+
mock_websocket.__aiter__ = MagicMock(return_value=iter([]))
202+
203+
await self.scribe.connect({
204+
"model_id": "scribe_v2_realtime",
205+
"audio_format": AudioFormat.PCM_16000,
206+
"sample_rate": 16000
207+
# commit_strategy not specified
208+
})
209+
210+
url = mock_ws_connect.call_args[0][0]
211+
assert "commit_strategy=manual" in url
212+
213+
214+
class TestRealtimeConnectionEventHandling:
215+
"""Tests for RealtimeConnection event handling behavior"""
216+
217+
def setup_method(self):
218+
"""Set up test fixtures"""
219+
self.mock_websocket = MagicMock()
220+
self.connection = RealtimeConnection(
221+
websocket=self.mock_websocket,
222+
current_sample_rate=16000,
223+
ffmpeg_process=None
224+
)
225+
226+
def test_emit_calls_all_registered_handlers(self):
227+
"""Test that emitting an event calls all registered handlers in order"""
228+
call_order = []
229+
230+
def handler1(data):
231+
call_order.append(("handler1", data))
232+
233+
def handler2(data):
234+
call_order.append(("handler2", data))
235+
236+
self.connection.on("test_event", handler1)
237+
self.connection.on("test_event", handler2)
238+
self.connection._emit("test_event", {"value": 42})
239+
240+
assert call_order == [
241+
("handler1", {"value": 42}),
242+
("handler2", {"value": 42})
243+
]
244+
245+
def test_emit_isolates_handler_exceptions(self, capsys):
246+
"""Test that an exception in one handler doesn't prevent others from running"""
247+
results = []
248+
249+
def bad_handler(data):
250+
raise ValueError("Handler error")
251+
252+
def good_handler(data):
253+
results.append(data)
254+
255+
self.connection.on("test_event", bad_handler)
256+
self.connection.on("test_event", good_handler)
257+
258+
# Should not raise, and good_handler should still be called
259+
self.connection._emit("test_event", {"value": "test"})
260+
261+
assert results == [{"value": "test"}]
262+
captured = capsys.readouterr()
263+
assert "Error in event handler" in captured.out
264+
265+
def test_emit_with_no_handlers_does_not_raise(self):
266+
"""Test that emitting to an event with no handlers is a no-op"""
267+
# Should not raise
268+
self.connection._emit("nonexistent_event", {"data": "test"})
269+
270+
def test_handlers_receive_correct_arguments(self):
271+
"""Test that handlers receive all emitted arguments"""
272+
received_args = []
273+
274+
def handler(*args):
275+
received_args.extend(args)
276+
277+
self.connection.on(RealtimeEvents.PARTIAL_TRANSCRIPT, handler)
278+
self.connection._emit(RealtimeEvents.PARTIAL_TRANSCRIPT, "arg1", "arg2", {"key": "value"})
279+
280+
assert received_args == ["arg1", "arg2", {"key": "value"}]

0 commit comments

Comments
 (0)