@@ -213,20 +213,20 @@ def parse_tools(llm_response: str) -> list[tuple[str, Mapping]]:
213213
214214
215215def validate_tool_arguments (
216- func : Callable ,
216+ tool : AbstractMelleaTool ,
217217 args : Mapping [str , Any ],
218218 * ,
219219 coerce_types : bool = True ,
220220 strict : bool = False ,
221221) -> dict [str , Any ]:
222- """Validate and optionally coerce tool arguments against function signature .
222+ """Validate and optionally coerce tool arguments against tool's JSON schema .
223223
224224 This function validates tool call arguments extracted from LLM responses against
225- the expected function signature . It can automatically coerce common type mismatches
226- (e.g., string "30" to int 30) and provides detailed error messages.
225+ the tool's JSON schema from as_json_tool . It can automatically coerce common type
226+ mismatches (e.g., string "30" to int 30) and provides detailed error messages.
227227
228228 Args:
229- func : The tool function to validate against
229+ tool : The MelleaTool instance to validate against
230230 args: Raw arguments from model (post-JSON parsing)
231231 coerce_types: If True, attempt type coercion for common cases (default: True)
232232 strict: If True, raise ValidationError on failures; if False, log warnings
@@ -241,16 +241,17 @@ def validate_tool_arguments(
241241 Examples:
242242 >>> def get_weather(location: str, days: int = 1) -> dict:
243243 ... return {"location": location, "days": days}
244+ >>> tool = MelleaTool.from_callable(get_weather)
244245
245246 >>> # LLM returns days as string
246247 >>> args = {"location": "Boston", "days": "3"}
247- >>> validated = validate_tool_arguments(get_weather , args)
248+ >>> validated = validate_tool_arguments(tool , args)
248249 >>> validated
249250 {'location': 'Boston', 'days': 3}
250251
251252 >>> # Strict mode raises on validation errors
252253 >>> bad_args = {"location": "Boston", "days": "not_a_number"}
253- >>> validate_tool_arguments(get_weather , bad_args, strict=True)
254+ >>> validate_tool_arguments(tool , bad_args, strict=True)
254255 Traceback (most recent call last):
255256 ...
256257 pydantic.ValidationError: ...
@@ -259,34 +260,45 @@ def validate_tool_arguments(
259260
260261 from ..core import FancyLogger
261262
262- # Get function signature
263- sig = inspect .signature (func )
264-
265- # Build Pydantic model from function signature
266- # This reuses the logic from convert_function_to_tool
263+ # Extract JSON schema from tool
264+ tool_schema = tool .as_json_tool .get ("function" , {})
265+ tool_name = tool_schema .get ("name" , "unknown_tool" )
266+ parameters = tool_schema .get ("parameters" , {})
267+ properties = parameters .get ("properties" , {})
268+ required_fields = parameters .get ("required" , [])
269+
270+ # Map JSON schema types to Python types
271+ JSON_TYPE_TO_PYTHON = {
272+ "string" : str ,
273+ "integer" : int ,
274+ "number" : float ,
275+ "boolean" : bool ,
276+ "array" : list ,
277+ "object" : dict ,
278+ }
279+
280+ # Build Pydantic model from JSON schema
267281 field_definitions : dict [str , Any ] = {}
268282
269- for param_name , param in sig .parameters .items ():
270- # Skip *args and **kwargs
271- if param .kind in (
272- inspect .Parameter .VAR_POSITIONAL ,
273- inspect .Parameter .VAR_KEYWORD ,
274- ):
275- continue
283+ for param_name , param_schema in properties .items ():
284+ # Get type from JSON schema
285+ json_type = param_schema .get ("type" , "string" )
286+
287+ # Handle comma-separated types (e.g., "string, integer")
288+ if isinstance (json_type , str ) and "," in json_type :
289+ # Take the first type for simplicity
290+ json_type = json_type .split ("," )[0 ].strip ()
276291
277- # Get type annotation
278- param_type = param .annotation
279- if param_type == inspect .Parameter .empty :
280- # No type hint, default to Any
281- param_type = Any
292+ # Map to Python type
293+ param_type = JSON_TYPE_TO_PYTHON .get (json_type , Any )
282294
283- # Handle default values
284- if param . default == inspect . Parameter . empty :
295+ # Determine if parameter is required
296+ if param_name in required_fields :
285297 # Required parameter
286298 field_definitions [param_name ] = (param_type , ...)
287299 else :
288- # Optional parameter with default
289- field_definitions [param_name ] = (param_type , param . default )
300+ # Optional parameter ( default to None)
301+ field_definitions [param_name ] = (param_type , None )
290302
291303 # Configure model for type coercion if requested
292304 if coerce_types :
@@ -305,7 +317,7 @@ def validate_tool_arguments(
305317
306318 # Create dynamic Pydantic model for validation
307319 ValidatorModel = create_model (
308- f"{ func . __name__ } _Validator" , __config__ = model_config , ** field_definitions
320+ f"{ tool_name } _Validator" , __config__ = model_config , ** field_definitions
309321 )
310322
311323 try :
@@ -332,7 +344,7 @@ def validate_tool_arguments(
332344
333345 if coerced_fields and coerce_types :
334346 FancyLogger .get_logger ().debug (
335- f"Tool '{ func . __name__ } ' arguments coerced: { ', ' .join (coerced_fields )} "
347+ f"Tool '{ tool_name } ' arguments coerced: { ', ' .join (coerced_fields )} "
336348 )
337349
338350 return validated_args
@@ -345,9 +357,8 @@ def validate_tool_arguments(
345357 msg = error ["msg" ]
346358 error_details .append (f" - { field } : { msg } " )
347359
348- error_msg = (
349- f"Tool argument validation failed for '{ func .__name__ } ':\n "
350- + "\n " .join (error_details )
360+ error_msg = f"Tool argument validation failed for '{ tool_name } ':\n " + "\n " .join (
361+ error_details
351362 )
352363
353364 if strict :
@@ -363,7 +374,7 @@ def validate_tool_arguments(
363374
364375 except Exception as e :
365376 # Catch any other errors during validation
366- error_msg = f"Unexpected error validating tool '{ func . __name__ } ' arguments: { e } "
377+ error_msg = f"Unexpected error validating tool '{ tool_name } ' arguments: { e } "
367378
368379 if strict :
369380 FancyLogger .get_logger ().error (error_msg )
0 commit comments