Skip to content

Commit 3337332

Browse files
committed
add tool calling argument validation
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent 9155d92 commit 3337332

File tree

8 files changed

+1144
-7
lines changed

8 files changed

+1144
-7
lines changed

mellea/backends/litellm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
add_tools_from_context_actions,
4141
add_tools_from_model_options,
4242
convert_tools_to_json,
43+
validate_tool_arguments,
4344
)
4445

4546
format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors
@@ -602,7 +603,12 @@ def _extract_model_tool_requests(
602603

603604
# Returns the args as a string. Parse it here.
604605
args = json.loads(tool_args)
605-
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args)
606+
607+
# Validate and coerce argument types
608+
validated_args = validate_tool_arguments(func, args, strict=False)
609+
model_tool_calls[tool_name] = ModelToolCall(
610+
tool_name, func, validated_args
611+
)
606612

607613
if len(model_tool_calls) > 0:
608614
return model_tool_calls

mellea/backends/ollama.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ async def generate_from_raw(
501501
def _extract_model_tool_requests(
502502
self, tools: dict[str, Callable], chat_response: ollama.ChatResponse
503503
) -> dict[str, ModelToolCall] | None:
504+
from .tools import validate_tool_arguments
505+
504506
model_tool_calls: dict[str, ModelToolCall] = {}
505507

506508
if chat_response.message.tool_calls:
@@ -513,8 +515,11 @@ def _extract_model_tool_requests(
513515
continue # skip this function if we can't find it.
514516

515517
args = tool.function.arguments
518+
519+
# Validate and coerce argument types
520+
validated_args = validate_tool_arguments(func, args, strict=False)
516521
model_tool_calls[tool.function.name] = ModelToolCall(
517-
tool.function.name, func, args
522+
tool.function.name, func, validated_args
518523
)
519524

520525
if len(model_tool_calls) > 0:

mellea/backends/tools.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,151 @@ def parse_tools(llm_response: str) -> list[tuple[str, Mapping]]:
142142
if tool_name is not None and tool_arguments is not None:
143143
tools.append((tool_name, tool_arguments))
144144

145-
return tools
145+
146+
def validate_tool_arguments(
147+
func: Callable,
148+
args: Mapping[str, Any],
149+
*,
150+
coerce_types: bool = True,
151+
strict: bool = False,
152+
) -> dict[str, Any]:
153+
"""Validate and optionally coerce tool arguments against function signature.
154+
155+
This function validates tool call arguments extracted from LLM responses against
156+
the expected function signature. It can automatically coerce common type mismatches
157+
(e.g., string "30" to int 30) and provides detailed error messages.
158+
159+
Args:
160+
func: The tool function to validate against
161+
args: Raw arguments from model (post-JSON parsing)
162+
coerce_types: If True, attempt type coercion for common cases (default: True)
163+
strict: If True, raise ValidationError on failures; if False, log warnings
164+
and return original args (default: False)
165+
166+
Returns:
167+
Validated and optionally coerced arguments dict
168+
169+
Raises:
170+
ValidationError: If strict=True and validation fails
171+
172+
Examples:
173+
>>> def get_weather(location: str, days: int = 1) -> dict:
174+
... return {"location": location, "days": days}
175+
176+
>>> # LLM returns days as string
177+
>>> args = {"location": "Boston", "days": "3"}
178+
>>> validated = validate_tool_arguments(get_weather, args)
179+
>>> validated
180+
{'location': 'Boston', 'days': 3}
181+
182+
>>> # Strict mode raises on validation errors
183+
>>> bad_args = {"location": "Boston", "days": "not_a_number"}
184+
>>> validate_tool_arguments(get_weather, bad_args, strict=True)
185+
Traceback (most recent call last):
186+
...
187+
pydantic.ValidationError: ...
188+
"""
189+
from pydantic import ValidationError, create_model
190+
191+
from ..core import FancyLogger
192+
193+
# Get function signature
194+
sig = inspect.signature(func)
195+
196+
# Build Pydantic model from function signature
197+
# This reuses the logic from convert_function_to_tool
198+
field_definitions: dict[str, Any] = {}
199+
200+
for param_name, param in sig.parameters.items():
201+
# Skip *args and **kwargs
202+
if param.kind in (
203+
inspect.Parameter.VAR_POSITIONAL,
204+
inspect.Parameter.VAR_KEYWORD,
205+
):
206+
continue
207+
208+
# Get type annotation
209+
param_type = param.annotation
210+
if param_type == inspect.Parameter.empty:
211+
# No type hint, default to Any
212+
param_type = Any
213+
214+
# Handle default values
215+
if param.default == inspect.Parameter.empty:
216+
# Required parameter
217+
field_definitions[param_name] = (param_type, ...)
218+
else:
219+
# Optional parameter with default
220+
field_definitions[param_name] = (param_type, param.default)
221+
222+
# Create dynamic Pydantic model for validation
223+
ValidatorModel = create_model(f"{func.__name__}_Validator", **field_definitions)
224+
225+
# Configure model for type coercion if requested
226+
if coerce_types:
227+
# Pydantic v2 uses model_config
228+
ValidatorModel.model_config = ConfigDict(
229+
str_strip_whitespace=True # Strip whitespace from strings
230+
# Pydantic automatically coerces compatible types
231+
)
232+
233+
try:
234+
# Validate using Pydantic
235+
validated_model = ValidatorModel(**args)
236+
validated_args = validated_model.model_dump()
237+
238+
# Log successful validation with coercion details
239+
coerced_fields = []
240+
for key, original_value in args.items():
241+
validated_value = validated_args.get(key)
242+
if type(original_value) is not type(validated_value):
243+
coerced_fields.append(
244+
f"{key}: {type(original_value).__name__}{type(validated_value).__name__}"
245+
)
246+
247+
if coerced_fields and coerce_types:
248+
FancyLogger.get_logger().debug(
249+
f"Tool '{func.__name__}' arguments coerced: {', '.join(coerced_fields)}"
250+
)
251+
252+
return validated_args
253+
254+
except ValidationError as e:
255+
# Format error message
256+
error_details = []
257+
for error in e.errors():
258+
field = ".".join(str(loc) for loc in error["loc"])
259+
msg = error["msg"]
260+
error_details.append(f" - {field}: {msg}")
261+
262+
error_msg = (
263+
f"Tool argument validation failed for '{func.__name__}':\n"
264+
+ "\n".join(error_details)
265+
)
266+
267+
if strict:
268+
# Re-raise with enhanced message
269+
FancyLogger.get_logger().error(error_msg)
270+
raise
271+
else:
272+
# Log warning and return original args
273+
FancyLogger.get_logger().warning(
274+
error_msg + "\nReturning original arguments without validation."
275+
)
276+
return dict(args)
277+
278+
except Exception as e:
279+
# Catch any other errors during validation
280+
error_msg = f"Unexpected error validating tool '{func.__name__}' arguments: {e}"
281+
282+
if strict:
283+
FancyLogger.get_logger().error(error_msg)
284+
raise
285+
else:
286+
FancyLogger.get_logger().warning(
287+
error_msg + "\nReturning original arguments without validation."
288+
)
289+
return dict(args)
146290

147291

148292
# Below functions and classes extracted from Ollama Python SDK (v0.6.1)

mellea/backends/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..core import CBlock, Component, Context, FancyLogger, ModelToolCall
99
from ..formatters import ChatFormatter
1010
from ..stdlib.components import Message
11-
from .tools import parse_tools
11+
from .tools import parse_tools, validate_tool_arguments
1212

1313
# Chat = dict[Literal["role", "content"], str] # external apply_chat_template type hint is weaker
1414
# Chat = dict[str, str | list[dict[str, Any]] ] # for multi-modal models
@@ -74,7 +74,9 @@ def to_tool_calls(
7474
if len(sig.parameters) == 0:
7575
tool_args = {}
7676

77-
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, tool_args)
77+
# Validate and coerce argument types
78+
validated_args = validate_tool_arguments(func, tool_args, strict=False)
79+
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, validated_args)
7880

7981
if len(model_tool_calls) > 0:
8082
return model_tool_calls

mellea/backends/watsonx.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
add_tools_from_context_actions,
4444
add_tools_from_model_options,
4545
convert_tools_to_json,
46+
validate_tool_arguments,
4647
)
4748

4849
format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors
@@ -589,7 +590,10 @@ def _extract_model_tool_requests(
589590

590591
# Watsonx returns the args as a string. Parse it here.
591592
args = json.loads(tool_args)
592-
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args)
593+
594+
# Validate and coerce argument types
595+
validated_args = validate_tool_arguments(func, args, strict=False)
596+
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, validated_args)
593597

594598
if len(model_tool_calls) > 0:
595599
return model_tool_calls

mellea/helpers/openai_compatible_helpers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Callable
55
from typing import Any
66

7+
from ..backends.tools import validate_tool_arguments
78
from ..core import FancyLogger, ModelToolCall
89
from ..stdlib.components import Document, Message
910

@@ -30,7 +31,10 @@ def extract_model_tool_requests(
3031
if tool_args is not None:
3132
# Returns the args as a string. Parse it here.
3233
args = json.loads(tool_args)
33-
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args)
34+
35+
# Validate and coerce argument types
36+
validated_args = validate_tool_arguments(func, args, strict=False)
37+
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, validated_args)
3438

3539
if len(model_tool_calls) > 0:
3640
return model_tool_calls

0 commit comments

Comments
 (0)