Skip to content

Commit efd8d6a

Browse files
committed
review comments
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent bd96f65 commit efd8d6a

File tree

4 files changed

+75
-248
lines changed

4 files changed

+75
-248
lines changed

mellea/core/base.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -410,50 +410,63 @@ class ComputedModelOutputThunk(ModelOutputThunk[S]):
410410
and those that are already computed. It should be returned from synchronous functions
411411
and sampling strategies to indicate that no awaiting is needed.
412412
413+
Rather than creating from scratch, this wraps an existing ModelOutputThunk and ensures
414+
it's fully computed. All attribute access is delegated to the wrapped thunk.
415+
413416
Key differences from ModelOutputThunk:
414-
- Always initialized with a value (cannot be None)
417+
- Always initialized from a computed ModelOutputThunk
415418
- _computed is always True
416419
- Cannot be used for streaming (generation fields are not set)
417420
- Provides type safety to indicate "already computed"
418421
"""
419422

420-
def __init__(
421-
self,
422-
value: str,
423-
meta: dict[str, Any] | None = None,
424-
parsed_repr: S | None = None,
425-
tool_calls: dict[str, ModelToolCall] | None = None,
426-
):
427-
"""Initializes a computed ModelOutputThunk with a required value.
423+
def __init__(self, thunk: ModelOutputThunk[S]):
424+
"""Wraps an existing ModelOutputThunk, ensuring it's computed.
428425
429426
Args:
430-
value: The computed string value (required, cannot be None)
431-
meta: Optional metadata dictionary
432-
parsed_repr: Optional parsed representation
433-
tool_calls: Optional tool calls dictionary
427+
thunk: A ModelOutputThunk that must be fully computed (value cannot be None)
428+
429+
Raises:
430+
ValueError: If the thunk is not computed or has a None value
434431
"""
435-
if value is None:
432+
if not thunk.is_computed():
433+
raise ValueError(
434+
"ComputedModelOutputThunk requires a computed ModelOutputThunk"
435+
)
436+
if thunk.value is None:
436437
raise ValueError("ComputedModelOutputThunk requires a non-None value")
437438

438-
super().__init__(value, meta, parsed_repr, tool_calls)
439-
440-
# Ensure computed flag is set
441-
assert self._computed, "ComputedModelOutputThunk must be computed"
439+
# Store the wrapped thunk and ensure it's marked as computed
440+
self._wrapped_thunk = thunk
441+
self._wrapped_thunk._computed = True
442442

443443
# Clear generation-related fields since this is already computed
444-
self._generate = None
445-
self._generate_type = GenerateType.NONE
446-
self._generate_extra = None
447-
self._process = None
448-
self._post_process = None
444+
self._wrapped_thunk._generate = None
445+
self._wrapped_thunk._generate_type = GenerateType.NONE
446+
self._wrapped_thunk._generate_extra = None
447+
self._wrapped_thunk._process = None
448+
self._wrapped_thunk._post_process = None
449+
450+
def __getattr__(self, name: str) -> Any:
451+
"""Delegate all attribute access to the wrapped thunk."""
452+
return getattr(self._wrapped_thunk, name)
453+
454+
def __setattr__(self, name: str, value: Any) -> None:
455+
"""Delegate all attribute setting to the wrapped thunk, except for _wrapped_thunk itself."""
456+
if name == "_wrapped_thunk":
457+
object.__setattr__(self, name, value)
458+
else:
459+
setattr(self._wrapped_thunk, name, value)
449460

450461
async def avalue(self) -> str:
451462
"""Returns the value immediately since it's already computed.
452463
453464
Overrides the parent method to avoid unnecessary async operations.
454465
"""
455-
assert self.value is not None, "ComputedModelOutputThunk value cannot be None"
456-
return self.value
466+
assert self._wrapped_thunk.value is not None, (
467+
"ComputedModelOutputThunk value cannot be None"
468+
)
469+
return self._wrapped_thunk.value
457470

458471
async def astream(self) -> str:
459472
"""Returns the value immediately since streaming is not applicable.

mellea/stdlib/functional.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def act(
4444
format: type[BaseModelSubclass] | None = None,
4545
model_options: dict | None = None,
4646
tool_calls: bool = False,
47-
) -> tuple[ModelOutputThunk[S], Context]: ...
47+
) -> tuple[ComputedModelOutputThunk[S], Context]: ...
4848

4949

5050
@overload
@@ -73,7 +73,7 @@ def act(
7373
format: type[BaseModelSubclass] | None = None,
7474
model_options: dict | None = None,
7575
tool_calls: bool = False,
76-
) -> tuple[ModelOutputThunk[S], Context] | SamplingResult[S]:
76+
) -> tuple[ComputedModelOutputThunk[S], Context] | SamplingResult[S]:
7777
"""Runs a generic action, and adds both the action and the result to the context.
7878
7979
Args:
@@ -129,7 +129,7 @@ def instruct(
129129
format: type[BaseModelSubclass] | None = None,
130130
model_options: dict | None = None,
131131
tool_calls: bool = False,
132-
) -> tuple[ModelOutputThunk[str], Context]: ...
132+
) -> tuple[ComputedModelOutputThunk[str], Context]: ...
133133

134134

135135
@overload
@@ -170,7 +170,7 @@ def instruct(
170170
format: type[BaseModelSubclass] | None = None,
171171
model_options: dict | None = None,
172172
tool_calls: bool = False,
173-
) -> tuple[ModelOutputThunk[str], Context] | SamplingResult[str]:
173+
) -> tuple[ComputedModelOutputThunk[str], Context] | SamplingResult[str]:
174174
"""Generates from an instruction.
175175
176176
Args:
@@ -555,20 +555,9 @@ async def aact(
555555
# ._generate_log should never be None after generation.
556556
assert result._generate_log is not None
557557
result._generate_log.is_final_result = True
558-
generate_logs.append(result._generate_log)
559558

560559
# Wrap in ComputedModelOutputThunk to indicate it's fully computed
561-
computed_result = ComputedModelOutputThunk(
562-
value=result.value, # type: ignore
563-
meta=result._meta,
564-
parsed_repr=result.parsed_repr,
565-
tool_calls=result.tool_calls,
566-
)
567-
# Copy over important fields
568-
computed_result._thinking = result._thinking
569-
computed_result._context = result._context
570-
computed_result._action = result._action
571-
computed_result._model_options = result._model_options
560+
computed_result = ComputedModelOutputThunk(result)
572561
computed_result._generate_log = result._generate_log
573562

574563
# Update context to point to the wrapped result instead of original
@@ -611,18 +600,7 @@ async def aact(
611600
)
612601

613602
# Wrap sampling result in ComputedModelOutputThunk since it's always computed
614-
computed_result = ComputedModelOutputThunk(
615-
value=result.value, # type: ignore
616-
meta=result._meta,
617-
parsed_repr=result.parsed_repr,
618-
tool_calls=result.tool_calls,
619-
)
620-
# Copy over important fields
621-
computed_result._thinking = result._thinking
622-
computed_result._context = result._context
623-
computed_result._action = result._action
624-
computed_result._model_options = result._model_options
625-
computed_result._generate_log = result._generate_log
603+
computed_result = ComputedModelOutputThunk(result)
626604

627605
# Update the sampling result to use the computed thunk
628606
sampling_result.sample_generations[sampling_result.result_index] = (

test/backends/test_tool_argument_validation_README.md

Lines changed: 0 additions & 181 deletions
This file was deleted.

0 commit comments

Comments
 (0)