Skip to content

Commit 9cafe05

Browse files
fix: astream output (#358)
* fix astream output Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com> * review comments Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com> * review comment Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com> * adding tests Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com> * adding tests Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com> * review comment Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com> --------- Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent 3204b3a commit 9cafe05

File tree

2 files changed

+247
-1
lines changed

2 files changed

+247
-1
lines changed

mellea/core/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ async def astream(self) -> str:
275275
raise RuntimeError(
276276
f"Cannot use `ModelOutputThunk.astream()` when the generate function is using `{self._generate_type.name}`"
277277
)
278+
# Beginning value
279+
beginning_length = (
280+
0 if self._underlying_value is None else len(str(self._underlying_value))
281+
) # type: ignore
278282

279283
exception_to_raise = None
280284
try:
@@ -350,12 +354,17 @@ async def astream(self) -> str:
350354
assert self.parsed_repr is not None, (
351355
"enforce constraint that a computed ModelOutputThunk has a non-None parsed_repr"
352356
)
357+
return self._underlying_value # type: ignore
353358

354359
# Re-raise exception after cleanup if one occurred
355360
if exception_to_raise is not None:
356361
raise exception_to_raise
357362

358-
return self._underlying_value # type: ignore
363+
return (
364+
self._underlying_value
365+
if beginning_length == 0
366+
else self._underlying_value[beginning_length:] # type: ignore
367+
)
359368

360369
def __repr__(self):
361370
"""Provides a python-parsable representation (usually).
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""Tests for ModelOutputThunk.astream() incremental return behavior.
2+
3+
Tests that astream() returns only new content added since the beginning of
4+
each astream() call, not the entire accumulated value.
5+
"""
6+
7+
import pytest
8+
9+
from mellea.backends import ModelOption
10+
from mellea.core import CBlock, ModelOutputThunk
11+
from mellea.stdlib.context import SimpleContext
12+
from mellea.stdlib.session import start_session
13+
14+
15+
@pytest.mark.ollama
16+
@pytest.mark.llm
17+
async def test_astream_returns_incremental_chunks():
18+
"""Test that astream() returns only new content, not accumulated content.
19+
20+
This tests the fix where beginning_length is captured at the start of
21+
astream() and the return value is sliced to only include new content.
22+
"""
23+
session = start_session()
24+
model_opts = {ModelOption.STREAM: True}
25+
26+
mot, _ = await session.backend.generate_from_context(
27+
CBlock("Count from 1 to 5 slowly."), SimpleContext(), model_options=model_opts
28+
)
29+
30+
# First astream call - should return content from beginning
31+
chunk1 = await mot.astream()
32+
assert chunk1 is not None, "First chunk should not be None"
33+
assert len(chunk1) > 0, "First chunk should have content"
34+
35+
# Second astream call - should return only NEW content since first call
36+
chunk2 = await mot.astream()
37+
38+
if not mot.is_computed():
39+
# If not computed, chunk2 should be new content only
40+
assert chunk2 is not None, "Second chunk should not be None if not computed"
41+
42+
# The key test: chunk2 should NOT start with chunk1
43+
# (it should be incremental, not accumulated)
44+
if len(chunk2) > 0:
45+
# chunk2 should be different from chunk1 (new content)
46+
assert chunk2 != chunk1, (
47+
"Second chunk should be different from first (incremental)"
48+
)
49+
50+
# Get final value
51+
final_val = await mot.avalue()
52+
53+
# Final value should contain both chunks in order
54+
assert final_val.startswith(chunk1), (
55+
"Final value should start with first chunk"
56+
)
57+
# The concatenation of chunks should be a prefix of or equal to final value
58+
accumulated = chunk1 + chunk2
59+
assert final_val.startswith(accumulated) or accumulated.startswith(
60+
final_val
61+
), "Accumulated chunks should match final value progression"
62+
else:
63+
# If computed after first astream, chunk2 should be empty or the remainder
64+
final_val = await mot.avalue()
65+
# chunk1 should be a prefix of final value
66+
assert final_val.startswith(chunk1), "Final value should start with first chunk"
67+
68+
69+
@pytest.mark.ollama
70+
@pytest.mark.llm
71+
async def test_astream_multiple_calls_accumulate_correctly():
72+
"""Test that multiple astream() calls accumulate to the final value.
73+
74+
Note: The final astream() call that marks the thunk as computed returns
75+
the FULL value (line 350 in base.py), not just the incremental part.
76+
"""
77+
session = start_session()
78+
model_opts = {ModelOption.STREAM: True}
79+
80+
mot, _ = await session.backend.generate_from_context(
81+
CBlock("Write a short sentence."), SimpleContext(), model_options=model_opts
82+
)
83+
84+
accumulated = ""
85+
chunks = []
86+
87+
# Stream until computed
88+
while not mot.is_computed():
89+
chunk = await mot.astream()
90+
if chunk:
91+
chunks.append(chunk)
92+
# Only accumulate if this wasn't the final (completing) chunk
93+
if not mot.is_computed():
94+
accumulated += chunk
95+
96+
# Safety: don't loop forever
97+
if len(chunks) > 100:
98+
break
99+
100+
# Get final value
101+
final_val = await mot.avalue()
102+
103+
# The last chunk should be the full value when computed
104+
if len(chunks) > 0:
105+
assert chunks[-1] == final_val, (
106+
f"Last chunk (when computed) should be full value.\n"
107+
f"Last chunk: {chunks[-1]!r}\n"
108+
f"Final: {final_val!r}"
109+
)
110+
111+
# All chunks except the last should be incremental
112+
if len(chunks) > 1:
113+
incremental_accumulated = "".join(chunks[:-1])
114+
assert final_val.startswith(incremental_accumulated), (
115+
f"Incremental chunks should be prefix of final value.\n"
116+
f"Accumulated: {incremental_accumulated!r}\n"
117+
f"Final: {final_val!r}"
118+
)
119+
120+
121+
@pytest.mark.ollama
122+
@pytest.mark.llm
123+
async def test_astream_beginning_length_tracking():
124+
"""Test that beginning_length is correctly tracked across astream calls.
125+
126+
This specifically tests the logic at lines 278-281 where beginning_length
127+
is captured at the start of each astream() call.
128+
"""
129+
session = start_session()
130+
model_opts = {ModelOption.STREAM: True}
131+
132+
mot, _ = await session.backend.generate_from_context(
133+
CBlock("Say hello."), SimpleContext(), model_options=model_opts
134+
)
135+
136+
# First call: beginning_length should be 0 (or length of any pre-existing value)
137+
chunk1 = await mot.astream()
138+
139+
# Second call: beginning_length should be captured at start of this call
140+
chunk2 = await mot.astream()
141+
142+
if chunk2 and len(chunk2) > 0:
143+
# chunk2 should not include chunk1's content
144+
# This verifies the slicing logic at lines 352-356
145+
if chunk1:
146+
assert not chunk2.startswith(chunk1), (
147+
"Second chunk should not start with first chunk (should be incremental)"
148+
)
149+
150+
151+
@pytest.mark.ollama
152+
@pytest.mark.llm
153+
async def test_astream_empty_beginning():
154+
"""Test astream when _underlying_value starts as None."""
155+
session = start_session()
156+
model_opts = {ModelOption.STREAM: True}
157+
158+
mot, _ = await session.backend.generate_from_context(
159+
CBlock("Hi"), SimpleContext(), model_options=model_opts
160+
)
161+
162+
# At the start, _underlying_value might be None
163+
# beginning_length should be 0 in this case (line 280)
164+
chunk = await mot.astream()
165+
166+
assert chunk is not None, "Should get a chunk even when starting from None"
167+
168+
# When beginning_length is 0, should return full _underlying_value (line 354)
169+
if mot._underlying_value:
170+
assert chunk == mot._underlying_value or mot._underlying_value.startswith(
171+
chunk
172+
), "When beginning_length is 0, should return the full underlying value"
173+
174+
175+
@pytest.mark.ollama
176+
@pytest.mark.llm
177+
async def test_astream_computed_returns_full_value():
178+
"""Test that astream returns full value when already computed."""
179+
# Create a pre-computed thunk
180+
mot = ModelOutputThunk(value="Hello, world!")
181+
mot._computed = True
182+
183+
# astream should return the full value immediately (line 272)
184+
result = await mot.astream()
185+
186+
assert result == "Hello, world!", "Computed thunk should return full value"
187+
188+
189+
@pytest.mark.ollama
190+
@pytest.mark.llm
191+
async def test_astream_final_call_returns_full_value():
192+
"""Test that the final astream call returns the full value when computed.
193+
194+
This tests the behavior at line 350 in base.py where the final call
195+
(when _computed becomes True) returns the full _underlying_value.
196+
"""
197+
session = start_session()
198+
model_opts = {ModelOption.STREAM: True}
199+
200+
mot, _ = await session.backend.generate_from_context(
201+
CBlock("Count: 1, 2, 3"), SimpleContext(), model_options=model_opts
202+
)
203+
204+
chunks = []
205+
206+
# Collect all chunks
207+
while not mot.is_computed():
208+
chunk = await mot.astream()
209+
if chunk:
210+
chunks.append(chunk)
211+
212+
if len(chunks) > 100: # Safety
213+
break
214+
215+
# Get final value
216+
final_val = await mot.avalue()
217+
218+
# The last chunk should be the full value (not incremental)
219+
if len(chunks) > 0:
220+
assert chunks[-1] == final_val, (
221+
f"Final chunk should be the complete value.\n"
222+
f"Last chunk: {chunks[-1]!r}\n"
223+
f"Final value: {final_val!r}"
224+
)
225+
226+
# All chunks before the last should be incremental (non-overlapping)
227+
for i in range(len(chunks) - 2): # Exclude the last chunk
228+
for j in range(i + 1, len(chunks) - 1): # Exclude the last chunk
229+
# Earlier incremental chunks shouldn't be prefixes of later ones
230+
if chunks[j] and chunks[i]:
231+
assert not chunks[j].startswith(chunks[i]), (
232+
f"Incremental chunk {j} should not start with chunk {i}"
233+
)
234+
235+
236+
if __name__ == "__main__":
237+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)