diff --git a/docs/examples/tools/interpreter_example.py b/docs/examples/tools/interpreter_example.py index ea77e801..3b8c5546 100644 --- a/docs/examples/tools/interpreter_example.py +++ b/docs/examples/tools/interpreter_example.py @@ -1,7 +1,7 @@ # pytest: ollama, llm from mellea import MelleaSession, start_session -from mellea.backends import ModelOption +from mellea.backends import ModelOption, tool from mellea.stdlib.requirements import tool_arg_validator, uses_tool from mellea.stdlib.tools import code_interpreter, local_code_interpreter diff --git a/docs/examples/tools/tool_decorator_example.py b/docs/examples/tools/tool_decorator_example.py new file mode 100644 index 00000000..e35e1405 --- /dev/null +++ b/docs/examples/tools/tool_decorator_example.py @@ -0,0 +1,137 @@ +# pytest: ollama, llm +"""Example demonstrating the @tool decorator for cleaner tool definitions.""" + +import ast + +from mellea import start_session +from mellea.backends import ModelOption, tool + + +# Define tools using the @tool decorator - much cleaner than MelleaTool.from_callable() +@tool +def get_weather(location: str, days: int = 1) -> dict: + """Get weather forecast for a location. + + Args: + location: City name + days: Number of days to forecast (default: 1) + """ + # Mock implementation + return {"location": location, "days": days, "forecast": "sunny", "temperature": 72} + + +@tool +def search_web(query: str, max_results: int = 5) -> list[str]: + """Search the web for information. + + Args: + query: Search query + max_results: Maximum number of results to return + """ + # Mock implementation + return [f"Result {i + 1} for '{query}'" for i in range(max_results)] + + +@tool(name="calculator") +def calculate(expression: str) -> str: + """Evaluate a mathematical expression. + + Args: + expression: Mathematical expression to evaluate + """ + try: + # Use ast.literal_eval for safe evaluation of simple expressions + result = ast.literal_eval(expression) + return f"Result: {result}" + except Exception as e: + return f"Error: {e!s}" + + +def example_basic_usage(): + """Example 1: Basic usage with decorated tools.""" + print("\n=== Example 1: Basic Tool Usage ===") + + # Without the decorator, you can add tools using: + # tools = [MelleaTool.from_callable(get_weather), MelleaTool.from_callable(search_web)] + + # Now you can just pass the decorated functions directly to model_options + # Example: model_options={ModelOption.TOOLS: [get_weather, search_web, calculate]} + + # The decorated tools must be called using .run() + weather = get_weather.run("Boston", days=3) + print(f"Tool call via .run(): {weather}") + + # And they have tool properties + print(f"Tool name: {get_weather.name}") + print(f"Tool has JSON schema: {'function' in get_weather.as_json_tool}") + + +def example_with_llm(): + """Example 2: Using decorated tools with an LLM.""" + print("\n=== Example 2: Using Tools with LLM ===") + + m = start_session() + + # Pass decorated tools directly - no wrapping needed! + response = m.instruct( + description="What's the weather like in San Francisco?", + model_options={ModelOption.TOOLS: [get_weather, search_web]}, + ) + + print(f"Response: {response}") + + +def example_custom_name(): + """Example 3: Using custom tool names.""" + print("\n=== Example 3: Custom Tool Names ===") + + # The calculator tool was decorated with @tool(name="calculator") + # So its name is "calculator" instead of "calculate" + print("Function name: calculate") + print(f"Tool name: {calculate.name}") + + # Must use .run() to invoke + result = calculate.run("2 + 2") + print(f"Result: {result}") + + +def example_comparison(): + """Example 4: Comparison of old vs new approach.""" + print("\n=== Example 4: Old vs New Approach ===") + + # OLD APPROACH (still works, but verbose): + from mellea.backends.tools import MelleaTool + + def old_style_tool(x: int) -> int: + """Old style tool. + + Args: + x: Input value + """ + return x * 2 + + old_tool = MelleaTool.from_callable(old_style_tool) + print(f"Old approach - tool name: {old_tool.name}") + + # NEW APPROACH (cleaner): + @tool + def new_style_tool(x: int) -> int: + """New style tool. + + Args: + x: Input value + """ + return x * 2 + + print(f"New approach - tool name: {new_style_tool.name}") + + # Both can be used together in a tools list + tools = [old_tool, new_style_tool, get_weather] + print(f"Mixed tools list: {[t.name for t in tools]}") + + +if __name__ == "__main__": + example_basic_usage() + # example_with_llm() # Uncomment to test with actual LLM + example_custom_name() + example_comparison() diff --git a/docs/tutorial.md b/docs/tutorial.md index 2fa97af3..d494b6a8 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -1347,9 +1347,15 @@ For examples on adding tools to the template representation of a component, see Here's an example of adding a tool through model options. This can be useful when you want to add a tool like web search that should almost always be available: ```python import mellea -from mellea.backends import ModelOption +from mellea.backends import ModelOption, tool +@tool def web_search(query: str) -> str: + """Search the web for information. + + Args: + query: The search query + """ ... m = mellea.start_session() diff --git a/mellea/backends/__init__.py b/mellea/backends/__init__.py index 564d9950..9dd45518 100644 --- a/mellea/backends/__init__.py +++ b/mellea/backends/__init__.py @@ -6,12 +6,15 @@ from .cache import SimpleLRUCache from .model_ids import ModelIdentifier from .model_options import ModelOption +from .tools import MelleaTool, tool __all__ = [ "Backend", "BaseModelSubclass", "FormatterBackend", + "MelleaTool", "ModelIdentifier", "ModelOption", "SimpleLRUCache", + "tool", ] diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index 09249c66..2fd047f4 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -5,7 +5,7 @@ import re from collections import defaultdict from collections.abc import Callable, Generator, Iterable, Mapping, Sequence -from typing import Any, Literal +from typing import Any, Literal, overload from pydantic import BaseModel, ConfigDict, Field @@ -91,10 +91,84 @@ def from_callable(cls, func: Callable, name: str | None = None): return MelleaTool(tool_name, tool_call, as_json) +@overload +def tool(func: Callable) -> MelleaTool: ... + + +@overload +def tool(*, name: str | None = None) -> Callable[[Callable], MelleaTool]: ... + + +def tool( + func: Callable | None = None, name: str | None = None +) -> MelleaTool | Callable[[Callable], MelleaTool]: + """Decorator to mark a function as a Mellea tool. + + This decorator wraps a function to make it usable as a tool without + requiring explicit MelleaTool.from_callable() calls. The decorated + function returns a MelleaTool instance that must be called via .run(). + + Args: + func: The function to decorate (when used without arguments) + name: Optional custom name for the tool (defaults to function name) + + Returns: + A MelleaTool instance. Use .run() to invoke the tool. + The returned object passes isinstance(result, MelleaTool) checks. + + Examples: + Basic usage: + >>> @tool + ... def get_weather(location: str, days: int = 1) -> dict: + ... '''Get weather forecast. + ... + ... Args: + ... location: City name + ... days: Number of days to forecast + ... ''' + ... return {"location": location, "forecast": "sunny"} + >>> + >>> # The decorated function IS a MelleaTool + >>> isinstance(get_weather, MelleaTool) # True + >>> + >>> # Can be used directly in tools list (no extraction needed) + >>> tools = [get_weather] + >>> + >>> # Must use .run() to invoke the tool + >>> result = get_weather.run(location="Boston") + + With custom name (as decorator): + >>> @tool(name="weather_api") + ... def get_weather(location: str) -> dict: + ... return {"location": location} + >>> + >>> result = get_weather.run(location="New York") + + With custom name (as function): + >>> def new_tool(): ... + >>> differently_named_tool = tool(new_tool, name="different_name") + """ + + def decorator(f: Callable) -> MelleaTool: + # Simply return the base MelleaTool instance + return MelleaTool.from_callable(f, name=name) + + # Handle both @tool and @tool() syntax + if func is None: + # Called with arguments: @tool(name="custom") + return decorator + else: + # Called without arguments: @tool + return decorator(func) + + def add_tools_from_model_options( tools_dict: dict[str, AbstractMelleaTool], model_options: dict[str, Any] ): - """If model_options has tools, add those tools to the tools_dict.""" + """If model_options has tools, add those tools to the tools_dict. + + Accepts MelleaTool instances or @tool decorated functions. + """ model_opts_tools = model_options.get(ModelOption.TOOLS, None) if model_opts_tools is None: return @@ -110,16 +184,20 @@ def add_tools_from_model_options( assert isinstance(tool_name, str), ( f"If ModelOption.TOOLS is a dict, it must be a dict of [str, Tool]; found {type(tool_name)} as the key instead" ) - assert isinstance(tool_instance, MelleaTool), ( + assert isinstance(tool_instance, AbstractMelleaTool), ( f"If ModelOption.TOOLS is a dict, it must be a dict of [str, Tool]; found {type(tool_instance)} as the value instead" ) tools_dict[tool_name] = tool_instance else: # Handle any other iterable / list here. for tool_instance in model_opts_tools: - assert isinstance(tool_instance, MelleaTool), ( + assert isinstance(tool_instance, AbstractMelleaTool), ( f"If ModelOption.TOOLS is a list, it must be a list of Tool; found {type(tool_instance)}" ) + # MelleaTool (and subclasses like CallableMelleaTool) have a name attribute + assert isinstance(tool_instance, MelleaTool), ( + f"Tool must be a MelleaTool instance with a name attribute; found {type(tool_instance)}" + ) tools_dict[tool_instance.name] = tool_instance diff --git a/test/backends/test_tool_decorator.py b/test/backends/test_tool_decorator.py new file mode 100644 index 00000000..139d0a95 --- /dev/null +++ b/test/backends/test_tool_decorator.py @@ -0,0 +1,309 @@ +"""Tests for the @tool decorator.""" + +import pytest + +from mellea.backends import MelleaTool, tool +from mellea.core import ModelToolCall + +# ============================================================================ +# Test Fixtures - Tool Functions +# ============================================================================ + + +@tool +def simple_tool(message: str) -> str: + """A simple tool that takes a string. + + Args: + message: The message to process + """ + return f"Processed: {message}" + + +@tool(name="custom_name") +def tool_with_custom_name(value: int) -> int: + """Tool with custom name. + + Args: + value: A value to process + """ + return value * 2 + + +@tool +def multi_param_tool(name: str, age: int, active: bool = True) -> dict: + """Tool with multiple parameters. + + Args: + name: Person's name + age: Person's age + active: Whether active + """ + return {"name": name, "age": age, "active": active} + + +def undecorated_function(x: int) -> int: + """A regular function without the decorator. + + Args: + x: Input value + """ + return x + 1 + + +# ============================================================================ +# Test Cases: Basic Decorator Functionality +# ============================================================================ + + +class TestToolDecoratorBasics: + """Test basic decorator functionality.""" + + def test_decorated_function_is_callable(self): + """Test that decorated function can be called via .run().""" + result = simple_tool.run("hello") + assert result == "Processed: hello" + + def test_decorated_function_has_name_attribute(self): + """Test that decorated function has name attribute.""" + assert hasattr(simple_tool, "name") + assert simple_tool.name == "simple_tool" + + def test_decorated_function_has_as_json_tool(self): + """Test that decorated function has as_json_tool property.""" + assert hasattr(simple_tool, "as_json_tool") + json_tool = simple_tool.as_json_tool + assert isinstance(json_tool, dict) + assert "function" in json_tool + + def test_decorated_function_has_run_method(self): + """Test that decorated function has run method.""" + assert hasattr(simple_tool, "run") + result = simple_tool.run("test") + assert result == "Processed: test" + + def test_decorated_function_preserves_metadata(self): + """Test that decorator preserves function metadata.""" + # MelleaTool doesn't have __name__ or __doc__ attributes + # but has name attribute and the original function's docstring in as_json_tool + assert simple_tool.name == "simple_tool" + json_tool = simple_tool.as_json_tool + assert "simple tool" in json_tool["function"]["description"].lower() + + def test_custom_name_decorator(self): + """Test decorator with custom name parameter.""" + assert tool_with_custom_name.name == "custom_name" + # Function should still work via .run() + result = tool_with_custom_name.run(5) + assert result == 10 + + +# ============================================================================ +# Test Cases: Integration with MelleaTool +# ============================================================================ + + +class TestToolDecoratorIntegration: + """Test integration with existing MelleaTool infrastructure.""" + + def test_decorated_tool_in_list(self): + """Test that decorated tools can be used in a list.""" + tools = [simple_tool, multi_param_tool] + assert len(tools) == 2 + # Should be able to access tool properties + assert tools[0].name == "simple_tool" + assert tools[1].name == "multi_param_tool" + + def test_decorated_tool_with_model_tool_call(self): + """Test that decorated tools work with ModelToolCall.""" + args = {"message": "test message"} + # Decorated function IS a MelleaTool, can be passed directly + tool_call = ModelToolCall("simple_tool", simple_tool, args) + result = tool_call.call_func() + assert result == "Processed: test message" + + def test_decorated_tool_json_schema(self): + """Test that decorated tool generates correct JSON schema.""" + json_tool = simple_tool.as_json_tool + assert json_tool["type"] == "function" + assert json_tool["function"]["name"] == "simple_tool" + assert "parameters" in json_tool["function"] + properties = json_tool["function"]["parameters"]["properties"] + assert "message" in properties + assert properties["message"]["type"] == "string" + + def test_multi_param_tool_schema(self): + """Test schema generation for multi-parameter tool.""" + json_tool = multi_param_tool.as_json_tool + properties = json_tool["function"]["parameters"]["properties"] + assert "name" in properties + assert "age" in properties + assert "active" in properties + # Check required fields + required = json_tool["function"]["parameters"]["required"] + assert "name" in required + assert "age" in required + # active has default, so might not be required + + +# ============================================================================ +# Test Cases: Comparison with from_callable +# ============================================================================ + + +class TestToolDecoratorVsFromCallable: + """Test that decorator produces equivalent results to from_callable.""" + + def test_decorator_equivalent_to_from_callable(self): + """Test that @tool produces same result as MelleaTool.from_callable.""" + # Create tool using from_callable + manual_tool = MelleaTool.from_callable(undecorated_function) + + # Create tool using decorator + @tool + def decorated_version(x: int) -> int: + """A regular function without the decorator. + + Args: + x: Input value + """ + return x + 1 + + # Compare JSON schemas + manual_json = manual_tool.as_json_tool + decorated_json = decorated_version.as_json_tool + + # Names should match + assert manual_json["function"]["name"] == "undecorated_function" + assert decorated_json["function"]["name"] == "decorated_version" + + # Parameters should have same structure + assert ( + manual_json["function"]["parameters"]["type"] + == decorated_json["function"]["parameters"]["type"] + ) + + def test_both_approaches_work_in_tools_list(self): + """Test that both decorated and from_callable tools work together.""" + manual_tool = MelleaTool.from_callable(undecorated_function) + tools = [simple_tool, manual_tool] + + # Both should have name attribute + assert hasattr(tools[0], "name") + assert hasattr(tools[1], "name") + + # Both should have as_json_tool + assert hasattr(tools[0], "as_json_tool") + assert hasattr(tools[1], "as_json_tool") + + +# ============================================================================ +# Test Cases: Edge Cases +# ============================================================================ + + +class TestToolDecoratorEdgeCases: + """Test edge cases and error conditions.""" + + def test_decorator_with_no_params_function(self): + """Test decorator on function with no parameters.""" + + @tool + def no_params() -> str: + """Function with no parameters.""" + return "no params" + + result = no_params.run() + assert result == "no params" + assert no_params.name == "no_params" + + def test_decorator_preserves_function_behavior(self): + """Test that decorator doesn't change function behavior.""" + + @tool + def add(a: int, b: int) -> int: + """Add two numbers. + + Args: + a: First number + b: Second number + """ + return a + b + + # Should work via .run() method + assert add.run(2, 3) == 5 + assert add.run(10, 20) == 30 + assert add.run(5, 7) == 12 + + def test_decorator_with_complex_types(self): + """Test decorator with complex parameter types.""" + + @tool + def complex_tool(items: list[str], config: dict) -> int: + """Tool with complex types. + + Args: + items: List of items + config: Configuration dict + """ + return len(items) + len(config) + + result = complex_tool.run(["a", "b"], {"x": 1, "y": 2}) + assert result == 4 + + def test_multiple_decorators_on_same_function(self): + """Test that decorator can be applied multiple times (creates new instances).""" + + def base_func(x: int) -> int: + """Base function. + + Args: + x: Input + """ + return x + + tool1 = tool(base_func) + tool2 = tool(name="custom")(base_func) + + assert tool1.name == "base_func" + assert tool2.name == "custom" + + +# ============================================================================ +# Test Cases: Usage Patterns +# ============================================================================ + + +class TestToolDecoratorUsagePatterns: + """Test common usage patterns.""" + + def test_tools_in_dict(self): + """Test using decorated tools in a dictionary.""" + tools_dict = {"simple": simple_tool, "multi": multi_param_tool} + + assert tools_dict["simple"].name == "simple_tool" + assert tools_dict["multi"].name == "multi_param_tool" + + def test_tools_passed_to_function(self): + """Test passing decorated tools to a function.""" + + def process_tools(tool_list): + """Process a list of tools.""" + return [t.name for t in tool_list] + + tools = [simple_tool, multi_param_tool] + names = process_tools(tools) + assert "simple_tool" in names + assert "multi_param_tool" in names + + def test_accessing_underlying_mellea_tool(self): + """Test that decorated function IS a MelleaTool instance.""" + assert isinstance(simple_tool, MelleaTool) + assert simple_tool.name == "simple_tool" + # Verify it has all MelleaTool properties + assert hasattr(simple_tool, "as_json_tool") + assert hasattr(simple_tool, "run") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])