From 6a79d5bdc5691078e1e7d297664701f4aacceb56 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 26 Jan 2026 14:59:02 -0500 Subject: [PATCH 1/8] add tool calling argument validation Signed-off-by: Akihiko Kuroda --- mellea/backends/litellm.py | 8 +- mellea/backends/ollama.py | 7 +- mellea/backends/tools.py | 146 ++++- mellea/backends/utils.py | 6 +- mellea/backends/watsonx.py | 6 +- mellea/helpers/openai_compatible_helpers.py | 6 +- .../backends/test_tool_argument_validation.py | 613 ++++++++++++++++++ .../test_tool_validation_integration.py | 359 ++++++++++ 8 files changed, 1144 insertions(+), 7 deletions(-) create mode 100644 test/backends/test_tool_argument_validation.py create mode 100644 test/backends/test_tool_validation_integration.py diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 226251ce..6fab9d6b 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -41,6 +41,7 @@ add_tools_from_context_actions, add_tools_from_model_options, convert_tools_to_json, + validate_tool_arguments, ) format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors @@ -603,7 +604,12 @@ def _extract_model_tool_requests( # Returns the args as a string. Parse it here. args = json.loads(tool_args) - model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args) + + # Validate and coerce argument types + validated_args = validate_tool_arguments(func, args, strict=False) + model_tool_calls[tool_name] = ModelToolCall( + tool_name, func, validated_args + ) if len(model_tool_calls) > 0: return model_tool_calls diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index d20e2aa1..19efc1e3 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -502,6 +502,8 @@ async def generate_from_raw( def _extract_model_tool_requests( self, tools: dict[str, AbstractMelleaTool], chat_response: ollama.ChatResponse ) -> dict[str, ModelToolCall] | None: + from .tools import validate_tool_arguments + model_tool_calls: dict[str, ModelToolCall] = {} if chat_response.message.tool_calls: @@ -514,8 +516,11 @@ def _extract_model_tool_requests( continue # skip this function if we can't find it. args = tool.function.arguments + + # Validate and coerce argument types + validated_args = validate_tool_arguments(func, args, strict=False) model_tool_calls[tool.function.name] = ModelToolCall( - tool.function.name, func, args + tool.function.name, func, validated_args ) if len(model_tool_calls) > 0: diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index 09249c66..36537fb8 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -212,7 +212,151 @@ def parse_tools(llm_response: str) -> list[tuple[str, Mapping]]: if tool_name is not None and tool_arguments is not None: tools.append((tool_name, tool_arguments)) - return tools + +def validate_tool_arguments( + func: Callable, + args: Mapping[str, Any], + *, + coerce_types: bool = True, + strict: bool = False, +) -> dict[str, Any]: + """Validate and optionally coerce tool arguments against function signature. + + This function validates tool call arguments extracted from LLM responses against + the expected function signature. It can automatically coerce common type mismatches + (e.g., string "30" to int 30) and provides detailed error messages. + + Args: + func: The tool function to validate against + args: Raw arguments from model (post-JSON parsing) + coerce_types: If True, attempt type coercion for common cases (default: True) + strict: If True, raise ValidationError on failures; if False, log warnings + and return original args (default: False) + + Returns: + Validated and optionally coerced arguments dict + + Raises: + ValidationError: If strict=True and validation fails + + Examples: + >>> def get_weather(location: str, days: int = 1) -> dict: + ... return {"location": location, "days": days} + + >>> # LLM returns days as string + >>> args = {"location": "Boston", "days": "3"} + >>> validated = validate_tool_arguments(get_weather, args) + >>> validated + {'location': 'Boston', 'days': 3} + + >>> # Strict mode raises on validation errors + >>> bad_args = {"location": "Boston", "days": "not_a_number"} + >>> validate_tool_arguments(get_weather, bad_args, strict=True) + Traceback (most recent call last): + ... + pydantic.ValidationError: ... + """ + from pydantic import ValidationError, create_model + + from ..core import FancyLogger + + # Get function signature + sig = inspect.signature(func) + + # Build Pydantic model from function signature + # This reuses the logic from convert_function_to_tool + field_definitions: dict[str, Any] = {} + + for param_name, param in sig.parameters.items(): + # Skip *args and **kwargs + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue + + # Get type annotation + param_type = param.annotation + if param_type == inspect.Parameter.empty: + # No type hint, default to Any + param_type = Any + + # Handle default values + if param.default == inspect.Parameter.empty: + # Required parameter + field_definitions[param_name] = (param_type, ...) + else: + # Optional parameter with default + field_definitions[param_name] = (param_type, param.default) + + # Create dynamic Pydantic model for validation + ValidatorModel = create_model(f"{func.__name__}_Validator", **field_definitions) + + # Configure model for type coercion if requested + if coerce_types: + # Pydantic v2 uses model_config + ValidatorModel.model_config = ConfigDict( + str_strip_whitespace=True # Strip whitespace from strings + # Pydantic automatically coerces compatible types + ) + + try: + # Validate using Pydantic + validated_model = ValidatorModel(**args) + validated_args = validated_model.model_dump() + + # Log successful validation with coercion details + coerced_fields = [] + for key, original_value in args.items(): + validated_value = validated_args.get(key) + if type(original_value) is not type(validated_value): + coerced_fields.append( + f"{key}: {type(original_value).__name__} β†’ {type(validated_value).__name__}" + ) + + if coerced_fields and coerce_types: + FancyLogger.get_logger().debug( + f"Tool '{func.__name__}' arguments coerced: {', '.join(coerced_fields)}" + ) + + return validated_args + + except ValidationError as e: + # Format error message + error_details = [] + for error in e.errors(): + field = ".".join(str(loc) for loc in error["loc"]) + msg = error["msg"] + error_details.append(f" - {field}: {msg}") + + error_msg = ( + f"Tool argument validation failed for '{func.__name__}':\n" + + "\n".join(error_details) + ) + + if strict: + # Re-raise with enhanced message + FancyLogger.get_logger().error(error_msg) + raise + else: + # Log warning and return original args + FancyLogger.get_logger().warning( + error_msg + "\nReturning original arguments without validation." + ) + return dict(args) + + except Exception as e: + # Catch any other errors during validation + error_msg = f"Unexpected error validating tool '{func.__name__}' arguments: {e}" + + if strict: + FancyLogger.get_logger().error(error_msg) + raise + else: + FancyLogger.get_logger().warning( + error_msg + "\nReturning original arguments without validation." + ) + return dict(args) # Below functions and classes extracted from Ollama Python SDK (v0.6.1) diff --git a/mellea/backends/utils.py b/mellea/backends/utils.py index 2c3c00f6..3a5c5cf7 100644 --- a/mellea/backends/utils.py +++ b/mellea/backends/utils.py @@ -9,7 +9,7 @@ from ..core.base import AbstractMelleaTool from ..formatters import ChatFormatter from ..stdlib.components import Message -from .tools import parse_tools +from .tools import parse_tools, validate_tool_arguments # Chat = dict[Literal["role", "content"], str] # external apply_chat_template type hint is weaker # Chat = dict[str, str | list[dict[str, Any]] ] # for multi-modal models @@ -75,7 +75,9 @@ def to_tool_calls( if len(param_map) == 0: tool_args = {} - model_tool_calls[tool_name] = ModelToolCall(tool_name, func, tool_args) + # Validate and coerce argument types + validated_args = validate_tool_arguments(func, tool_args, strict=False) + model_tool_calls[tool_name] = ModelToolCall(tool_name, func, validated_args) if len(model_tool_calls) > 0: return model_tool_calls diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 0fe1c2bd..27401e53 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -44,6 +44,7 @@ add_tools_from_context_actions, add_tools_from_model_options, convert_tools_to_json, + validate_tool_arguments, ) format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors @@ -590,7 +591,10 @@ def _extract_model_tool_requests( # Watsonx returns the args as a string. Parse it here. args = json.loads(tool_args) - model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args) + + # Validate and coerce argument types + validated_args = validate_tool_arguments(func, args, strict=False) + model_tool_calls[tool_name] = ModelToolCall(tool_name, func, validated_args) if len(model_tool_calls) > 0: return model_tool_calls diff --git a/mellea/helpers/openai_compatible_helpers.py b/mellea/helpers/openai_compatible_helpers.py index 7374b157..5afec123 100644 --- a/mellea/helpers/openai_compatible_helpers.py +++ b/mellea/helpers/openai_compatible_helpers.py @@ -4,6 +4,7 @@ from collections.abc import Callable from typing import Any +from ..backends.tools import validate_tool_arguments from ..core import FancyLogger, ModelToolCall from ..core.base import AbstractMelleaTool from ..stdlib.components import Document, Message @@ -31,7 +32,10 @@ def extract_model_tool_requests( if tool_args is not None: # Returns the args as a string. Parse it here. args = json.loads(tool_args) - model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args) + + # Validate and coerce argument types + validated_args = validate_tool_arguments(func, args, strict=False) + model_tool_calls[tool_name] = ModelToolCall(tool_name, func, validated_args) if len(model_tool_calls) > 0: return model_tool_calls diff --git a/test/backends/test_tool_argument_validation.py b/test/backends/test_tool_argument_validation.py new file mode 100644 index 00000000..458a899c --- /dev/null +++ b/test/backends/test_tool_argument_validation.py @@ -0,0 +1,613 @@ +"""Comprehensive test suite for tool call argument validation. + +Tests cover: +- Basic type validation and coercion +- Complex nested types (dicts, lists) +- Union types and Optional parameters +- Missing required arguments +- Extra arguments +- Malformed JSON parsing +- Pydantic model arguments +""" + +import json +from typing import Any, Optional, Union + +import pytest +from pydantic import BaseModel, ValidationError + +from mellea.backends.tools import parse_tools +from mellea.core import ModelToolCall + + +# ============================================================================ +# Test Fixtures - Tool Functions with Various Signatures +# ============================================================================ + + +def simple_string_tool(message: str) -> str: + """A simple tool that takes a string. + + Args: + message: The message to process + """ + return f"Processed: {message}" + + +def typed_primitives_tool(name: str, age: int, score: float, active: bool) -> dict: + """Tool with multiple primitive types. + + Args: + name: Person's name + age: Person's age in years + score: Performance score + active: Whether person is active + """ + return {"name": name, "age": age, "score": score, "active": active} + + +def optional_params_tool(required: str, optional: Optional[str] = None) -> str: + """Tool with optional parameters. + + Args: + required: A required parameter + optional: An optional parameter + """ + return f"{required}:{optional or 'none'}" + + +def union_type_tool(value: Union[str, int]) -> str: + """Tool with union type parameter. + + Args: + value: Can be string or integer + """ + return f"Value: {value} (type: {type(value).__name__})" + + +def list_param_tool(items: list[str]) -> int: + """Tool with list parameter. + + Args: + items: List of string items + """ + return len(items) + + +def dict_param_tool(config: dict[str, Any]) -> str: + """Tool with dict parameter. + + Args: + config: Configuration dictionary + """ + return json.dumps(config) + + +def nested_structure_tool(data: dict[str, list[int]]) -> int: + """Tool with nested structure. + + Args: + data: Dictionary mapping strings to lists of integers + """ + return sum(sum(values) for values in data.values()) + + +def default_values_tool(name: str, count: int = 10, prefix: str = "item") -> str: + """Tool with default values. + + Args: + name: Base name + count: Number of items (default: 10) + prefix: Prefix for items (default: "item") + """ + return f"{prefix}_{name}_{count}" + + +class UserModel(BaseModel): + """Pydantic model for testing.""" + + name: str + age: int + email: Optional[str] = None + + +def pydantic_model_tool(user: UserModel) -> str: + """Tool that accepts a Pydantic model. + + Args: + user: User information + """ + return f"User: {user.name}, Age: {user.age}" + + +def no_params_tool() -> str: + """Tool with no parameters.""" + return "No params needed" + + +# ============================================================================ +# Test Cases: Basic Type Validation +# ============================================================================ + + +class TestBasicTypeValidation: + """Test basic type validation and coercion.""" + + def test_string_argument(self): + """Test simple string argument.""" + args = {"message": "Hello, World!"} + tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + result = tool_call.call_func() + assert result == "Processed: Hello, World!" + + def test_integer_argument(self): + """Test integer argument.""" + args = {"name": "Alice", "age": 30, "score": 95.5, "active": True} + tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + result = tool_call.call_func() + assert result["age"] == 30 + assert isinstance(result["age"], int) + + def test_float_argument(self): + """Test float argument.""" + args = {"name": "Bob", "age": 25, "score": 88.7, "active": False} + tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + result = tool_call.call_func() + assert result["score"] == 88.7 + assert isinstance(result["score"], float) + + def test_boolean_argument(self): + """Test boolean argument.""" + args = {"name": "Charlie", "age": 35, "score": 92.0, "active": True} + tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + result = tool_call.call_func() + assert result["active"] is True + assert isinstance(result["active"], bool) + + +# ============================================================================ +# Test Cases: Type Coercion +# ============================================================================ + + +class TestTypeCoercion: + """Test automatic type coercion scenarios.""" + + def test_string_to_int_coercion(self): + """Test that string "123" can be coerced to int 123.""" + # This currently FAILS without validation + args = {"name": "Test", "age": "30", "score": 95.5, "active": True} + tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + + # Without validation, this will fail at call time + with pytest.raises(TypeError): + tool_call.call_func() + + def test_string_to_float_coercion(self): + """Test that string "95.5" can be coerced to float 95.5.""" + args = {"name": "Test", "age": 30, "score": "95.5", "active": True} + tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + + # Without validation, this will fail + with pytest.raises(TypeError): + tool_call.call_func() + + def test_int_to_string_coercion(self): + """Test that int 123 can be coerced to string "123".""" + args = {"message": 123} # Should be string + tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + + # This might work due to Python's duck typing, but not guaranteed + result = tool_call.call_func() + assert "123" in result + + def test_string_to_bool_coercion(self): + """Test boolean coercion from strings.""" + # Common LLM outputs: "true", "false", "True", "False" + args = {"name": "Test", "age": 30, "score": 95.5, "active": "true"} + tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + + # Without validation, this will fail + with pytest.raises(TypeError): + tool_call.call_func() + + +# ============================================================================ +# Test Cases: Optional and Default Parameters +# ============================================================================ + + +class TestOptionalParameters: + """Test optional and default parameter handling.""" + + def test_optional_param_provided(self): + """Test optional parameter when provided.""" + args = {"required": "value1", "optional": "value2"} + tool_call = ModelToolCall("optional_params_tool", optional_params_tool, args) + result = tool_call.call_func() + assert result == "value1:value2" + + def test_optional_param_omitted(self): + """Test optional parameter when omitted.""" + args = {"required": "value1"} + tool_call = ModelToolCall("optional_params_tool", optional_params_tool, args) + result = tool_call.call_func() + assert result == "value1:none" + + def test_optional_param_none(self): + """Test optional parameter explicitly set to None.""" + args = {"required": "value1", "optional": None} + tool_call = ModelToolCall("optional_params_tool", optional_params_tool, args) + result = tool_call.call_func() + assert result == "value1:none" + + def test_default_values_all_provided(self): + """Test tool with all default values provided.""" + args = {"name": "test", "count": 5, "prefix": "custom"} + tool_call = ModelToolCall("default_values_tool", default_values_tool, args) + result = tool_call.call_func() + assert result == "custom_test_5" + + def test_default_values_partial(self): + """Test tool with some default values omitted.""" + args = {"name": "test", "count": 7} + tool_call = ModelToolCall("default_values_tool", default_values_tool, args) + result = tool_call.call_func() + assert result == "item_test_7" + + def test_default_values_minimal(self): + """Test tool with only required parameters.""" + args = {"name": "test"} + tool_call = ModelToolCall("default_values_tool", default_values_tool, args) + result = tool_call.call_func() + assert result == "item_test_10" + + +# ============================================================================ +# Test Cases: Union Types +# ============================================================================ + + +class TestUnionTypes: + """Test union type parameter handling.""" + + def test_union_with_string(self): + """Test union type with string value.""" + args = {"value": "hello"} + tool_call = ModelToolCall("union_type_tool", union_type_tool, args) + result = tool_call.call_func() + assert "hello" in result + assert "str" in result + + def test_union_with_int(self): + """Test union type with integer value.""" + args = {"value": 42} + tool_call = ModelToolCall("union_type_tool", union_type_tool, args) + result = tool_call.call_func() + assert "42" in result + assert "int" in result + + def test_union_with_string_number(self): + """Test union type with string that looks like number.""" + # Without validation, this stays as string + args = {"value": "42"} + tool_call = ModelToolCall("union_type_tool", union_type_tool, args) + result = tool_call.call_func() + assert "42" in result + # Type depends on whether validation coerces + + +# ============================================================================ +# Test Cases: Complex Types (Lists, Dicts) +# ============================================================================ + + +class TestComplexTypes: + """Test complex type parameters (lists, dicts, nested structures).""" + + def test_list_of_strings(self): + """Test list parameter with strings.""" + args = {"items": ["apple", "banana", "cherry"]} + tool_call = ModelToolCall("list_param_tool", list_param_tool, args) + result = tool_call.call_func() + assert result == 3 + + def test_empty_list(self): + """Test empty list parameter.""" + args = {"items": []} + tool_call = ModelToolCall("list_param_tool", list_param_tool, args) + result = tool_call.call_func() + assert result == 0 + + def test_dict_parameter(self): + """Test dictionary parameter.""" + args = {"config": {"key1": "value1", "key2": 42, "key3": True}} + tool_call = ModelToolCall("dict_param_tool", dict_param_tool, args) + result = tool_call.call_func() + parsed = json.loads(result) + assert parsed["key1"] == "value1" + assert parsed["key2"] == 42 + + def test_nested_structure(self): + """Test nested dictionary with lists.""" + args = {"data": {"group1": [1, 2, 3], "group2": [4, 5], "group3": [6]}} + tool_call = ModelToolCall("nested_structure_tool", nested_structure_tool, args) + result = tool_call.call_func() + assert result == 21 # Sum of all numbers + + def test_nested_structure_empty(self): + """Test nested structure with empty lists.""" + args = {"data": {"group1": [], "group2": []}} + tool_call = ModelToolCall("nested_structure_tool", nested_structure_tool, args) + result = tool_call.call_func() + assert result == 0 + + +# ============================================================================ +# Test Cases: Error Conditions +# ============================================================================ + + +class TestErrorConditions: + """Test error handling for invalid arguments.""" + + def test_missing_required_argument(self): + """Test that missing required argument raises error.""" + args = {} # Missing 'message' + tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + + with pytest.raises(TypeError, match="missing.*required"): + tool_call.call_func() + + def test_extra_arguments(self): + """Test that extra arguments are ignored (Python behavior).""" + args = {"message": "test", "extra": "ignored"} + tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + + # Python ignores extra kwargs by default + with pytest.raises(TypeError, match="unexpected keyword argument"): + tool_call.call_func() + + def test_wrong_type_no_coercion(self): + """Test that wrong types fail without coercion.""" + args = {"name": "Test", "age": "not_a_number", "score": 95.5, "active": True} + tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + + # This will fail when Python tries to use the value + with pytest.raises((TypeError, ValueError)): + tool_call.call_func() + + def test_none_for_required_param(self): + """Test that None for required parameter fails.""" + args = {"message": None} + tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + + # Depends on function implementation + result = tool_call.call_func() + # May work or fail depending on function + + +# ============================================================================ +# Test Cases: JSON Parsing +# ============================================================================ + + +class TestJSONParsing: + """Test JSON parsing scenarios from model responses.""" + + def test_valid_json_string(self): + """Test parsing valid JSON string.""" + json_str = '{"message": "Hello, World!"}' + args = json.loads(json_str) + tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + result = tool_call.call_func() + assert "Hello, World!" in result + + def test_malformed_json(self): + """Test that malformed JSON raises error.""" + json_str = '{"message": "Hello, World!"' # Missing closing brace + + with pytest.raises(json.JSONDecodeError): + json.loads(json_str) + + def test_json_with_escaped_quotes(self): + """Test JSON with escaped quotes.""" + json_str = '{"message": "He said \\"Hello\\""}' + args = json.loads(json_str) + tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + result = tool_call.call_func() + assert "Hello" in result + + def test_json_with_unicode(self): + """Test JSON with unicode characters.""" + json_str = '{"message": "Hello δΈ–η•Œ 🌍"}' + args = json.loads(json_str) + tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + result = tool_call.call_func() + assert "δΈ–η•Œ" in result + + def test_json_with_nested_objects(self): + """Test JSON with nested objects.""" + json_str = '{"config": {"nested": {"key": "value"}, "list": [1, 2, 3]}}' + args = json.loads(json_str) + tool_call = ModelToolCall("dict_param_tool", dict_param_tool, args) + result = tool_call.call_func() + parsed = json.loads(result) + assert parsed["nested"]["key"] == "value" + + +# ============================================================================ +# Test Cases: Pydantic Models +# ============================================================================ + + +class TestPydanticModels: + """Test tools that accept Pydantic models as arguments.""" + + def test_pydantic_model_from_dict(self): + """Test creating Pydantic model from dict.""" + args = {"user": {"name": "Alice", "age": 30, "email": "alice@example.com"}} + + # Need to convert dict to Pydantic model + user_data = args["user"] + user = UserModel(**user_data) + args_with_model = {"user": user} + + tool_call = ModelToolCall( + "pydantic_model_tool", pydantic_model_tool, args_with_model + ) + result = tool_call.call_func() + assert "Alice" in result + assert "30" in result + + def test_pydantic_model_validation_error(self): + """Test that invalid Pydantic model data raises error.""" + user_data = {"name": "Bob", "age": "not_an_int"} # Invalid age type + + with pytest.raises(ValidationError): + UserModel(**user_data) + + def test_pydantic_model_with_optional(self): + """Test Pydantic model with optional field.""" + user_data = {"name": "Charlie", "age": 25} # email is optional + user = UserModel(**user_data) + args = {"user": user} + + tool_call = ModelToolCall("pydantic_model_tool", pydantic_model_tool, args) + result = tool_call.call_func() + assert "Charlie" in result + + +# ============================================================================ +# Test Cases: Edge Cases +# ============================================================================ + + +class TestEdgeCases: + """Test edge cases and unusual scenarios.""" + + def test_no_parameters_tool(self): + """Test tool with no parameters.""" + args = {} + tool_call = ModelToolCall("no_params_tool", no_params_tool, args) + result = tool_call.call_func() + assert result == "No params needed" + + def test_no_parameters_with_hallucinated_args(self): + """Test that hallucinated args for no-param tool are handled.""" + # Models sometimes hallucinate parameters + args = {"fake_param": "should_be_ignored"} + tool_call = ModelToolCall("no_params_tool", no_params_tool, args) + + # This should fail without validation that clears args + with pytest.raises(TypeError): + tool_call.call_func() + + def test_very_long_string(self): + """Test with very long string argument.""" + long_string = "x" * 10000 + args = {"message": long_string} + tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + result = tool_call.call_func() + assert len(result) > 10000 + + def test_special_characters_in_string(self): + """Test with special characters.""" + special = "!@#$%^&*()_+-=[]{}|;:',.<>?/~`" + args = {"message": special} + tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + result = tool_call.call_func() + assert special in result + + def test_empty_string(self): + """Test with empty string.""" + args = {"message": ""} + tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + result = tool_call.call_func() + assert result == "Processed: " + + +# ============================================================================ +# Test Cases: parse_tools() Function +# ============================================================================ + + +class TestParseToolsFunction: + """Test the parse_tools() function from backends/tools.py.""" + + def test_parse_single_tool_call(self): + """Test parsing a single tool call from text.""" + text = """ + I'll call the function now: + {"name": "get_weather", "arguments": {"location": "Boston", "days": 3}} + """ + tools = list(parse_tools(text)) + assert len(tools) == 1 + assert tools[0][0] == "get_weather" + assert tools[0][1]["location"] == "Boston" + assert tools[0][1]["days"] == 3 + + def test_parse_multiple_tool_calls(self): + """Test parsing multiple tool calls from text.""" + text = """ + First: {"name": "tool1", "arguments": {"arg1": "value1"}} + Second: {"name": "tool2", "arguments": {"arg2": "value2"}} + """ + tools = list(parse_tools(text)) + assert len(tools) == 2 + assert tools[0][0] == "tool1" + assert tools[1][0] == "tool2" + + def test_parse_with_extra_text(self): + """Test parsing tool calls with surrounding text.""" + text = """ + Let me help you with that. I'll use the get_temperature function. + {"name": "get_temperature", "arguments": {"location": "New York"}} + That should give us the current temperature. + """ + tools = list(parse_tools(text)) + assert len(tools) == 1 + assert tools[0][0] == "get_temperature" + + def test_parse_no_tools(self): + """Test parsing text with no tool calls.""" + text = "This is just regular text with no tool calls." + tools = list(parse_tools(text)) + assert len(tools) == 0 + + def test_parse_malformed_json(self): + """Test that malformed JSON is skipped.""" + text = """ + {"name": "tool1", "arguments": {"arg1": "value1"}} + {"name": "bad_tool", "arguments": {broken json}} + {"name": "tool2", "arguments": {"arg2": "value2"}} + """ + tools = list(parse_tools(text)) + # Should parse the valid ones and skip the malformed one + assert len(tools) == 2 + + +# ============================================================================ +# Integration Test Markers +# ============================================================================ + + +@pytest.mark.integration +class TestToolValidationIntegration: + """Integration tests that would use actual validation function.""" + + @pytest.mark.skip(reason="Validation function not yet implemented") + def test_validation_with_coercion(self): + """Test validation with type coercion enabled.""" + # This test will be enabled once validation is implemented + pass + + @pytest.mark.skip(reason="Validation function not yet implemented") + def test_validation_strict_mode(self): + """Test validation in strict mode.""" + # This test will be enabled once validation is implemented + pass + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/backends/test_tool_validation_integration.py b/test/backends/test_tool_validation_integration.py new file mode 100644 index 00000000..6d04f926 --- /dev/null +++ b/test/backends/test_tool_validation_integration.py @@ -0,0 +1,359 @@ +"""Integration tests for the validate_tool_arguments function. + +These tests verify that the validation function works correctly with +the actual tool call flow. +""" + +import pytest +from typing import Any, Optional, Union + +from pydantic import ValidationError + +from mellea.backends.tools import validate_tool_arguments +from mellea.core import ModelToolCall + + +# ============================================================================ +# Test Fixtures - Tool Functions +# ============================================================================ + + +def simple_tool(message: str) -> str: + """A simple tool that takes a string. + + Args: + message: The message to process + """ + return f"Processed: {message}" + + +def typed_tool(name: str, age: int, score: float, active: bool) -> dict: + """Tool with multiple primitive types. + + Args: + name: Person's name + age: Person's age in years + score: Performance score + active: Whether person is active + """ + return {"name": name, "age": age, "score": score, "active": active} + + +def optional_tool(required: str, optional: Optional[str] = None) -> str: + """Tool with optional parameters. + + Args: + required: A required parameter + optional: An optional parameter + """ + return f"{required}:{optional or 'none'}" + + +def union_tool(value: Union[str, int]) -> str: + """Tool with union type parameter. + + Args: + value: Can be string or integer + """ + return f"Value: {value} (type: {type(value).__name__})" + + +def list_tool(items: list[str]) -> int: + """Tool with list parameter. + + Args: + items: List of string items + """ + return len(items) + + +def dict_tool(config: dict[str, Any]) -> str: + """Tool with dict parameter. + + Args: + config: Configuration dictionary + """ + import json + + return json.dumps(config) + + +def no_params_tool() -> str: + """Tool with no parameters.""" + return "No params needed" + + +# ============================================================================ +# Test Cases: Type Coercion +# ============================================================================ + + +class TestTypeCoercion: + """Test automatic type coercion with validation.""" + + def test_string_to_int_coercion(self): + """Test that string "30" is coerced to int 30.""" + args = {"name": "Test", "age": "30", "score": 95.5, "active": True} + validated = validate_tool_arguments(typed_tool, args, coerce_types=True) + + assert validated["age"] == 30 + assert isinstance(validated["age"], int) + + def test_string_to_float_coercion(self): + """Test that string "95.5" is coerced to float 95.5.""" + args = {"name": "Test", "age": 30, "score": "95.5", "active": True} + validated = validate_tool_arguments(typed_tool, args, coerce_types=True) + + assert validated["score"] == 95.5 + assert isinstance(validated["score"], float) + + def test_int_to_float_coercion(self): + """Test that int 95 is coerced to float 95.0.""" + args = {"name": "Test", "age": 30, "score": 95, "active": True} + validated = validate_tool_arguments(typed_tool, args, coerce_types=True) + + assert validated["score"] == 95.0 + assert isinstance(validated["score"], float) + + def test_int_to_string_coercion(self): + """Test that int 123 is coerced to string "123".""" + args = {"message": 123} + validated = validate_tool_arguments(simple_tool, args, coerce_types=True) + + assert validated["message"] == "123" + assert isinstance(validated["message"], str) + + def test_bool_coercion_from_int(self): + """Test that int 1/0 is coerced to bool True/False.""" + args = {"name": "Test", "age": 30, "score": 95.5, "active": 1} + validated = validate_tool_arguments(typed_tool, args, coerce_types=True) + + assert validated["active"] is True + assert isinstance(validated["active"], bool) + + args["active"] = 0 + validated = validate_tool_arguments(typed_tool, args, coerce_types=True) + assert validated["active"] is False + + +class TestValidationModes: + """Test strict vs. lenient validation modes.""" + + def test_lenient_mode_with_invalid_type(self): + """Test that lenient mode returns original args on validation failure.""" + args = {"name": "Test", "age": "not_a_number", "score": 95.5, "active": True} + validated = validate_tool_arguments(typed_tool, args, strict=False) + + # Should return original args + assert validated == args + assert validated["age"] == "not_a_number" + + def test_strict_mode_with_invalid_type(self): + """Test that strict mode raises ValidationError on failure.""" + args = {"name": "Test", "age": "not_a_number", "score": 95.5, "active": True} + + with pytest.raises(ValidationError): + validate_tool_arguments(typed_tool, args, strict=True) + + def test_lenient_mode_with_missing_required(self): + """Test lenient mode with missing required parameter.""" + args = {"optional": "value"} # Missing 'required' + validated = validate_tool_arguments(optional_tool, args, strict=False) + + # Should return original args + assert validated == args + + def test_strict_mode_with_missing_required(self): + """Test strict mode with missing required parameter.""" + args = {"optional": "value"} # Missing 'required' + + with pytest.raises(ValidationError): + validate_tool_arguments(optional_tool, args, strict=True) + + +class TestWithModelToolCall: + """Test validation integrated with ModelToolCall.""" + + def test_validated_tool_call_with_coercion(self): + """Test that validated args work correctly with ModelToolCall.""" + # LLM returns age as string + args = {"name": "Alice", "age": "30", "score": "95.5", "active": True} + + # Validate and coerce + validated_args = validate_tool_arguments(typed_tool, args, coerce_types=True) + + # Create tool call with validated args + tool_call = ModelToolCall("typed_tool", typed_tool, validated_args) + result = tool_call.call_func() + + # Verify result has correct types + assert result["age"] == 30 + assert isinstance(result["age"], int) + assert result["score"] == 95.5 + assert isinstance(result["score"], float) + + def test_unvalidated_vs_validated_comparison(self): + """Compare behavior with and without validation.""" + args = {"name": "Bob", "age": "25", "score": "88.7", "active": True} + + # Without validation - types stay as strings + unvalidated_call = ModelToolCall("typed_tool", typed_tool, args) + unvalidated_result = unvalidated_call.call_func() + assert isinstance(unvalidated_result["age"], str) # Still string! + + # With validation - types are coerced + validated_args = validate_tool_arguments(typed_tool, args, coerce_types=True) + validated_call = ModelToolCall("typed_tool", typed_tool, validated_args) + validated_result = validated_call.call_func() + assert isinstance(validated_result["age"], int) # Correctly coerced! + + +class TestOptionalParameters: + """Test validation with optional parameters.""" + + def test_optional_param_provided(self): + """Test validation when optional parameter is provided.""" + args = {"required": "value1", "optional": "value2"} + validated = validate_tool_arguments(optional_tool, args) + + assert validated == args + + def test_optional_param_omitted(self): + """Test validation when optional parameter is omitted.""" + args = {"required": "value1"} + validated = validate_tool_arguments(optional_tool, args) + + assert validated["required"] == "value1" + assert "optional" not in validated or validated.get("optional") is None + + def test_optional_param_none(self): + """Test validation when optional parameter is explicitly None.""" + args = {"required": "value1", "optional": None} + validated = validate_tool_arguments(optional_tool, args) + + assert validated["required"] == "value1" + assert validated["optional"] is None + + +class TestComplexTypes: + """Test validation with complex types.""" + + def test_list_parameter(self): + """Test validation with list parameter.""" + args = {"items": ["apple", "banana", "cherry"]} + validated = validate_tool_arguments(list_tool, args) + + assert validated["items"] == ["apple", "banana", "cherry"] + assert isinstance(validated["items"], list) + + def test_dict_parameter(self): + """Test validation with dict parameter.""" + args = {"config": {"key1": "value1", "key2": 42, "key3": True}} + validated = validate_tool_arguments(dict_tool, args) + + assert validated["config"] == args["config"] + assert isinstance(validated["config"], dict) + + def test_empty_list(self): + """Test validation with empty list.""" + args = {"items": []} + validated = validate_tool_arguments(list_tool, args) + + assert validated["items"] == [] + + +class TestUnionTypes: + """Test validation with union types.""" + + def test_union_with_string(self): + """Test union type with string value.""" + args = {"value": "hello"} + validated = validate_tool_arguments(union_tool, args) + + assert validated["value"] == "hello" + assert isinstance(validated["value"], str) + + def test_union_with_int(self): + """Test union type with integer value.""" + args = {"value": 42} + validated = validate_tool_arguments(union_tool, args) + + assert validated["value"] == 42 + assert isinstance(validated["value"], int) + + def test_union_with_string_number(self): + """Test union type with string that looks like number.""" + args = {"value": "42"} + validated = validate_tool_arguments(union_tool, args, coerce_types=True) + + # Pydantic will try to coerce to the first matching type + # Result depends on Union order and Pydantic's coercion rules + assert validated["value"] in ["42", 42] + + +class TestEdgeCases: + """Test edge cases.""" + + def test_no_parameters_tool(self): + """Test validation with no-parameter tool.""" + args = {} + validated = validate_tool_arguments(no_params_tool, args) + + assert validated == {} + + def test_no_parameters_with_extra_args(self): + """Test that extra args for no-param tool are handled.""" + args = {"fake_param": "should_be_ignored"} + + # In lenient mode, returns original args + validated = validate_tool_arguments(no_params_tool, args, strict=False) + assert validated == args + + # In strict mode, should raise + with pytest.raises(ValidationError): + validate_tool_arguments(no_params_tool, args, strict=True) + + def test_whitespace_stripping(self): + """Test that whitespace is stripped from strings.""" + args = {"message": " hello world "} + validated = validate_tool_arguments(simple_tool, args, coerce_types=True) + + assert validated["message"] == "hello world" + + def test_empty_string(self): + """Test validation with empty string.""" + args = {"message": ""} + validated = validate_tool_arguments(simple_tool, args) + + assert validated["message"] == "" + + +class TestErrorMessages: + """Test that error messages are helpful.""" + + def test_missing_required_error_message(self): + """Test error message for missing required parameter.""" + args = {} + + try: + validate_tool_arguments(simple_tool, args, strict=True) + pytest.fail("Should have raised ValidationError") + except ValidationError as e: + error_str = str(e) + assert "message" in error_str.lower() + assert "required" in error_str.lower() or "missing" in error_str.lower() + + def test_type_mismatch_error_message(self): + """Test error message for type mismatch.""" + args = {"name": "Test", "age": "not_a_number", "score": 95.5, "active": True} + + try: + validate_tool_arguments(typed_tool, args, strict=True) + pytest.fail("Should have raised ValidationError") + except ValidationError as e: + error_str = str(e) + assert "age" in error_str.lower() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 3b3615380db8b6bef852a030a02d1d22e0735411 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Wed, 28 Jan 2026 19:08:39 -0500 Subject: [PATCH 2/8] fix merge error Signed-off-by: Akihiko Kuroda --- mellea/backends/tools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index 36537fb8..63957fba 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -211,6 +211,7 @@ def parse_tools(llm_response: str) -> list[tuple[str, Mapping]]: tool_name, tool_arguments = find_func(possible_tool) if tool_name is not None and tool_arguments is not None: tools.append((tool_name, tool_arguments)) + return tools def validate_tool_arguments( From dcc34fa02a9803c49113a1ef4fd6b3a801c3488d Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Wed, 28 Jan 2026 20:30:42 -0500 Subject: [PATCH 3/8] fix failing tests Signed-off-by: Akihiko Kuroda --- mellea/backends/tools.py | 31 +++++++++++++---- .../backends/test_tool_argument_validation.py | 34 +++++++++---------- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index 63957fba..3a56e998 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -290,22 +290,39 @@ def validate_tool_arguments( # Optional parameter with default field_definitions[param_name] = (param_type, param.default) - # Create dynamic Pydantic model for validation - ValidatorModel = create_model(f"{func.__name__}_Validator", **field_definitions) - # Configure model for type coercion if requested if coerce_types: - # Pydantic v2 uses model_config - ValidatorModel.model_config = ConfigDict( - str_strip_whitespace=True # Strip whitespace from strings - # Pydantic automatically coerces compatible types + model_config = ConfigDict( + str_strip_whitespace=True, + strict=False, # Allow type coercion + extra="forbid" if strict else "allow", # Handle extra fields + # Enable coercion modes for common LLM output issues + coerce_numbers_to_str=True, # Allow int/float -> str + ) + else: + model_config = ConfigDict( + strict=True, # No coercion + extra="forbid" if strict else "allow", ) + # Create dynamic Pydantic model for validation + ValidatorModel = create_model( + f"{func.__name__}_Validator", __config__=model_config, **field_definitions + ) + try: # Validate using Pydantic validated_model = ValidatorModel(**args) validated_args = validated_model.model_dump() + # In lenient mode with extra="allow", Pydantic includes extra fields + # but we need to preserve them from the original args + if not strict: + # Add back any extra fields that weren't in the model + for key, value in args.items(): + if key not in field_definitions: + validated_args[key] = value + # Log successful validation with coercion details coerced_fields = [] for key, original_value in args.items(): diff --git a/test/backends/test_tool_argument_validation.py b/test/backends/test_tool_argument_validation.py index 458a899c..e449f905 100644 --- a/test/backends/test_tool_argument_validation.py +++ b/test/backends/test_tool_argument_validation.py @@ -174,23 +174,23 @@ class TestTypeCoercion: """Test automatic type coercion scenarios.""" def test_string_to_int_coercion(self): - """Test that string "123" can be coerced to int 123.""" - # This currently FAILS without validation + """Test that string "30" works without validation (Python duck typing).""" + # Python's duck typing allows this to work in many cases args = {"name": "Test", "age": "30", "score": 95.5, "active": True} tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) - # Without validation, this will fail at call time - with pytest.raises(TypeError): - tool_call.call_func() + # This actually works due to Python's duck typing + result = tool_call.call_func() + assert result["age"] == "30" # Still a string without validation def test_string_to_float_coercion(self): - """Test that string "95.5" can be coerced to float 95.5.""" + """Test that string "95.5" works without validation (Python duck typing).""" args = {"name": "Test", "age": 30, "score": "95.5", "active": True} tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) - # Without validation, this will fail - with pytest.raises(TypeError): - tool_call.call_func() + # This works due to Python's duck typing + result = tool_call.call_func() + assert result["score"] == "95.5" # Still a string without validation def test_int_to_string_coercion(self): """Test that int 123 can be coerced to string "123".""" @@ -202,14 +202,14 @@ def test_int_to_string_coercion(self): assert "123" in result def test_string_to_bool_coercion(self): - """Test boolean coercion from strings.""" + """Test boolean from strings works without validation (Python duck typing).""" # Common LLM outputs: "true", "false", "True", "False" args = {"name": "Test", "age": 30, "score": 95.5, "active": "true"} tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) - # Without validation, this will fail - with pytest.raises(TypeError): - tool_call.call_func() + # This works due to Python's duck typing - non-empty strings are truthy + result = tool_call.call_func() + assert result["active"] == "true" # Still a string without validation # ============================================================================ @@ -369,13 +369,13 @@ def test_extra_arguments(self): tool_call.call_func() def test_wrong_type_no_coercion(self): - """Test that wrong types fail without coercion.""" + """Test that wrong types work without validation (Python duck typing).""" args = {"name": "Test", "age": "not_a_number", "score": 95.5, "active": True} tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) - # This will fail when Python tries to use the value - with pytest.raises((TypeError, ValueError)): - tool_call.call_func() + # Python's duck typing allows this - the function just returns what it gets + result = tool_call.call_func() + assert result["age"] == "not_a_number" # Still a string def test_none_for_required_param(self): """Test that None for required parameter fails.""" From 7daedc4aafcf5888c92d0d04d3d74521c53659db Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Thu, 29 Jan 2026 17:38:01 -0500 Subject: [PATCH 4/8] review comments Signed-off-by: Akihiko Kuroda --- .../backends/test_tool_argument_validation.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/test/backends/test_tool_argument_validation.py b/test/backends/test_tool_argument_validation.py index e449f905..78859c6b 100644 --- a/test/backends/test_tool_argument_validation.py +++ b/test/backends/test_tool_argument_validation.py @@ -587,27 +587,5 @@ def test_parse_malformed_json(self): assert len(tools) == 2 -# ============================================================================ -# Integration Test Markers -# ============================================================================ - - -@pytest.mark.integration -class TestToolValidationIntegration: - """Integration tests that would use actual validation function.""" - - @pytest.mark.skip(reason="Validation function not yet implemented") - def test_validation_with_coercion(self): - """Test validation with type coercion enabled.""" - # This test will be enabled once validation is implemented - pass - - @pytest.mark.skip(reason="Validation function not yet implemented") - def test_validation_strict_mode(self): - """Test validation in strict mode.""" - # This test will be enabled once validation is implemented - pass - - if __name__ == "__main__": pytest.main([__file__, "-v"]) From 2a3d34be7ec2ae495b709cd0a94388a9a9beddb4 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Thu, 29 Jan 2026 21:25:26 -0500 Subject: [PATCH 5/8] refactor for MelleaTool Signed-off-by: Akihiko Kuroda --- mellea/backends/tools.py | 79 +++++++++------- .../test_tool_validation_integration.py | 90 ++++++++++++------- 2 files changed, 103 insertions(+), 66 deletions(-) diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index 3a56e998..e909ce62 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -215,20 +215,20 @@ def parse_tools(llm_response: str) -> list[tuple[str, Mapping]]: def validate_tool_arguments( - func: Callable, + tool: AbstractMelleaTool, args: Mapping[str, Any], *, coerce_types: bool = True, strict: bool = False, ) -> dict[str, Any]: - """Validate and optionally coerce tool arguments against function signature. + """Validate and optionally coerce tool arguments against tool's JSON schema. This function validates tool call arguments extracted from LLM responses against - the expected function signature. It can automatically coerce common type mismatches - (e.g., string "30" to int 30) and provides detailed error messages. + the tool's JSON schema from as_json_tool. It can automatically coerce common type + mismatches (e.g., string "30" to int 30) and provides detailed error messages. Args: - func: The tool function to validate against + tool: The MelleaTool instance to validate against args: Raw arguments from model (post-JSON parsing) coerce_types: If True, attempt type coercion for common cases (default: True) strict: If True, raise ValidationError on failures; if False, log warnings @@ -243,16 +243,17 @@ def validate_tool_arguments( Examples: >>> def get_weather(location: str, days: int = 1) -> dict: ... return {"location": location, "days": days} + >>> tool = MelleaTool.from_callable(get_weather) >>> # LLM returns days as string >>> args = {"location": "Boston", "days": "3"} - >>> validated = validate_tool_arguments(get_weather, args) + >>> validated = validate_tool_arguments(tool, args) >>> validated {'location': 'Boston', 'days': 3} >>> # Strict mode raises on validation errors >>> bad_args = {"location": "Boston", "days": "not_a_number"} - >>> validate_tool_arguments(get_weather, bad_args, strict=True) + >>> validate_tool_arguments(tool, bad_args, strict=True) Traceback (most recent call last): ... pydantic.ValidationError: ... @@ -261,34 +262,45 @@ def validate_tool_arguments( from ..core import FancyLogger - # Get function signature - sig = inspect.signature(func) - - # Build Pydantic model from function signature - # This reuses the logic from convert_function_to_tool + # Extract JSON schema from tool + tool_schema = tool.as_json_tool.get("function", {}) + tool_name = tool_schema.get("name", "unknown_tool") + parameters = tool_schema.get("parameters", {}) + properties = parameters.get("properties", {}) + required_fields = parameters.get("required", []) + + # Map JSON schema types to Python types + JSON_TYPE_TO_PYTHON = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "array": list, + "object": dict, + } + + # Build Pydantic model from JSON schema field_definitions: dict[str, Any] = {} - for param_name, param in sig.parameters.items(): - # Skip *args and **kwargs - if param.kind in ( - inspect.Parameter.VAR_POSITIONAL, - inspect.Parameter.VAR_KEYWORD, - ): - continue + for param_name, param_schema in properties.items(): + # Get type from JSON schema + json_type = param_schema.get("type", "string") + + # Handle comma-separated types (e.g., "string, integer") + if isinstance(json_type, str) and "," in json_type: + # Take the first type for simplicity + json_type = json_type.split(",")[0].strip() - # Get type annotation - param_type = param.annotation - if param_type == inspect.Parameter.empty: - # No type hint, default to Any - param_type = Any + # Map to Python type + param_type = JSON_TYPE_TO_PYTHON.get(json_type, Any) - # Handle default values - if param.default == inspect.Parameter.empty: + # Determine if parameter is required + if param_name in required_fields: # Required parameter field_definitions[param_name] = (param_type, ...) else: - # Optional parameter with default - field_definitions[param_name] = (param_type, param.default) + # Optional parameter (default to None) + field_definitions[param_name] = (param_type, None) # Configure model for type coercion if requested if coerce_types: @@ -307,7 +319,7 @@ def validate_tool_arguments( # Create dynamic Pydantic model for validation ValidatorModel = create_model( - f"{func.__name__}_Validator", __config__=model_config, **field_definitions + f"{tool_name}_Validator", __config__=model_config, **field_definitions ) try: @@ -334,7 +346,7 @@ def validate_tool_arguments( if coerced_fields and coerce_types: FancyLogger.get_logger().debug( - f"Tool '{func.__name__}' arguments coerced: {', '.join(coerced_fields)}" + f"Tool '{tool_name}' arguments coerced: {', '.join(coerced_fields)}" ) return validated_args @@ -347,9 +359,8 @@ def validate_tool_arguments( msg = error["msg"] error_details.append(f" - {field}: {msg}") - error_msg = ( - f"Tool argument validation failed for '{func.__name__}':\n" - + "\n".join(error_details) + error_msg = f"Tool argument validation failed for '{tool_name}':\n" + "\n".join( + error_details ) if strict: @@ -365,7 +376,7 @@ def validate_tool_arguments( except Exception as e: # Catch any other errors during validation - error_msg = f"Unexpected error validating tool '{func.__name__}' arguments: {e}" + error_msg = f"Unexpected error validating tool '{tool_name}' arguments: {e}" if strict: FancyLogger.get_logger().error(error_msg) diff --git a/test/backends/test_tool_validation_integration.py b/test/backends/test_tool_validation_integration.py index 6d04f926..7325af4b 100644 --- a/test/backends/test_tool_validation_integration.py +++ b/test/backends/test_tool_validation_integration.py @@ -9,7 +9,7 @@ from pydantic import ValidationError -from mellea.backends.tools import validate_tool_arguments +from mellea.backends.tools import MelleaTool, validate_tool_arguments from mellea.core import ModelToolCall @@ -94,7 +94,8 @@ class TestTypeCoercion: def test_string_to_int_coercion(self): """Test that string "30" is coerced to int 30.""" args = {"name": "Test", "age": "30", "score": 95.5, "active": True} - validated = validate_tool_arguments(typed_tool, args, coerce_types=True) + tool = MelleaTool.from_callable(typed_tool) + validated = validate_tool_arguments(tool, args, coerce_types=True) assert validated["age"] == 30 assert isinstance(validated["age"], int) @@ -102,7 +103,8 @@ def test_string_to_int_coercion(self): def test_string_to_float_coercion(self): """Test that string "95.5" is coerced to float 95.5.""" args = {"name": "Test", "age": 30, "score": "95.5", "active": True} - validated = validate_tool_arguments(typed_tool, args, coerce_types=True) + tool = MelleaTool.from_callable(typed_tool) + validated = validate_tool_arguments(tool, args, coerce_types=True) assert validated["score"] == 95.5 assert isinstance(validated["score"], float) @@ -110,7 +112,8 @@ def test_string_to_float_coercion(self): def test_int_to_float_coercion(self): """Test that int 95 is coerced to float 95.0.""" args = {"name": "Test", "age": 30, "score": 95, "active": True} - validated = validate_tool_arguments(typed_tool, args, coerce_types=True) + tool = MelleaTool.from_callable(typed_tool) + validated = validate_tool_arguments(tool, args, coerce_types=True) assert validated["score"] == 95.0 assert isinstance(validated["score"], float) @@ -118,7 +121,8 @@ def test_int_to_float_coercion(self): def test_int_to_string_coercion(self): """Test that int 123 is coerced to string "123".""" args = {"message": 123} - validated = validate_tool_arguments(simple_tool, args, coerce_types=True) + tool = MelleaTool.from_callable(simple_tool) + validated = validate_tool_arguments(tool, args, coerce_types=True) assert validated["message"] == "123" assert isinstance(validated["message"], str) @@ -126,13 +130,14 @@ def test_int_to_string_coercion(self): def test_bool_coercion_from_int(self): """Test that int 1/0 is coerced to bool True/False.""" args = {"name": "Test", "age": 30, "score": 95.5, "active": 1} - validated = validate_tool_arguments(typed_tool, args, coerce_types=True) + tool = MelleaTool.from_callable(typed_tool) + validated = validate_tool_arguments(tool, args, coerce_types=True) assert validated["active"] is True assert isinstance(validated["active"], bool) args["active"] = 0 - validated = validate_tool_arguments(typed_tool, args, coerce_types=True) + validated = validate_tool_arguments(tool, args, coerce_types=True) assert validated["active"] is False @@ -142,7 +147,8 @@ class TestValidationModes: def test_lenient_mode_with_invalid_type(self): """Test that lenient mode returns original args on validation failure.""" args = {"name": "Test", "age": "not_a_number", "score": 95.5, "active": True} - validated = validate_tool_arguments(typed_tool, args, strict=False) + tool = MelleaTool.from_callable(typed_tool) + validated = validate_tool_arguments(tool, args, strict=False) # Should return original args assert validated == args @@ -151,14 +157,16 @@ def test_lenient_mode_with_invalid_type(self): def test_strict_mode_with_invalid_type(self): """Test that strict mode raises ValidationError on failure.""" args = {"name": "Test", "age": "not_a_number", "score": 95.5, "active": True} + tool = MelleaTool.from_callable(typed_tool) with pytest.raises(ValidationError): - validate_tool_arguments(typed_tool, args, strict=True) + validate_tool_arguments(tool, args, strict=True) def test_lenient_mode_with_missing_required(self): """Test lenient mode with missing required parameter.""" args = {"optional": "value"} # Missing 'required' - validated = validate_tool_arguments(optional_tool, args, strict=False) + tool = MelleaTool.from_callable(optional_tool) + validated = validate_tool_arguments(tool, args, strict=False) # Should return original args assert validated == args @@ -166,9 +174,10 @@ def test_lenient_mode_with_missing_required(self): def test_strict_mode_with_missing_required(self): """Test strict mode with missing required parameter.""" args = {"optional": "value"} # Missing 'required' + tool = MelleaTool.from_callable(optional_tool) with pytest.raises(ValidationError): - validate_tool_arguments(optional_tool, args, strict=True) + validate_tool_arguments(tool, args, strict=True) class TestWithModelToolCall: @@ -178,12 +187,13 @@ def test_validated_tool_call_with_coercion(self): """Test that validated args work correctly with ModelToolCall.""" # LLM returns age as string args = {"name": "Alice", "age": "30", "score": "95.5", "active": True} + tool = MelleaTool.from_callable(typed_tool) # Validate and coerce - validated_args = validate_tool_arguments(typed_tool, args, coerce_types=True) + validated_args = validate_tool_arguments(tool, args, coerce_types=True) # Create tool call with validated args - tool_call = ModelToolCall("typed_tool", typed_tool, validated_args) + tool_call = ModelToolCall("typed_tool", tool, validated_args) result = tool_call.call_func() # Verify result has correct types @@ -195,15 +205,16 @@ def test_validated_tool_call_with_coercion(self): def test_unvalidated_vs_validated_comparison(self): """Compare behavior with and without validation.""" args = {"name": "Bob", "age": "25", "score": "88.7", "active": True} + tool = MelleaTool.from_callable(typed_tool) # Without validation - types stay as strings - unvalidated_call = ModelToolCall("typed_tool", typed_tool, args) + unvalidated_call = ModelToolCall("typed_tool", tool, args) unvalidated_result = unvalidated_call.call_func() assert isinstance(unvalidated_result["age"], str) # Still string! # With validation - types are coerced - validated_args = validate_tool_arguments(typed_tool, args, coerce_types=True) - validated_call = ModelToolCall("typed_tool", typed_tool, validated_args) + validated_args = validate_tool_arguments(tool, args, coerce_types=True) + validated_call = ModelToolCall("typed_tool", tool, validated_args) validated_result = validated_call.call_func() assert isinstance(validated_result["age"], int) # Correctly coerced! @@ -214,14 +225,16 @@ class TestOptionalParameters: def test_optional_param_provided(self): """Test validation when optional parameter is provided.""" args = {"required": "value1", "optional": "value2"} - validated = validate_tool_arguments(optional_tool, args) + tool = MelleaTool.from_callable(optional_tool) + validated = validate_tool_arguments(tool, args) assert validated == args def test_optional_param_omitted(self): """Test validation when optional parameter is omitted.""" args = {"required": "value1"} - validated = validate_tool_arguments(optional_tool, args) + tool = MelleaTool.from_callable(optional_tool) + validated = validate_tool_arguments(tool, args) assert validated["required"] == "value1" assert "optional" not in validated or validated.get("optional") is None @@ -229,7 +242,8 @@ def test_optional_param_omitted(self): def test_optional_param_none(self): """Test validation when optional parameter is explicitly None.""" args = {"required": "value1", "optional": None} - validated = validate_tool_arguments(optional_tool, args) + tool = MelleaTool.from_callable(optional_tool) + validated = validate_tool_arguments(tool, args) assert validated["required"] == "value1" assert validated["optional"] is None @@ -241,7 +255,8 @@ class TestComplexTypes: def test_list_parameter(self): """Test validation with list parameter.""" args = {"items": ["apple", "banana", "cherry"]} - validated = validate_tool_arguments(list_tool, args) + tool = MelleaTool.from_callable(list_tool) + validated = validate_tool_arguments(tool, args) assert validated["items"] == ["apple", "banana", "cherry"] assert isinstance(validated["items"], list) @@ -249,7 +264,8 @@ def test_list_parameter(self): def test_dict_parameter(self): """Test validation with dict parameter.""" args = {"config": {"key1": "value1", "key2": 42, "key3": True}} - validated = validate_tool_arguments(dict_tool, args) + tool = MelleaTool.from_callable(dict_tool) + validated = validate_tool_arguments(tool, args) assert validated["config"] == args["config"] assert isinstance(validated["config"], dict) @@ -257,7 +273,8 @@ def test_dict_parameter(self): def test_empty_list(self): """Test validation with empty list.""" args = {"items": []} - validated = validate_tool_arguments(list_tool, args) + tool = MelleaTool.from_callable(list_tool) + validated = validate_tool_arguments(tool, args) assert validated["items"] == [] @@ -268,7 +285,8 @@ class TestUnionTypes: def test_union_with_string(self): """Test union type with string value.""" args = {"value": "hello"} - validated = validate_tool_arguments(union_tool, args) + tool = MelleaTool.from_callable(union_tool) + validated = validate_tool_arguments(tool, args) assert validated["value"] == "hello" assert isinstance(validated["value"], str) @@ -276,7 +294,8 @@ def test_union_with_string(self): def test_union_with_int(self): """Test union type with integer value.""" args = {"value": 42} - validated = validate_tool_arguments(union_tool, args) + tool = MelleaTool.from_callable(union_tool) + validated = validate_tool_arguments(tool, args) assert validated["value"] == 42 assert isinstance(validated["value"], int) @@ -284,7 +303,8 @@ def test_union_with_int(self): def test_union_with_string_number(self): """Test union type with string that looks like number.""" args = {"value": "42"} - validated = validate_tool_arguments(union_tool, args, coerce_types=True) + tool = MelleaTool.from_callable(union_tool) + validated = validate_tool_arguments(tool, args, coerce_types=True) # Pydantic will try to coerce to the first matching type # Result depends on Union order and Pydantic's coercion rules @@ -297,33 +317,37 @@ class TestEdgeCases: def test_no_parameters_tool(self): """Test validation with no-parameter tool.""" args = {} - validated = validate_tool_arguments(no_params_tool, args) + tool = MelleaTool.from_callable(no_params_tool) + validated = validate_tool_arguments(tool, args) assert validated == {} def test_no_parameters_with_extra_args(self): """Test that extra args for no-param tool are handled.""" args = {"fake_param": "should_be_ignored"} + tool = MelleaTool.from_callable(no_params_tool) # In lenient mode, returns original args - validated = validate_tool_arguments(no_params_tool, args, strict=False) + validated = validate_tool_arguments(tool, args, strict=False) assert validated == args # In strict mode, should raise with pytest.raises(ValidationError): - validate_tool_arguments(no_params_tool, args, strict=True) + validate_tool_arguments(tool, args, strict=True) def test_whitespace_stripping(self): """Test that whitespace is stripped from strings.""" args = {"message": " hello world "} - validated = validate_tool_arguments(simple_tool, args, coerce_types=True) + tool = MelleaTool.from_callable(simple_tool) + validated = validate_tool_arguments(tool, args, coerce_types=True) assert validated["message"] == "hello world" def test_empty_string(self): """Test validation with empty string.""" args = {"message": ""} - validated = validate_tool_arguments(simple_tool, args) + tool = MelleaTool.from_callable(simple_tool) + validated = validate_tool_arguments(tool, args) assert validated["message"] == "" @@ -334,9 +358,10 @@ class TestErrorMessages: def test_missing_required_error_message(self): """Test error message for missing required parameter.""" args = {} + tool = MelleaTool.from_callable(simple_tool) try: - validate_tool_arguments(simple_tool, args, strict=True) + validate_tool_arguments(tool, args, strict=True) pytest.fail("Should have raised ValidationError") except ValidationError as e: error_str = str(e) @@ -346,9 +371,10 @@ def test_missing_required_error_message(self): def test_type_mismatch_error_message(self): """Test error message for type mismatch.""" args = {"name": "Test", "age": "not_a_number", "score": 95.5, "active": True} + tool = MelleaTool.from_callable(typed_tool) try: - validate_tool_arguments(typed_tool, args, strict=True) + validate_tool_arguments(tool, args, strict=True) pytest.fail("Should have raised ValidationError") except ValidationError as e: error_str = str(e) From 70cd10ec41f40924db5581bea27dcc3ac614b639 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Fri, 30 Jan 2026 08:33:55 -0500 Subject: [PATCH 6/8] CI test fix Signed-off-by: Akihiko Kuroda --- mellea/backends/tools.py | 27 ++- .../backends/test_tool_argument_validation.py | 168 ++++++++++++++---- 2 files changed, 152 insertions(+), 43 deletions(-) diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index e909ce62..ec24b96b 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -286,13 +286,30 @@ def validate_tool_arguments( # Get type from JSON schema json_type = param_schema.get("type", "string") - # Handle comma-separated types (e.g., "string, integer") + # Handle comma-separated types (e.g., "integer, string" for Union types) if isinstance(json_type, str) and "," in json_type: - # Take the first type for simplicity - json_type = json_type.split(",")[0].strip() + # Create Union type for multiple types + type_list = [t.strip() for t in json_type.split(",")] + python_types = [JSON_TYPE_TO_PYTHON.get(t, Any) for t in type_list] + # Remove duplicates while preserving order + seen = set() + unique_types = [] + for t in python_types: + if t not in seen: + seen.add(t) + unique_types.append(t) + + if len(unique_types) == 1: + param_type = unique_types[0] + else: + # Use modern union syntax (Python 3.10+) + from functools import reduce + from operator import or_ - # Map to Python type - param_type = JSON_TYPE_TO_PYTHON.get(json_type, Any) + param_type = reduce(or_, unique_types) + else: + # Map to Python type + param_type = JSON_TYPE_TO_PYTHON.get(json_type, Any) # Determine if parameter is required if param_name in required_fields: diff --git a/test/backends/test_tool_argument_validation.py b/test/backends/test_tool_argument_validation.py index 78859c6b..2e23749e 100644 --- a/test/backends/test_tool_argument_validation.py +++ b/test/backends/test_tool_argument_validation.py @@ -16,7 +16,7 @@ import pytest from pydantic import BaseModel, ValidationError -from mellea.backends.tools import parse_tools +from mellea.backends.tools import MelleaTool, parse_tools from mellea.core import ModelToolCall @@ -136,14 +136,20 @@ class TestBasicTypeValidation: def test_string_argument(self): """Test simple string argument.""" args = {"message": "Hello, World!"} - tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + tool_call = ModelToolCall( + "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args + ) result = tool_call.call_func() assert result == "Processed: Hello, World!" def test_integer_argument(self): """Test integer argument.""" args = {"name": "Alice", "age": 30, "score": 95.5, "active": True} - tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + tool_call = ModelToolCall( + "typed_primitives_tool", + MelleaTool.from_callable(typed_primitives_tool), + args, + ) result = tool_call.call_func() assert result["age"] == 30 assert isinstance(result["age"], int) @@ -151,7 +157,11 @@ def test_integer_argument(self): def test_float_argument(self): """Test float argument.""" args = {"name": "Bob", "age": 25, "score": 88.7, "active": False} - tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + tool_call = ModelToolCall( + "typed_primitives_tool", + MelleaTool.from_callable(typed_primitives_tool), + args, + ) result = tool_call.call_func() assert result["score"] == 88.7 assert isinstance(result["score"], float) @@ -159,7 +169,11 @@ def test_float_argument(self): def test_boolean_argument(self): """Test boolean argument.""" args = {"name": "Charlie", "age": 35, "score": 92.0, "active": True} - tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + tool_call = ModelToolCall( + "typed_primitives_tool", + MelleaTool.from_callable(typed_primitives_tool), + args, + ) result = tool_call.call_func() assert result["active"] is True assert isinstance(result["active"], bool) @@ -177,7 +191,11 @@ def test_string_to_int_coercion(self): """Test that string "30" works without validation (Python duck typing).""" # Python's duck typing allows this to work in many cases args = {"name": "Test", "age": "30", "score": 95.5, "active": True} - tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + tool_call = ModelToolCall( + "typed_primitives_tool", + MelleaTool.from_callable(typed_primitives_tool), + args, + ) # This actually works due to Python's duck typing result = tool_call.call_func() @@ -186,7 +204,11 @@ def test_string_to_int_coercion(self): def test_string_to_float_coercion(self): """Test that string "95.5" works without validation (Python duck typing).""" args = {"name": "Test", "age": 30, "score": "95.5", "active": True} - tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + tool_call = ModelToolCall( + "typed_primitives_tool", + MelleaTool.from_callable(typed_primitives_tool), + args, + ) # This works due to Python's duck typing result = tool_call.call_func() @@ -195,7 +217,9 @@ def test_string_to_float_coercion(self): def test_int_to_string_coercion(self): """Test that int 123 can be coerced to string "123".""" args = {"message": 123} # Should be string - tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + tool_call = ModelToolCall( + "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args + ) # This might work due to Python's duck typing, but not guaranteed result = tool_call.call_func() @@ -205,7 +229,11 @@ def test_string_to_bool_coercion(self): """Test boolean from strings works without validation (Python duck typing).""" # Common LLM outputs: "true", "false", "True", "False" args = {"name": "Test", "age": 30, "score": 95.5, "active": "true"} - tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + tool_call = ModelToolCall( + "typed_primitives_tool", + MelleaTool.from_callable(typed_primitives_tool), + args, + ) # This works due to Python's duck typing - non-empty strings are truthy result = tool_call.call_func() @@ -223,42 +251,54 @@ class TestOptionalParameters: def test_optional_param_provided(self): """Test optional parameter when provided.""" args = {"required": "value1", "optional": "value2"} - tool_call = ModelToolCall("optional_params_tool", optional_params_tool, args) + tool_call = ModelToolCall( + "optional_params_tool", MelleaTool.from_callable(optional_params_tool), args + ) result = tool_call.call_func() assert result == "value1:value2" def test_optional_param_omitted(self): """Test optional parameter when omitted.""" args = {"required": "value1"} - tool_call = ModelToolCall("optional_params_tool", optional_params_tool, args) + tool_call = ModelToolCall( + "optional_params_tool", MelleaTool.from_callable(optional_params_tool), args + ) result = tool_call.call_func() assert result == "value1:none" def test_optional_param_none(self): """Test optional parameter explicitly set to None.""" args = {"required": "value1", "optional": None} - tool_call = ModelToolCall("optional_params_tool", optional_params_tool, args) + tool_call = ModelToolCall( + "optional_params_tool", MelleaTool.from_callable(optional_params_tool), args + ) result = tool_call.call_func() assert result == "value1:none" def test_default_values_all_provided(self): """Test tool with all default values provided.""" args = {"name": "test", "count": 5, "prefix": "custom"} - tool_call = ModelToolCall("default_values_tool", default_values_tool, args) + tool_call = ModelToolCall( + "default_values_tool", MelleaTool.from_callable(default_values_tool), args + ) result = tool_call.call_func() assert result == "custom_test_5" def test_default_values_partial(self): """Test tool with some default values omitted.""" args = {"name": "test", "count": 7} - tool_call = ModelToolCall("default_values_tool", default_values_tool, args) + tool_call = ModelToolCall( + "default_values_tool", MelleaTool.from_callable(default_values_tool), args + ) result = tool_call.call_func() assert result == "item_test_7" def test_default_values_minimal(self): """Test tool with only required parameters.""" args = {"name": "test"} - tool_call = ModelToolCall("default_values_tool", default_values_tool, args) + tool_call = ModelToolCall( + "default_values_tool", MelleaTool.from_callable(default_values_tool), args + ) result = tool_call.call_func() assert result == "item_test_10" @@ -274,7 +314,9 @@ class TestUnionTypes: def test_union_with_string(self): """Test union type with string value.""" args = {"value": "hello"} - tool_call = ModelToolCall("union_type_tool", union_type_tool, args) + tool_call = ModelToolCall( + "union_type_tool", MelleaTool.from_callable(union_type_tool), args + ) result = tool_call.call_func() assert "hello" in result assert "str" in result @@ -282,7 +324,9 @@ def test_union_with_string(self): def test_union_with_int(self): """Test union type with integer value.""" args = {"value": 42} - tool_call = ModelToolCall("union_type_tool", union_type_tool, args) + tool_call = ModelToolCall( + "union_type_tool", MelleaTool.from_callable(union_type_tool), args + ) result = tool_call.call_func() assert "42" in result assert "int" in result @@ -291,7 +335,9 @@ def test_union_with_string_number(self): """Test union type with string that looks like number.""" # Without validation, this stays as string args = {"value": "42"} - tool_call = ModelToolCall("union_type_tool", union_type_tool, args) + tool_call = ModelToolCall( + "union_type_tool", MelleaTool.from_callable(union_type_tool), args + ) result = tool_call.call_func() assert "42" in result # Type depends on whether validation coerces @@ -308,21 +354,27 @@ class TestComplexTypes: def test_list_of_strings(self): """Test list parameter with strings.""" args = {"items": ["apple", "banana", "cherry"]} - tool_call = ModelToolCall("list_param_tool", list_param_tool, args) + tool_call = ModelToolCall( + "list_param_tool", MelleaTool.from_callable(list_param_tool), args + ) result = tool_call.call_func() assert result == 3 def test_empty_list(self): """Test empty list parameter.""" args = {"items": []} - tool_call = ModelToolCall("list_param_tool", list_param_tool, args) + tool_call = ModelToolCall( + "list_param_tool", MelleaTool.from_callable(list_param_tool), args + ) result = tool_call.call_func() assert result == 0 def test_dict_parameter(self): """Test dictionary parameter.""" args = {"config": {"key1": "value1", "key2": 42, "key3": True}} - tool_call = ModelToolCall("dict_param_tool", dict_param_tool, args) + tool_call = ModelToolCall( + "dict_param_tool", MelleaTool.from_callable(dict_param_tool), args + ) result = tool_call.call_func() parsed = json.loads(result) assert parsed["key1"] == "value1" @@ -331,14 +383,22 @@ def test_dict_parameter(self): def test_nested_structure(self): """Test nested dictionary with lists.""" args = {"data": {"group1": [1, 2, 3], "group2": [4, 5], "group3": [6]}} - tool_call = ModelToolCall("nested_structure_tool", nested_structure_tool, args) + tool_call = ModelToolCall( + "nested_structure_tool", + MelleaTool.from_callable(nested_structure_tool), + args, + ) result = tool_call.call_func() assert result == 21 # Sum of all numbers def test_nested_structure_empty(self): """Test nested structure with empty lists.""" args = {"data": {"group1": [], "group2": []}} - tool_call = ModelToolCall("nested_structure_tool", nested_structure_tool, args) + tool_call = ModelToolCall( + "nested_structure_tool", + MelleaTool.from_callable(nested_structure_tool), + args, + ) result = tool_call.call_func() assert result == 0 @@ -354,7 +414,9 @@ class TestErrorConditions: def test_missing_required_argument(self): """Test that missing required argument raises error.""" args = {} # Missing 'message' - tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + tool_call = ModelToolCall( + "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args + ) with pytest.raises(TypeError, match="missing.*required"): tool_call.call_func() @@ -362,7 +424,9 @@ def test_missing_required_argument(self): def test_extra_arguments(self): """Test that extra arguments are ignored (Python behavior).""" args = {"message": "test", "extra": "ignored"} - tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + tool_call = ModelToolCall( + "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args + ) # Python ignores extra kwargs by default with pytest.raises(TypeError, match="unexpected keyword argument"): @@ -371,7 +435,11 @@ def test_extra_arguments(self): def test_wrong_type_no_coercion(self): """Test that wrong types work without validation (Python duck typing).""" args = {"name": "Test", "age": "not_a_number", "score": 95.5, "active": True} - tool_call = ModelToolCall("typed_primitives_tool", typed_primitives_tool, args) + tool_call = ModelToolCall( + "typed_primitives_tool", + MelleaTool.from_callable(typed_primitives_tool), + args, + ) # Python's duck typing allows this - the function just returns what it gets result = tool_call.call_func() @@ -380,7 +448,9 @@ def test_wrong_type_no_coercion(self): def test_none_for_required_param(self): """Test that None for required parameter fails.""" args = {"message": None} - tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + tool_call = ModelToolCall( + "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args + ) # Depends on function implementation result = tool_call.call_func() @@ -399,7 +469,9 @@ def test_valid_json_string(self): """Test parsing valid JSON string.""" json_str = '{"message": "Hello, World!"}' args = json.loads(json_str) - tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + tool_call = ModelToolCall( + "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args + ) result = tool_call.call_func() assert "Hello, World!" in result @@ -414,7 +486,9 @@ def test_json_with_escaped_quotes(self): """Test JSON with escaped quotes.""" json_str = '{"message": "He said \\"Hello\\""}' args = json.loads(json_str) - tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + tool_call = ModelToolCall( + "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args + ) result = tool_call.call_func() assert "Hello" in result @@ -422,7 +496,9 @@ def test_json_with_unicode(self): """Test JSON with unicode characters.""" json_str = '{"message": "Hello δΈ–η•Œ 🌍"}' args = json.loads(json_str) - tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + tool_call = ModelToolCall( + "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args + ) result = tool_call.call_func() assert "δΈ–η•Œ" in result @@ -430,7 +506,9 @@ def test_json_with_nested_objects(self): """Test JSON with nested objects.""" json_str = '{"config": {"nested": {"key": "value"}, "list": [1, 2, 3]}}' args = json.loads(json_str) - tool_call = ModelToolCall("dict_param_tool", dict_param_tool, args) + tool_call = ModelToolCall( + "dict_param_tool", MelleaTool.from_callable(dict_param_tool), args + ) result = tool_call.call_func() parsed = json.loads(result) assert parsed["nested"]["key"] == "value" @@ -454,7 +532,9 @@ def test_pydantic_model_from_dict(self): args_with_model = {"user": user} tool_call = ModelToolCall( - "pydantic_model_tool", pydantic_model_tool, args_with_model + "pydantic_model_tool", + MelleaTool.from_callable(pydantic_model_tool), + args_with_model, ) result = tool_call.call_func() assert "Alice" in result @@ -473,7 +553,9 @@ def test_pydantic_model_with_optional(self): user = UserModel(**user_data) args = {"user": user} - tool_call = ModelToolCall("pydantic_model_tool", pydantic_model_tool, args) + tool_call = ModelToolCall( + "pydantic_model_tool", MelleaTool.from_callable(pydantic_model_tool), args + ) result = tool_call.call_func() assert "Charlie" in result @@ -489,7 +571,9 @@ class TestEdgeCases: def test_no_parameters_tool(self): """Test tool with no parameters.""" args = {} - tool_call = ModelToolCall("no_params_tool", no_params_tool, args) + tool_call = ModelToolCall( + "no_params_tool", MelleaTool.from_callable(no_params_tool), args + ) result = tool_call.call_func() assert result == "No params needed" @@ -497,7 +581,9 @@ def test_no_parameters_with_hallucinated_args(self): """Test that hallucinated args for no-param tool are handled.""" # Models sometimes hallucinate parameters args = {"fake_param": "should_be_ignored"} - tool_call = ModelToolCall("no_params_tool", no_params_tool, args) + tool_call = ModelToolCall( + "no_params_tool", MelleaTool.from_callable(no_params_tool), args + ) # This should fail without validation that clears args with pytest.raises(TypeError): @@ -507,7 +593,9 @@ def test_very_long_string(self): """Test with very long string argument.""" long_string = "x" * 10000 args = {"message": long_string} - tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + tool_call = ModelToolCall( + "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args + ) result = tool_call.call_func() assert len(result) > 10000 @@ -515,14 +603,18 @@ def test_special_characters_in_string(self): """Test with special characters.""" special = "!@#$%^&*()_+-=[]{}|;:',.<>?/~`" args = {"message": special} - tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + tool_call = ModelToolCall( + "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args + ) result = tool_call.call_func() assert special in result def test_empty_string(self): """Test with empty string.""" args = {"message": ""} - tool_call = ModelToolCall("simple_string_tool", simple_string_tool, args) + tool_call = ModelToolCall( + "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args + ) result = tool_call.call_func() assert result == "Processed: " From 301568228a3496002dff58b7e083cec3d2090fd3 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 2 Feb 2026 10:07:30 -0500 Subject: [PATCH 7/8] review comments Signed-off-by: Akihiko Kuroda --- .../backends/test_tool_argument_validation.py | 683 ------------------ .../test_tool_validation_integration.py | 75 ++ 2 files changed, 75 insertions(+), 683 deletions(-) delete mode 100644 test/backends/test_tool_argument_validation.py diff --git a/test/backends/test_tool_argument_validation.py b/test/backends/test_tool_argument_validation.py deleted file mode 100644 index 2e23749e..00000000 --- a/test/backends/test_tool_argument_validation.py +++ /dev/null @@ -1,683 +0,0 @@ -"""Comprehensive test suite for tool call argument validation. - -Tests cover: -- Basic type validation and coercion -- Complex nested types (dicts, lists) -- Union types and Optional parameters -- Missing required arguments -- Extra arguments -- Malformed JSON parsing -- Pydantic model arguments -""" - -import json -from typing import Any, Optional, Union - -import pytest -from pydantic import BaseModel, ValidationError - -from mellea.backends.tools import MelleaTool, parse_tools -from mellea.core import ModelToolCall - - -# ============================================================================ -# Test Fixtures - Tool Functions with Various Signatures -# ============================================================================ - - -def simple_string_tool(message: str) -> str: - """A simple tool that takes a string. - - Args: - message: The message to process - """ - return f"Processed: {message}" - - -def typed_primitives_tool(name: str, age: int, score: float, active: bool) -> dict: - """Tool with multiple primitive types. - - Args: - name: Person's name - age: Person's age in years - score: Performance score - active: Whether person is active - """ - return {"name": name, "age": age, "score": score, "active": active} - - -def optional_params_tool(required: str, optional: Optional[str] = None) -> str: - """Tool with optional parameters. - - Args: - required: A required parameter - optional: An optional parameter - """ - return f"{required}:{optional or 'none'}" - - -def union_type_tool(value: Union[str, int]) -> str: - """Tool with union type parameter. - - Args: - value: Can be string or integer - """ - return f"Value: {value} (type: {type(value).__name__})" - - -def list_param_tool(items: list[str]) -> int: - """Tool with list parameter. - - Args: - items: List of string items - """ - return len(items) - - -def dict_param_tool(config: dict[str, Any]) -> str: - """Tool with dict parameter. - - Args: - config: Configuration dictionary - """ - return json.dumps(config) - - -def nested_structure_tool(data: dict[str, list[int]]) -> int: - """Tool with nested structure. - - Args: - data: Dictionary mapping strings to lists of integers - """ - return sum(sum(values) for values in data.values()) - - -def default_values_tool(name: str, count: int = 10, prefix: str = "item") -> str: - """Tool with default values. - - Args: - name: Base name - count: Number of items (default: 10) - prefix: Prefix for items (default: "item") - """ - return f"{prefix}_{name}_{count}" - - -class UserModel(BaseModel): - """Pydantic model for testing.""" - - name: str - age: int - email: Optional[str] = None - - -def pydantic_model_tool(user: UserModel) -> str: - """Tool that accepts a Pydantic model. - - Args: - user: User information - """ - return f"User: {user.name}, Age: {user.age}" - - -def no_params_tool() -> str: - """Tool with no parameters.""" - return "No params needed" - - -# ============================================================================ -# Test Cases: Basic Type Validation -# ============================================================================ - - -class TestBasicTypeValidation: - """Test basic type validation and coercion.""" - - def test_string_argument(self): - """Test simple string argument.""" - args = {"message": "Hello, World!"} - tool_call = ModelToolCall( - "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args - ) - result = tool_call.call_func() - assert result == "Processed: Hello, World!" - - def test_integer_argument(self): - """Test integer argument.""" - args = {"name": "Alice", "age": 30, "score": 95.5, "active": True} - tool_call = ModelToolCall( - "typed_primitives_tool", - MelleaTool.from_callable(typed_primitives_tool), - args, - ) - result = tool_call.call_func() - assert result["age"] == 30 - assert isinstance(result["age"], int) - - def test_float_argument(self): - """Test float argument.""" - args = {"name": "Bob", "age": 25, "score": 88.7, "active": False} - tool_call = ModelToolCall( - "typed_primitives_tool", - MelleaTool.from_callable(typed_primitives_tool), - args, - ) - result = tool_call.call_func() - assert result["score"] == 88.7 - assert isinstance(result["score"], float) - - def test_boolean_argument(self): - """Test boolean argument.""" - args = {"name": "Charlie", "age": 35, "score": 92.0, "active": True} - tool_call = ModelToolCall( - "typed_primitives_tool", - MelleaTool.from_callable(typed_primitives_tool), - args, - ) - result = tool_call.call_func() - assert result["active"] is True - assert isinstance(result["active"], bool) - - -# ============================================================================ -# Test Cases: Type Coercion -# ============================================================================ - - -class TestTypeCoercion: - """Test automatic type coercion scenarios.""" - - def test_string_to_int_coercion(self): - """Test that string "30" works without validation (Python duck typing).""" - # Python's duck typing allows this to work in many cases - args = {"name": "Test", "age": "30", "score": 95.5, "active": True} - tool_call = ModelToolCall( - "typed_primitives_tool", - MelleaTool.from_callable(typed_primitives_tool), - args, - ) - - # This actually works due to Python's duck typing - result = tool_call.call_func() - assert result["age"] == "30" # Still a string without validation - - def test_string_to_float_coercion(self): - """Test that string "95.5" works without validation (Python duck typing).""" - args = {"name": "Test", "age": 30, "score": "95.5", "active": True} - tool_call = ModelToolCall( - "typed_primitives_tool", - MelleaTool.from_callable(typed_primitives_tool), - args, - ) - - # This works due to Python's duck typing - result = tool_call.call_func() - assert result["score"] == "95.5" # Still a string without validation - - def test_int_to_string_coercion(self): - """Test that int 123 can be coerced to string "123".""" - args = {"message": 123} # Should be string - tool_call = ModelToolCall( - "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args - ) - - # This might work due to Python's duck typing, but not guaranteed - result = tool_call.call_func() - assert "123" in result - - def test_string_to_bool_coercion(self): - """Test boolean from strings works without validation (Python duck typing).""" - # Common LLM outputs: "true", "false", "True", "False" - args = {"name": "Test", "age": 30, "score": 95.5, "active": "true"} - tool_call = ModelToolCall( - "typed_primitives_tool", - MelleaTool.from_callable(typed_primitives_tool), - args, - ) - - # This works due to Python's duck typing - non-empty strings are truthy - result = tool_call.call_func() - assert result["active"] == "true" # Still a string without validation - - -# ============================================================================ -# Test Cases: Optional and Default Parameters -# ============================================================================ - - -class TestOptionalParameters: - """Test optional and default parameter handling.""" - - def test_optional_param_provided(self): - """Test optional parameter when provided.""" - args = {"required": "value1", "optional": "value2"} - tool_call = ModelToolCall( - "optional_params_tool", MelleaTool.from_callable(optional_params_tool), args - ) - result = tool_call.call_func() - assert result == "value1:value2" - - def test_optional_param_omitted(self): - """Test optional parameter when omitted.""" - args = {"required": "value1"} - tool_call = ModelToolCall( - "optional_params_tool", MelleaTool.from_callable(optional_params_tool), args - ) - result = tool_call.call_func() - assert result == "value1:none" - - def test_optional_param_none(self): - """Test optional parameter explicitly set to None.""" - args = {"required": "value1", "optional": None} - tool_call = ModelToolCall( - "optional_params_tool", MelleaTool.from_callable(optional_params_tool), args - ) - result = tool_call.call_func() - assert result == "value1:none" - - def test_default_values_all_provided(self): - """Test tool with all default values provided.""" - args = {"name": "test", "count": 5, "prefix": "custom"} - tool_call = ModelToolCall( - "default_values_tool", MelleaTool.from_callable(default_values_tool), args - ) - result = tool_call.call_func() - assert result == "custom_test_5" - - def test_default_values_partial(self): - """Test tool with some default values omitted.""" - args = {"name": "test", "count": 7} - tool_call = ModelToolCall( - "default_values_tool", MelleaTool.from_callable(default_values_tool), args - ) - result = tool_call.call_func() - assert result == "item_test_7" - - def test_default_values_minimal(self): - """Test tool with only required parameters.""" - args = {"name": "test"} - tool_call = ModelToolCall( - "default_values_tool", MelleaTool.from_callable(default_values_tool), args - ) - result = tool_call.call_func() - assert result == "item_test_10" - - -# ============================================================================ -# Test Cases: Union Types -# ============================================================================ - - -class TestUnionTypes: - """Test union type parameter handling.""" - - def test_union_with_string(self): - """Test union type with string value.""" - args = {"value": "hello"} - tool_call = ModelToolCall( - "union_type_tool", MelleaTool.from_callable(union_type_tool), args - ) - result = tool_call.call_func() - assert "hello" in result - assert "str" in result - - def test_union_with_int(self): - """Test union type with integer value.""" - args = {"value": 42} - tool_call = ModelToolCall( - "union_type_tool", MelleaTool.from_callable(union_type_tool), args - ) - result = tool_call.call_func() - assert "42" in result - assert "int" in result - - def test_union_with_string_number(self): - """Test union type with string that looks like number.""" - # Without validation, this stays as string - args = {"value": "42"} - tool_call = ModelToolCall( - "union_type_tool", MelleaTool.from_callable(union_type_tool), args - ) - result = tool_call.call_func() - assert "42" in result - # Type depends on whether validation coerces - - -# ============================================================================ -# Test Cases: Complex Types (Lists, Dicts) -# ============================================================================ - - -class TestComplexTypes: - """Test complex type parameters (lists, dicts, nested structures).""" - - def test_list_of_strings(self): - """Test list parameter with strings.""" - args = {"items": ["apple", "banana", "cherry"]} - tool_call = ModelToolCall( - "list_param_tool", MelleaTool.from_callable(list_param_tool), args - ) - result = tool_call.call_func() - assert result == 3 - - def test_empty_list(self): - """Test empty list parameter.""" - args = {"items": []} - tool_call = ModelToolCall( - "list_param_tool", MelleaTool.from_callable(list_param_tool), args - ) - result = tool_call.call_func() - assert result == 0 - - def test_dict_parameter(self): - """Test dictionary parameter.""" - args = {"config": {"key1": "value1", "key2": 42, "key3": True}} - tool_call = ModelToolCall( - "dict_param_tool", MelleaTool.from_callable(dict_param_tool), args - ) - result = tool_call.call_func() - parsed = json.loads(result) - assert parsed["key1"] == "value1" - assert parsed["key2"] == 42 - - def test_nested_structure(self): - """Test nested dictionary with lists.""" - args = {"data": {"group1": [1, 2, 3], "group2": [4, 5], "group3": [6]}} - tool_call = ModelToolCall( - "nested_structure_tool", - MelleaTool.from_callable(nested_structure_tool), - args, - ) - result = tool_call.call_func() - assert result == 21 # Sum of all numbers - - def test_nested_structure_empty(self): - """Test nested structure with empty lists.""" - args = {"data": {"group1": [], "group2": []}} - tool_call = ModelToolCall( - "nested_structure_tool", - MelleaTool.from_callable(nested_structure_tool), - args, - ) - result = tool_call.call_func() - assert result == 0 - - -# ============================================================================ -# Test Cases: Error Conditions -# ============================================================================ - - -class TestErrorConditions: - """Test error handling for invalid arguments.""" - - def test_missing_required_argument(self): - """Test that missing required argument raises error.""" - args = {} # Missing 'message' - tool_call = ModelToolCall( - "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args - ) - - with pytest.raises(TypeError, match="missing.*required"): - tool_call.call_func() - - def test_extra_arguments(self): - """Test that extra arguments are ignored (Python behavior).""" - args = {"message": "test", "extra": "ignored"} - tool_call = ModelToolCall( - "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args - ) - - # Python ignores extra kwargs by default - with pytest.raises(TypeError, match="unexpected keyword argument"): - tool_call.call_func() - - def test_wrong_type_no_coercion(self): - """Test that wrong types work without validation (Python duck typing).""" - args = {"name": "Test", "age": "not_a_number", "score": 95.5, "active": True} - tool_call = ModelToolCall( - "typed_primitives_tool", - MelleaTool.from_callable(typed_primitives_tool), - args, - ) - - # Python's duck typing allows this - the function just returns what it gets - result = tool_call.call_func() - assert result["age"] == "not_a_number" # Still a string - - def test_none_for_required_param(self): - """Test that None for required parameter fails.""" - args = {"message": None} - tool_call = ModelToolCall( - "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args - ) - - # Depends on function implementation - result = tool_call.call_func() - # May work or fail depending on function - - -# ============================================================================ -# Test Cases: JSON Parsing -# ============================================================================ - - -class TestJSONParsing: - """Test JSON parsing scenarios from model responses.""" - - def test_valid_json_string(self): - """Test parsing valid JSON string.""" - json_str = '{"message": "Hello, World!"}' - args = json.loads(json_str) - tool_call = ModelToolCall( - "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args - ) - result = tool_call.call_func() - assert "Hello, World!" in result - - def test_malformed_json(self): - """Test that malformed JSON raises error.""" - json_str = '{"message": "Hello, World!"' # Missing closing brace - - with pytest.raises(json.JSONDecodeError): - json.loads(json_str) - - def test_json_with_escaped_quotes(self): - """Test JSON with escaped quotes.""" - json_str = '{"message": "He said \\"Hello\\""}' - args = json.loads(json_str) - tool_call = ModelToolCall( - "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args - ) - result = tool_call.call_func() - assert "Hello" in result - - def test_json_with_unicode(self): - """Test JSON with unicode characters.""" - json_str = '{"message": "Hello δΈ–η•Œ 🌍"}' - args = json.loads(json_str) - tool_call = ModelToolCall( - "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args - ) - result = tool_call.call_func() - assert "δΈ–η•Œ" in result - - def test_json_with_nested_objects(self): - """Test JSON with nested objects.""" - json_str = '{"config": {"nested": {"key": "value"}, "list": [1, 2, 3]}}' - args = json.loads(json_str) - tool_call = ModelToolCall( - "dict_param_tool", MelleaTool.from_callable(dict_param_tool), args - ) - result = tool_call.call_func() - parsed = json.loads(result) - assert parsed["nested"]["key"] == "value" - - -# ============================================================================ -# Test Cases: Pydantic Models -# ============================================================================ - - -class TestPydanticModels: - """Test tools that accept Pydantic models as arguments.""" - - def test_pydantic_model_from_dict(self): - """Test creating Pydantic model from dict.""" - args = {"user": {"name": "Alice", "age": 30, "email": "alice@example.com"}} - - # Need to convert dict to Pydantic model - user_data = args["user"] - user = UserModel(**user_data) - args_with_model = {"user": user} - - tool_call = ModelToolCall( - "pydantic_model_tool", - MelleaTool.from_callable(pydantic_model_tool), - args_with_model, - ) - result = tool_call.call_func() - assert "Alice" in result - assert "30" in result - - def test_pydantic_model_validation_error(self): - """Test that invalid Pydantic model data raises error.""" - user_data = {"name": "Bob", "age": "not_an_int"} # Invalid age type - - with pytest.raises(ValidationError): - UserModel(**user_data) - - def test_pydantic_model_with_optional(self): - """Test Pydantic model with optional field.""" - user_data = {"name": "Charlie", "age": 25} # email is optional - user = UserModel(**user_data) - args = {"user": user} - - tool_call = ModelToolCall( - "pydantic_model_tool", MelleaTool.from_callable(pydantic_model_tool), args - ) - result = tool_call.call_func() - assert "Charlie" in result - - -# ============================================================================ -# Test Cases: Edge Cases -# ============================================================================ - - -class TestEdgeCases: - """Test edge cases and unusual scenarios.""" - - def test_no_parameters_tool(self): - """Test tool with no parameters.""" - args = {} - tool_call = ModelToolCall( - "no_params_tool", MelleaTool.from_callable(no_params_tool), args - ) - result = tool_call.call_func() - assert result == "No params needed" - - def test_no_parameters_with_hallucinated_args(self): - """Test that hallucinated args for no-param tool are handled.""" - # Models sometimes hallucinate parameters - args = {"fake_param": "should_be_ignored"} - tool_call = ModelToolCall( - "no_params_tool", MelleaTool.from_callable(no_params_tool), args - ) - - # This should fail without validation that clears args - with pytest.raises(TypeError): - tool_call.call_func() - - def test_very_long_string(self): - """Test with very long string argument.""" - long_string = "x" * 10000 - args = {"message": long_string} - tool_call = ModelToolCall( - "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args - ) - result = tool_call.call_func() - assert len(result) > 10000 - - def test_special_characters_in_string(self): - """Test with special characters.""" - special = "!@#$%^&*()_+-=[]{}|;:',.<>?/~`" - args = {"message": special} - tool_call = ModelToolCall( - "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args - ) - result = tool_call.call_func() - assert special in result - - def test_empty_string(self): - """Test with empty string.""" - args = {"message": ""} - tool_call = ModelToolCall( - "simple_string_tool", MelleaTool.from_callable(simple_string_tool), args - ) - result = tool_call.call_func() - assert result == "Processed: " - - -# ============================================================================ -# Test Cases: parse_tools() Function -# ============================================================================ - - -class TestParseToolsFunction: - """Test the parse_tools() function from backends/tools.py.""" - - def test_parse_single_tool_call(self): - """Test parsing a single tool call from text.""" - text = """ - I'll call the function now: - {"name": "get_weather", "arguments": {"location": "Boston", "days": 3}} - """ - tools = list(parse_tools(text)) - assert len(tools) == 1 - assert tools[0][0] == "get_weather" - assert tools[0][1]["location"] == "Boston" - assert tools[0][1]["days"] == 3 - - def test_parse_multiple_tool_calls(self): - """Test parsing multiple tool calls from text.""" - text = """ - First: {"name": "tool1", "arguments": {"arg1": "value1"}} - Second: {"name": "tool2", "arguments": {"arg2": "value2"}} - """ - tools = list(parse_tools(text)) - assert len(tools) == 2 - assert tools[0][0] == "tool1" - assert tools[1][0] == "tool2" - - def test_parse_with_extra_text(self): - """Test parsing tool calls with surrounding text.""" - text = """ - Let me help you with that. I'll use the get_temperature function. - {"name": "get_temperature", "arguments": {"location": "New York"}} - That should give us the current temperature. - """ - tools = list(parse_tools(text)) - assert len(tools) == 1 - assert tools[0][0] == "get_temperature" - - def test_parse_no_tools(self): - """Test parsing text with no tool calls.""" - text = "This is just regular text with no tool calls." - tools = list(parse_tools(text)) - assert len(tools) == 0 - - def test_parse_malformed_json(self): - """Test that malformed JSON is skipped.""" - text = """ - {"name": "tool1", "arguments": {"arg1": "value1"}} - {"name": "bad_tool", "arguments": {broken json}} - {"name": "tool2", "arguments": {"arg2": "value2"}} - """ - tools = list(parse_tools(text)) - # Should parse the valid ones and skip the malformed one - assert len(tools) == 2 - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/test/backends/test_tool_validation_integration.py b/test/backends/test_tool_validation_integration.py index 7325af4b..92dc859b 100644 --- a/test/backends/test_tool_validation_integration.py +++ b/test/backends/test_tool_validation_integration.py @@ -83,6 +83,15 @@ def no_params_tool() -> str: return "No params needed" +def untyped_param(message) -> str: + """A tool with an untyped parameter. + + Args: + message: The message to process (no type hint) + """ + return f"Processed: {message}" + + # ============================================================================ # Test Cases: Type Coercion # ============================================================================ @@ -381,5 +390,71 @@ def test_type_mismatch_error_message(self): assert "age" in error_str.lower() +class TestUntypedParameters: + """Test validation with untyped parameters.""" + + def test_untyped_parameter_accepts_string(self): + """Test that untyped parameters accept string values.""" + args = {"message": "test"} + tool = MelleaTool.from_callable(untyped_param) + validated = validate_tool_arguments(tool, args) + + assert validated["message"] == "test" + + def test_untyped_parameter_accepts_int(self): + """Test that untyped parameters accept integer values. + + Note: Without type hints, validation may coerce to string for safety. + """ + args = {"message": 123} + tool = MelleaTool.from_callable(untyped_param) + validated = validate_tool_arguments(tool, args) + + # Validation may coerce to string when no type hint is present + assert validated["message"] in [123, "123"] + + def test_untyped_parameter_accepts_dict(self): + """Test untyped parameter with complex type (dict).""" + args = {"message": {"key": "value", "number": 42}} + tool = MelleaTool.from_callable(untyped_param) + validated = validate_tool_arguments(tool, args) + + assert validated["message"] == {"key": "value", "number": 42} + + def test_untyped_parameter_accepts_list(self): + """Test untyped parameter with list.""" + args = {"message": ["item1", "item2", "item3"]} + tool = MelleaTool.from_callable(untyped_param) + validated = validate_tool_arguments(tool, args) + + assert validated["message"] == ["item1", "item2", "item3"] + + def test_untyped_parameter_accepts_bool(self): + """Test untyped parameter with boolean.""" + args = {"message": True} + tool = MelleaTool.from_callable(untyped_param) + validated = validate_tool_arguments(tool, args) + + assert validated["message"] is True + + def test_untyped_parameter_accepts_none(self): + """Test untyped parameter with None.""" + args = {"message": None} + tool = MelleaTool.from_callable(untyped_param) + validated = validate_tool_arguments(tool, args) + + assert validated["message"] is None + + def test_untyped_parameter_no_coercion(self): + """Test that untyped parameters don't get coerced.""" + args = {"message": "123"} + tool = MelleaTool.from_callable(untyped_param) + validated = validate_tool_arguments(tool, args, coerce_types=True) + + # Should remain as string since there's no type hint to coerce to + assert validated["message"] == "123" + assert isinstance(validated["message"], str) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 8420c9a67f431b1debea799b5b965c05c3944122 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Tue, 3 Feb 2026 10:27:39 -0500 Subject: [PATCH 8/8] fix lint error Signed-off-by: Akihiko Kuroda --- test/backends/test_tool_validation_integration.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/backends/test_tool_validation_integration.py b/test/backends/test_tool_validation_integration.py index 92dc859b..2ffd488f 100644 --- a/test/backends/test_tool_validation_integration.py +++ b/test/backends/test_tool_validation_integration.py @@ -4,15 +4,14 @@ the actual tool call flow. """ -import pytest from typing import Any, Optional, Union +import pytest from pydantic import ValidationError from mellea.backends.tools import MelleaTool, validate_tool_arguments from mellea.core import ModelToolCall - # ============================================================================ # Test Fixtures - Tool Functions # ============================================================================ @@ -39,7 +38,7 @@ def typed_tool(name: str, age: int, score: float, active: bool) -> dict: return {"name": name, "age": age, "score": score, "active": active} -def optional_tool(required: str, optional: Optional[str] = None) -> str: +def optional_tool(required: str, optional: str | None = None) -> str: """Tool with optional parameters. Args: @@ -49,7 +48,7 @@ def optional_tool(required: str, optional: Optional[str] = None) -> str: return f"{required}:{optional or 'none'}" -def union_tool(value: Union[str, int]) -> str: +def union_tool(value: str | int) -> str: """Tool with union type parameter. Args: