Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
add_tools_from_context_actions,
add_tools_from_model_options,
convert_tools_to_json,
validate_tool_arguments,
)

format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors
Expand Down Expand Up @@ -603,7 +604,12 @@ def _extract_model_tool_requests(

# Returns the args as a string. Parse it here.
args = json.loads(tool_args)
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args)

# Validate and coerce argument types
validated_args = validate_tool_arguments(func, args, strict=False)
model_tool_calls[tool_name] = ModelToolCall(
tool_name, func, validated_args
)
Comment on lines +608 to +612
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were you able to verify that these backends can generate incorrect parameters? I know huggingface has some issues with its tool requests that could benefit from this, but I'm not sure the other backends need any validation (but that was an open question I wasn't certain on).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember if I saw the backend generated incorrect parameters but it is good to validate them before making tool calls. We can not assume all backends working correctly :-)


if len(model_tool_calls) > 0:
return model_tool_calls
Expand Down
7 changes: 6 additions & 1 deletion mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ async def generate_from_raw(
def _extract_model_tool_requests(
self, tools: dict[str, AbstractMelleaTool], chat_response: ollama.ChatResponse
) -> dict[str, ModelToolCall] | None:
from .tools import validate_tool_arguments

model_tool_calls: dict[str, ModelToolCall] = {}

if chat_response.message.tool_calls:
Expand All @@ -514,8 +516,11 @@ def _extract_model_tool_requests(
continue # skip this function if we can't find it.

args = tool.function.arguments

# Validate and coerce argument types
validated_args = validate_tool_arguments(func, args, strict=False)
model_tool_calls[tool.function.name] = ModelToolCall(
tool.function.name, func, args
tool.function.name, func, validated_args
)

if len(model_tool_calls) > 0:
Expand Down
192 changes: 191 additions & 1 deletion mellea/backends/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,200 @@ def parse_tools(llm_response: str) -> list[tuple[str, Mapping]]:
tool_name, tool_arguments = find_func(possible_tool)
if tool_name is not None and tool_arguments is not None:
tools.append((tool_name, tool_arguments))

return tools


def validate_tool_arguments(
tool: AbstractMelleaTool,
args: Mapping[str, Any],
*,
coerce_types: bool = True,
strict: bool = False,
) -> dict[str, Any]:
Comment on lines 217 to 223
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the changes for this PR and Hendrik's; if Hendrik's gets merged, first you'll just have to change this func: Callable to be a func: MelleaTool since not all of our tools going forward will be simple functions. Then the validation becomes a bit trickier but MelleaTool.as_json_tool should give you the parameter names and types that matter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw the PR. Right now, the parameter information is from the function signature only. I understand that some refactor is necessary.

"""Validate and optionally coerce tool arguments against tool's JSON schema.

This function validates tool call arguments extracted from LLM responses against
the tool's JSON schema from as_json_tool. It can automatically coerce common type
mismatches (e.g., string "30" to int 30) and provides detailed error messages.

Args:
tool: The MelleaTool instance to validate against
args: Raw arguments from model (post-JSON parsing)
coerce_types: If True, attempt type coercion for common cases (default: True)
strict: If True, raise ValidationError on failures; if False, log warnings
and return original args (default: False)

Returns:
Validated and optionally coerced arguments dict

Raises:
ValidationError: If strict=True and validation fails

Examples:
>>> def get_weather(location: str, days: int = 1) -> dict:
... return {"location": location, "days": days}
>>> tool = MelleaTool.from_callable(get_weather)

>>> # LLM returns days as string
>>> args = {"location": "Boston", "days": "3"}
>>> validated = validate_tool_arguments(tool, args)
>>> validated
{'location': 'Boston', 'days': 3}

>>> # Strict mode raises on validation errors
>>> bad_args = {"location": "Boston", "days": "not_a_number"}
>>> validate_tool_arguments(tool, bad_args, strict=True)
Traceback (most recent call last):
...
pydantic.ValidationError: ...
"""
from pydantic import ValidationError, create_model

from ..core import FancyLogger

# Extract JSON schema from tool
tool_schema = tool.as_json_tool.get("function", {})
tool_name = tool_schema.get("name", "unknown_tool")
parameters = tool_schema.get("parameters", {})
properties = parameters.get("properties", {})
required_fields = parameters.get("required", [])

# Map JSON schema types to Python types
JSON_TYPE_TO_PYTHON = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
}

# Build Pydantic model from JSON schema
field_definitions: dict[str, Any] = {}

for param_name, param_schema in properties.items():
# Get type from JSON schema
json_type = param_schema.get("type", "string")

# Handle comma-separated types (e.g., "integer, string" for Union types)
if isinstance(json_type, str) and "," in json_type:
# Create Union type for multiple types
type_list = [t.strip() for t in json_type.split(",")]
python_types = [JSON_TYPE_TO_PYTHON.get(t, Any) for t in type_list]
# Remove duplicates while preserving order
seen = set()
unique_types = []
for t in python_types:
if t not in seen:
seen.add(t)
unique_types.append(t)

if len(unique_types) == 1:
param_type = unique_types[0]
else:
# Use modern union syntax (Python 3.10+)
from functools import reduce
from operator import or_

param_type = reduce(or_, unique_types)
else:
# Map to Python type
param_type = JSON_TYPE_TO_PYTHON.get(json_type, Any)

# Determine if parameter is required
if param_name in required_fields:
# Required parameter
field_definitions[param_name] = (param_type, ...)
else:
# Optional parameter (default to None)
field_definitions[param_name] = (param_type, None)

# Configure model for type coercion if requested
if coerce_types:
model_config = ConfigDict(
str_strip_whitespace=True,
strict=False, # Allow type coercion
extra="forbid" if strict else "allow", # Handle extra fields
# Enable coercion modes for common LLM output issues
coerce_numbers_to_str=True, # Allow int/float -> str
)
else:
model_config = ConfigDict(
strict=True, # No coercion
extra="forbid" if strict else "allow",
)

# Create dynamic Pydantic model for validation
ValidatorModel = create_model(
f"{tool_name}_Validator", __config__=model_config, **field_definitions
)

try:
# Validate using Pydantic
validated_model = ValidatorModel(**args)
validated_args = validated_model.model_dump()

# In lenient mode with extra="allow", Pydantic includes extra fields
# but we need to preserve them from the original args
if not strict:
# Add back any extra fields that weren't in the model
for key, value in args.items():
if key not in field_definitions:
validated_args[key] = value

# Log successful validation with coercion details
coerced_fields = []
for key, original_value in args.items():
validated_value = validated_args.get(key)
if type(original_value) is not type(validated_value):
coerced_fields.append(
f"{key}: {type(original_value).__name__} → {type(validated_value).__name__}"
)

if coerced_fields and coerce_types:
FancyLogger.get_logger().debug(
f"Tool '{tool_name}' arguments coerced: {', '.join(coerced_fields)}"
)

return validated_args

except ValidationError as e:
# Format error message
error_details = []
for error in e.errors():
field = ".".join(str(loc) for loc in error["loc"])
msg = error["msg"]
error_details.append(f" - {field}: {msg}")

error_msg = f"Tool argument validation failed for '{tool_name}':\n" + "\n".join(
error_details
)

if strict:
# Re-raise with enhanced message
FancyLogger.get_logger().error(error_msg)
raise
else:
# Log warning and return original args
FancyLogger.get_logger().warning(
error_msg + "\nReturning original arguments without validation."
)
return dict(args)

except Exception as e:
# Catch any other errors during validation
error_msg = f"Unexpected error validating tool '{tool_name}' arguments: {e}"

if strict:
FancyLogger.get_logger().error(error_msg)
raise
else:
FancyLogger.get_logger().warning(
error_msg + "\nReturning original arguments without validation."
)
return dict(args)


# Below functions and classes extracted from Ollama Python SDK (v0.6.1)
# so that all backends don't need it installed.
# https://github.com/ollama/ollama-python/blob/60e7b2f9ce710eeb57ef2986c46ea612ae7516af/ollama/_types.py#L19-L101
Expand Down
6 changes: 4 additions & 2 deletions mellea/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..core.base import AbstractMelleaTool
from ..formatters import ChatFormatter
from ..stdlib.components import Message
from .tools import parse_tools
from .tools import parse_tools, validate_tool_arguments

# Chat = dict[Literal["role", "content"], str] # external apply_chat_template type hint is weaker
# Chat = dict[str, str | list[dict[str, Any]] ] # for multi-modal models
Expand Down Expand Up @@ -75,7 +75,9 @@ def to_tool_calls(
if len(param_map) == 0:
tool_args = {}

model_tool_calls[tool_name] = ModelToolCall(tool_name, func, tool_args)
# Validate and coerce argument types
validated_args = validate_tool_arguments(func, tool_args, strict=False)
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, validated_args)

if len(model_tool_calls) > 0:
return model_tool_calls
Expand Down
6 changes: 5 additions & 1 deletion mellea/backends/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
add_tools_from_context_actions,
add_tools_from_model_options,
convert_tools_to_json,
validate_tool_arguments,
)

format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors
Expand Down Expand Up @@ -590,7 +591,10 @@ def _extract_model_tool_requests(

# Watsonx returns the args as a string. Parse it here.
args = json.loads(tool_args)
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args)

# Validate and coerce argument types
validated_args = validate_tool_arguments(func, args, strict=False)
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, validated_args)

if len(model_tool_calls) > 0:
return model_tool_calls
Expand Down
6 changes: 5 additions & 1 deletion mellea/helpers/openai_compatible_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Callable
from typing import Any

from ..backends.tools import validate_tool_arguments
from ..core import FancyLogger, ModelToolCall
from ..core.base import AbstractMelleaTool
from ..stdlib.components import Document, Message
Expand Down Expand Up @@ -31,7 +32,10 @@ def extract_model_tool_requests(
if tool_args is not None:
# Returns the args as a string. Parse it here.
args = json.loads(tool_args)
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args)

# Validate and coerce argument types
validated_args = validate_tool_arguments(func, args, strict=False)
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, validated_args)

if len(model_tool_calls) > 0:
return model_tool_calls
Expand Down
Loading