Skip to content

Commit 8aeab48

Browse files
committed
refactor for MelleaTool
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent 2a0b2aa commit 8aeab48

File tree

2 files changed

+103
-66
lines changed

2 files changed

+103
-66
lines changed

mellea/backends/tools.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -213,20 +213,20 @@ def parse_tools(llm_response: str) -> list[tuple[str, Mapping]]:
213213

214214

215215
def validate_tool_arguments(
216-
func: Callable,
216+
tool: AbstractMelleaTool,
217217
args: Mapping[str, Any],
218218
*,
219219
coerce_types: bool = True,
220220
strict: bool = False,
221221
) -> dict[str, Any]:
222-
"""Validate and optionally coerce tool arguments against function signature.
222+
"""Validate and optionally coerce tool arguments against tool's JSON schema.
223223
224224
This function validates tool call arguments extracted from LLM responses against
225-
the expected function signature. It can automatically coerce common type mismatches
226-
(e.g., string "30" to int 30) and provides detailed error messages.
225+
the tool's JSON schema from as_json_tool. It can automatically coerce common type
226+
mismatches (e.g., string "30" to int 30) and provides detailed error messages.
227227
228228
Args:
229-
func: The tool function to validate against
229+
tool: The MelleaTool instance to validate against
230230
args: Raw arguments from model (post-JSON parsing)
231231
coerce_types: If True, attempt type coercion for common cases (default: True)
232232
strict: If True, raise ValidationError on failures; if False, log warnings
@@ -241,16 +241,17 @@ def validate_tool_arguments(
241241
Examples:
242242
>>> def get_weather(location: str, days: int = 1) -> dict:
243243
... return {"location": location, "days": days}
244+
>>> tool = MelleaTool.from_callable(get_weather)
244245
245246
>>> # LLM returns days as string
246247
>>> args = {"location": "Boston", "days": "3"}
247-
>>> validated = validate_tool_arguments(get_weather, args)
248+
>>> validated = validate_tool_arguments(tool, args)
248249
>>> validated
249250
{'location': 'Boston', 'days': 3}
250251
251252
>>> # Strict mode raises on validation errors
252253
>>> bad_args = {"location": "Boston", "days": "not_a_number"}
253-
>>> validate_tool_arguments(get_weather, bad_args, strict=True)
254+
>>> validate_tool_arguments(tool, bad_args, strict=True)
254255
Traceback (most recent call last):
255256
...
256257
pydantic.ValidationError: ...
@@ -259,34 +260,45 @@ def validate_tool_arguments(
259260

260261
from ..core import FancyLogger
261262

262-
# Get function signature
263-
sig = inspect.signature(func)
264-
265-
# Build Pydantic model from function signature
266-
# This reuses the logic from convert_function_to_tool
263+
# Extract JSON schema from tool
264+
tool_schema = tool.as_json_tool.get("function", {})
265+
tool_name = tool_schema.get("name", "unknown_tool")
266+
parameters = tool_schema.get("parameters", {})
267+
properties = parameters.get("properties", {})
268+
required_fields = parameters.get("required", [])
269+
270+
# Map JSON schema types to Python types
271+
JSON_TYPE_TO_PYTHON = {
272+
"string": str,
273+
"integer": int,
274+
"number": float,
275+
"boolean": bool,
276+
"array": list,
277+
"object": dict,
278+
}
279+
280+
# Build Pydantic model from JSON schema
267281
field_definitions: dict[str, Any] = {}
268282

269-
for param_name, param in sig.parameters.items():
270-
# Skip *args and **kwargs
271-
if param.kind in (
272-
inspect.Parameter.VAR_POSITIONAL,
273-
inspect.Parameter.VAR_KEYWORD,
274-
):
275-
continue
283+
for param_name, param_schema in properties.items():
284+
# Get type from JSON schema
285+
json_type = param_schema.get("type", "string")
286+
287+
# Handle comma-separated types (e.g., "string, integer")
288+
if isinstance(json_type, str) and "," in json_type:
289+
# Take the first type for simplicity
290+
json_type = json_type.split(",")[0].strip()
276291

277-
# Get type annotation
278-
param_type = param.annotation
279-
if param_type == inspect.Parameter.empty:
280-
# No type hint, default to Any
281-
param_type = Any
292+
# Map to Python type
293+
param_type = JSON_TYPE_TO_PYTHON.get(json_type, Any)
282294

283-
# Handle default values
284-
if param.default == inspect.Parameter.empty:
295+
# Determine if parameter is required
296+
if param_name in required_fields:
285297
# Required parameter
286298
field_definitions[param_name] = (param_type, ...)
287299
else:
288-
# Optional parameter with default
289-
field_definitions[param_name] = (param_type, param.default)
300+
# Optional parameter (default to None)
301+
field_definitions[param_name] = (param_type, None)
290302

291303
# Configure model for type coercion if requested
292304
if coerce_types:
@@ -305,7 +317,7 @@ def validate_tool_arguments(
305317

306318
# Create dynamic Pydantic model for validation
307319
ValidatorModel = create_model(
308-
f"{func.__name__}_Validator", __config__=model_config, **field_definitions
320+
f"{tool_name}_Validator", __config__=model_config, **field_definitions
309321
)
310322

311323
try:
@@ -332,7 +344,7 @@ def validate_tool_arguments(
332344

333345
if coerced_fields and coerce_types:
334346
FancyLogger.get_logger().debug(
335-
f"Tool '{func.__name__}' arguments coerced: {', '.join(coerced_fields)}"
347+
f"Tool '{tool_name}' arguments coerced: {', '.join(coerced_fields)}"
336348
)
337349

338350
return validated_args
@@ -345,9 +357,8 @@ def validate_tool_arguments(
345357
msg = error["msg"]
346358
error_details.append(f" - {field}: {msg}")
347359

348-
error_msg = (
349-
f"Tool argument validation failed for '{func.__name__}':\n"
350-
+ "\n".join(error_details)
360+
error_msg = f"Tool argument validation failed for '{tool_name}':\n" + "\n".join(
361+
error_details
351362
)
352363

353364
if strict:
@@ -363,7 +374,7 @@ def validate_tool_arguments(
363374

364375
except Exception as e:
365376
# Catch any other errors during validation
366-
error_msg = f"Unexpected error validating tool '{func.__name__}' arguments: {e}"
377+
error_msg = f"Unexpected error validating tool '{tool_name}' arguments: {e}"
367378

368379
if strict:
369380
FancyLogger.get_logger().error(error_msg)

0 commit comments

Comments
 (0)