-
Notifications
You must be signed in to change notification settings - Fork 75
feat: add tool calling argument validation #364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6a79d5b
3b36153
dcc34fa
7daedc4
2a3d34b
70cd10e
3015682
8420c9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -211,10 +211,200 @@ def parse_tools(llm_response: str) -> list[tuple[str, Mapping]]: | |
| tool_name, tool_arguments = find_func(possible_tool) | ||
| if tool_name is not None and tool_arguments is not None: | ||
| tools.append((tool_name, tool_arguments)) | ||
|
|
||
| return tools | ||
|
|
||
|
|
||
| def validate_tool_arguments( | ||
| tool: AbstractMelleaTool, | ||
| args: Mapping[str, Any], | ||
| *, | ||
| coerce_types: bool = True, | ||
| strict: bool = False, | ||
| ) -> dict[str, Any]: | ||
|
Comment on lines
217
to
223
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I looked at the changes for this PR and Hendrik's; if Hendrik's gets merged, first you'll just have to change this
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw the PR. Right now, the parameter information is from the function signature only. I understand that some refactor is necessary. |
||
| """Validate and optionally coerce tool arguments against tool's JSON schema. | ||
|
|
||
| This function validates tool call arguments extracted from LLM responses against | ||
| the tool's JSON schema from as_json_tool. It can automatically coerce common type | ||
| mismatches (e.g., string "30" to int 30) and provides detailed error messages. | ||
|
|
||
| Args: | ||
| tool: The MelleaTool instance to validate against | ||
| args: Raw arguments from model (post-JSON parsing) | ||
| coerce_types: If True, attempt type coercion for common cases (default: True) | ||
| strict: If True, raise ValidationError on failures; if False, log warnings | ||
| and return original args (default: False) | ||
|
|
||
| Returns: | ||
| Validated and optionally coerced arguments dict | ||
|
|
||
| Raises: | ||
| ValidationError: If strict=True and validation fails | ||
|
|
||
| Examples: | ||
| >>> def get_weather(location: str, days: int = 1) -> dict: | ||
| ... return {"location": location, "days": days} | ||
| >>> tool = MelleaTool.from_callable(get_weather) | ||
|
|
||
| >>> # LLM returns days as string | ||
| >>> args = {"location": "Boston", "days": "3"} | ||
| >>> validated = validate_tool_arguments(tool, args) | ||
| >>> validated | ||
| {'location': 'Boston', 'days': 3} | ||
|
|
||
| >>> # Strict mode raises on validation errors | ||
| >>> bad_args = {"location": "Boston", "days": "not_a_number"} | ||
| >>> validate_tool_arguments(tool, bad_args, strict=True) | ||
| Traceback (most recent call last): | ||
| ... | ||
| pydantic.ValidationError: ... | ||
| """ | ||
| from pydantic import ValidationError, create_model | ||
|
|
||
| from ..core import FancyLogger | ||
|
|
||
| # Extract JSON schema from tool | ||
| tool_schema = tool.as_json_tool.get("function", {}) | ||
| tool_name = tool_schema.get("name", "unknown_tool") | ||
| parameters = tool_schema.get("parameters", {}) | ||
| properties = parameters.get("properties", {}) | ||
| required_fields = parameters.get("required", []) | ||
|
|
||
| # Map JSON schema types to Python types | ||
| JSON_TYPE_TO_PYTHON = { | ||
| "string": str, | ||
| "integer": int, | ||
| "number": float, | ||
| "boolean": bool, | ||
| "array": list, | ||
| "object": dict, | ||
| } | ||
|
|
||
| # Build Pydantic model from JSON schema | ||
| field_definitions: dict[str, Any] = {} | ||
|
|
||
| for param_name, param_schema in properties.items(): | ||
| # Get type from JSON schema | ||
| json_type = param_schema.get("type", "string") | ||
|
|
||
| # Handle comma-separated types (e.g., "integer, string" for Union types) | ||
| if isinstance(json_type, str) and "," in json_type: | ||
| # Create Union type for multiple types | ||
| type_list = [t.strip() for t in json_type.split(",")] | ||
| python_types = [JSON_TYPE_TO_PYTHON.get(t, Any) for t in type_list] | ||
| # Remove duplicates while preserving order | ||
| seen = set() | ||
| unique_types = [] | ||
| for t in python_types: | ||
| if t not in seen: | ||
| seen.add(t) | ||
| unique_types.append(t) | ||
|
|
||
| if len(unique_types) == 1: | ||
| param_type = unique_types[0] | ||
| else: | ||
| # Use modern union syntax (Python 3.10+) | ||
| from functools import reduce | ||
| from operator import or_ | ||
|
|
||
| param_type = reduce(or_, unique_types) | ||
| else: | ||
| # Map to Python type | ||
| param_type = JSON_TYPE_TO_PYTHON.get(json_type, Any) | ||
|
|
||
| # Determine if parameter is required | ||
| if param_name in required_fields: | ||
| # Required parameter | ||
| field_definitions[param_name] = (param_type, ...) | ||
| else: | ||
| # Optional parameter (default to None) | ||
| field_definitions[param_name] = (param_type, None) | ||
|
|
||
| # Configure model for type coercion if requested | ||
| if coerce_types: | ||
| model_config = ConfigDict( | ||
| str_strip_whitespace=True, | ||
| strict=False, # Allow type coercion | ||
| extra="forbid" if strict else "allow", # Handle extra fields | ||
| # Enable coercion modes for common LLM output issues | ||
| coerce_numbers_to_str=True, # Allow int/float -> str | ||
| ) | ||
| else: | ||
| model_config = ConfigDict( | ||
| strict=True, # No coercion | ||
| extra="forbid" if strict else "allow", | ||
| ) | ||
|
|
||
| # Create dynamic Pydantic model for validation | ||
| ValidatorModel = create_model( | ||
| f"{tool_name}_Validator", __config__=model_config, **field_definitions | ||
| ) | ||
|
|
||
| try: | ||
| # Validate using Pydantic | ||
| validated_model = ValidatorModel(**args) | ||
| validated_args = validated_model.model_dump() | ||
|
|
||
| # In lenient mode with extra="allow", Pydantic includes extra fields | ||
| # but we need to preserve them from the original args | ||
| if not strict: | ||
| # Add back any extra fields that weren't in the model | ||
| for key, value in args.items(): | ||
| if key not in field_definitions: | ||
| validated_args[key] = value | ||
|
|
||
| # Log successful validation with coercion details | ||
| coerced_fields = [] | ||
| for key, original_value in args.items(): | ||
| validated_value = validated_args.get(key) | ||
| if type(original_value) is not type(validated_value): | ||
| coerced_fields.append( | ||
| f"{key}: {type(original_value).__name__} → {type(validated_value).__name__}" | ||
| ) | ||
|
|
||
| if coerced_fields and coerce_types: | ||
| FancyLogger.get_logger().debug( | ||
| f"Tool '{tool_name}' arguments coerced: {', '.join(coerced_fields)}" | ||
| ) | ||
|
|
||
| return validated_args | ||
|
|
||
| except ValidationError as e: | ||
| # Format error message | ||
| error_details = [] | ||
| for error in e.errors(): | ||
| field = ".".join(str(loc) for loc in error["loc"]) | ||
| msg = error["msg"] | ||
| error_details.append(f" - {field}: {msg}") | ||
|
|
||
| error_msg = f"Tool argument validation failed for '{tool_name}':\n" + "\n".join( | ||
| error_details | ||
| ) | ||
|
|
||
| if strict: | ||
| # Re-raise with enhanced message | ||
| FancyLogger.get_logger().error(error_msg) | ||
| raise | ||
| else: | ||
| # Log warning and return original args | ||
| FancyLogger.get_logger().warning( | ||
| error_msg + "\nReturning original arguments without validation." | ||
| ) | ||
| return dict(args) | ||
|
|
||
| except Exception as e: | ||
| # Catch any other errors during validation | ||
| error_msg = f"Unexpected error validating tool '{tool_name}' arguments: {e}" | ||
|
|
||
| if strict: | ||
| FancyLogger.get_logger().error(error_msg) | ||
| raise | ||
| else: | ||
| FancyLogger.get_logger().warning( | ||
| error_msg + "\nReturning original arguments without validation." | ||
| ) | ||
| return dict(args) | ||
|
|
||
|
|
||
| # Below functions and classes extracted from Ollama Python SDK (v0.6.1) | ||
| # so that all backends don't need it installed. | ||
| # https://github.com/ollama/ollama-python/blob/60e7b2f9ce710eeb57ef2986c46ea612ae7516af/ollama/_types.py#L19-L101 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Were you able to verify that these backends can generate incorrect parameters? I know huggingface has some issues with its tool requests that could benefit from this, but I'm not sure the other backends need any validation (but that was an open question I wasn't certain on).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember if I saw the backend generated incorrect parameters but it is good to validate them before making tool calls. We can not assume all backends working correctly :-)