diff --git a/neuro_san/internals/run_context/langchain/core/langchain_run_context.py b/neuro_san/internals/run_context/langchain/core/langchain_run_context.py index e19813378..13ec6c847 100644 --- a/neuro_san/internals/run_context/langchain/core/langchain_run_context.py +++ b/neuro_san/internals/run_context/langchain/core/langchain_run_context.py @@ -23,9 +23,8 @@ from logging import Logger from logging import getLogger -from openai import APIError -from anthropic import BadRequestError -from anthropic import AuthenticationError +from openai import APIError as OpenAI_APIError +from anthropic import APIError as Anthropic_APIError from pydantic_core import ValidationError @@ -495,7 +494,7 @@ async def ainvoke(self, agent_executor: AgentExecutor, inputs: Dict[str, Any], i while return_dict is None and retries > 0: try: return_dict: Dict[str, Any] = await agent_executor.ainvoke(inputs, invoke_config) - except (APIError, BadRequestError, AuthenticationError, ChatGoogleGenerativeAIError) as api_error: + except (OpenAI_APIError, Anthropic_APIError, ChatGoogleGenerativeAIError) as api_error: message: str = ApiKeyErrorCheck.check_for_api_key_exception(api_error) if message is not None: raise ValueError(message) from api_error diff --git a/neuro_san/internals/run_context/langchain/llms/default_llm_factory.py b/neuro_san/internals/run_context/langchain/llms/default_llm_factory.py index a15fff47c..fac7ca4eb 100644 --- a/neuro_san/internals/run_context/langchain/llms/default_llm_factory.py +++ b/neuro_san/internals/run_context/langchain/llms/default_llm_factory.py @@ -13,7 +13,6 @@ from typing import Any from typing import Dict from typing import List -from typing import Optional from typing import Type import os @@ -26,7 +25,6 @@ from langchain_core.language_models.base import BaseLanguageModel from leaf_common.config.dictionary_overlay import DictionaryOverlay -from leaf_common.config.resolver import Resolver from leaf_common.parsers.dictionary_extractor import DictionaryExtractor from neuro_san.internals.interfaces.context_type_llm_factory import ContextTypeLlmFactory @@ -34,6 +32,8 @@ from neuro_san.internals.run_context.langchain.llms.llm_info_restorer import LlmInfoRestorer from neuro_san.internals.run_context.langchain.llms.standard_langchain_llm_factory import StandardLangChainLlmFactory from neuro_san.internals.run_context.langchain.util.api_key_error_check import ApiKeyErrorCheck +from neuro_san.internals.run_context.langchain.util.argument_validator import ArgumentValidator +from neuro_san.internals.utils.resolver_util import ResolverUtil class DefaultLlmFactory(ContextTypeLlmFactory, LangChainLlmFactory): @@ -63,7 +63,7 @@ class DefaultLlmFactory(ContextTypeLlmFactory, LangChainLlmFactory): the model description in this class. """ - def __init__(self, config: Optional[Dict[str, Any]] = None): + def __init__(self, config: Dict[str, Any] = None): """ Constructor @@ -123,39 +123,20 @@ def resolve_one_llm_factory(self, llm_factory_class_name: str, llm_info_file: st raise ValueError(f"The value for the classes.factories key in {llm_info_file} " "must be a list of strings") - class_split: List[str] = llm_factory_class_name.split(".") - if len(class_split) <= 2: - raise ValueError(f"Value in the classes.factories in {llm_info_file} must be of the form " - "..") - - # Create a list of a single package given the name in the value - packages: List[str] = [".".join(class_split[:-2])] - class_name: str = class_split[-1] - resolver = Resolver(packages) - - # Resolve the class name - llm_factory_class: Type[LangChainLlmFactory] = None - try: - llm_factory_class: Type[LangChainLlmFactory] = \ - resolver.resolve_class_in_module(class_name, module_name=class_split[-2]) - except AttributeError as exception: - raise ValueError(f"Class {llm_factory_class_name} in {llm_info_file} " - "not found in PYTHONPATH") from exception - - # Instantiate it - try: - llm_factory: LangChainLlmFactory = llm_factory_class() - except TypeError as exception: - raise ValueError(f"Class {llm_factory_class_name} in {llm_info_file} " - "must have a no-args constructor") from exception - - # Make sure its the correct type - if not isinstance(llm_factory, LangChainLlmFactory): - raise ValueError(f"Class {llm_factory_class_name} in {llm_info_file} " - "must be of type LangChainLlmFactory") + # Resolve and instantiate the factory class + llm_factory = ResolverUtil.create_instance( + class_name=llm_factory_class_name, + class_name_source=llm_info_file, + type_of_class=LangChainLlmFactory + ) + return llm_factory - def create_llm(self, config: Dict[str, Any], callbacks: List[BaseCallbackHandler] = None) -> BaseLanguageModel: + def create_llm( + self, + config: Dict[str, Any], + callbacks: List[BaseCallbackHandler] = None + ) -> BaseLanguageModel: """ Creates a langchain LLM based on the 'model_name' value of the config passed in. @@ -176,6 +157,28 @@ def create_full_llm_config(self, config: Dict[str, Any]) -> Dict[str, Any]: :param config: The llm_config from the user :return: The fully specified config with defaults filled in. """ + + class_from_llm_config: str = config.get("class") + if class_from_llm_config: + if not isinstance(class_from_llm_config, str): + raise ValueError("Value of 'class' has to be string.") + # A "class" key in the config indicates the user has specified a particular LLM implementation. + # However, the config may only contain partial arguments (e.g., {"arg_1": 0.5}) and omit others. + # + # In the standard factory, LLM classes are instantiated like: + # ChatOpenAI(arg_1=config.get("arg_1"), arg_2=config.get("arg_2")) + # If a required argument like "arg_2" is missing in the config, config.get("arg_2") returns None, + # which may raise an error during instantiation if the argument has no default. + # + # To prevent this, we first fetch the default arguments for the given class from llm_info, + # then merge them with the user-provided config. This ensures all expected arguments are present, + # and the user’s config values take precedence over the defaults. + config_from_class_in_llm_info: Dict[str, Any] = self.get_chat_class_args(class_from_llm_config) + + # Merge the defaults from llm_info with the user-defined config, + # giving priority to values in config. + return self.overlayer.overlay(config_from_class_in_llm_info, config) + default_config: Dict[str, Any] = self.llm_infos.get("default_config") use_config = self.overlayer.overlay(default_config, config) @@ -215,7 +218,7 @@ def create_full_llm_config(self, config: Dict[str, Any]) -> Dict[str, Any]: return full_config - def get_chat_class_args(self, chat_class_name: str, use_model_name: str) -> Dict[str, Any]: + def get_chat_class_args(self, chat_class_name: str, use_model_name: str = None) -> Dict[str, Any]: """ :param chat_class_name: string name of the chat class to look up. :param use_model_name: the original model name that prompted the chat class lookups @@ -227,8 +230,13 @@ def get_chat_class_args(self, chat_class_name: str, use_model_name: str) -> Dict chat_classes: Dict[str, Any] = self.llm_infos.get("classes") chat_class: Dict[str, Any] = chat_classes.get(chat_class_name) if chat_class is None: - raise ValueError(f"llm info entry for {use_model_name} uses a 'class' of {chat_class_name} " - "which is not defined in the 'classes' table.") + if use_model_name is not None: + # If use_model_name is given, it must have a "class" in "classes" + raise ValueError(f"llm info entry for {use_model_name} uses a 'class' of {chat_class_name} " + "which is not defined in the 'classes' table.") + # If use_model_name is not provided and chat_class_name is not in "classes" in llm_info, + # it could be a user-specified langchain model class + return {} # Get the args from the chat class args: Dict[str, Any] = chat_class.get("args") @@ -244,7 +252,8 @@ def get_chat_class_args(self, chat_class_name: str, use_model_name: str) -> Dict def create_base_chat_model(self, config: Dict[str, Any], callbacks: List[BaseCallbackHandler] = None) -> BaseLanguageModel: """ - Create a BaseLanguageModel from the fully-specified llm config. + Create a BaseLanguageModel from the fully-specified llm config either from standard LLM factory, + user-defined LLM factory, or user-specified langchain model class. :param config: The fully specified llm config which is a product of _create_full_llm_config() above. :param callbacks: A list of BaseCallbackHandlers to add to the chat model. @@ -279,11 +288,54 @@ def create_base_chat_model(self, config: Dict[str, Any], # Let the next model have a crack found_exception = exception + # Try resolving via 'class' in config if factories failed + class_path: str = config.get("class") + if llm is None and found_exception is not None and class_path: + llm = self.create_base_chat_model_from_user_class(class_path, config) + found_exception = None + if found_exception is not None: raise found_exception return llm + def create_base_chat_model_from_user_class( + self, + class_path: str, + config: Dict[str, Any], + callbacks: List[BaseCallbackHandler] = None + ) -> BaseLanguageModel: + """ + Create a BaseLanguageModel from the user-specified langchain model class. + :param class_path: A string in the form of .. + :param config: The fully specified llm config which is a product of + _create_full_llm_config() above. + :param callbacks: A list of BaseCallbackHandlers to add to the chat model. + + :return: A BaseLanguageModel + """ + + if not isinstance(class_path, str): + raise ValueError("'class' in llm_config must be a string") + + # Resolve the 'class' + llm_class: Type[BaseLanguageModel] = ResolverUtil.create_class( + class_name=class_path, + class_name_source="agent network hocon file", + type_of_class=BaseLanguageModel + ) + + # Copy the config, take 'class' out, and add callbacks + # Then unpack into llm constructor + user_config: Dict[str, Any] = config.copy() + user_config.pop("class") + user_config["callbacks"] = callbacks + + # Check for invalid args and throw error if found + ArgumentValidator.check_invalid_args(llm_class, user_config) + + return llm_class(**user_config) + def get_max_prompt_tokens(self, config: Dict[str, Any]) -> int: """ :param config: A dictionary which describes which LLM to use. diff --git a/neuro_san/internals/run_context/langchain/llms/standard_langchain_llm_factory.py b/neuro_san/internals/run_context/langchain/llms/standard_langchain_llm_factory.py index ef101be59..cbab9b1d9 100644 --- a/neuro_san/internals/run_context/langchain/llms/standard_langchain_llm_factory.py +++ b/neuro_san/internals/run_context/langchain/llms/standard_langchain_llm_factory.py @@ -70,7 +70,10 @@ def create_base_chat_model(self, config: Dict[str, Any], if chat_class is not None: chat_class = chat_class.lower() - model_name: str = config.get("model_name") + # Check for key "model_name", "model", and "model_id" to use as model name + # If the config is from default_llm_info, this is always "model_name" + # but with user-specified config, it is possible to have the other keys will be specifed instead. + model_name: str = config.get("model_name") or config.get("model") or config.get("model_id") if chat_class == "openai": llm = ChatOpenAI( diff --git a/neuro_san/internals/run_context/langchain/toolbox/toolbox_factory.py b/neuro_san/internals/run_context/langchain/toolbox/toolbox_factory.py index 5e28c3fda..2e75ccf3c 100644 --- a/neuro_san/internals/run_context/langchain/toolbox/toolbox_factory.py +++ b/neuro_san/internals/run_context/langchain/toolbox/toolbox_factory.py @@ -10,14 +10,12 @@ # # END COPYRIGHT -from inspect import signature -from types import MethodType + from typing import Any from typing import Callable from typing import Dict from typing import List from typing import Optional -from typing import Set from typing import Type from typing import Union @@ -32,6 +30,7 @@ from neuro_san.internals.interfaces.context_type_toolbox_factory import ContextTypeToolboxFactory from neuro_san.internals.run_context.langchain.toolbox.toolbox_info_restorer import ToolboxInfoRestorer +from neuro_san.internals.run_context.langchain.util.argument_validator import ArgumentValidator class ToolboxFactory(ContextTypeToolboxFactory): @@ -152,7 +151,7 @@ def create_tool_from_toolbox( self._get_from_api_wrapper_method(tool_class) or tool_class # Validate and instantiate - self._check_invalid_args(callable_obj, final_args) + ArgumentValidator.check_invalid_args(callable_obj, final_args) # Instance can be a BaseTool or a BaseToolkit instance: Union[BaseTool, BaseToolkit] = callable_obj(**final_args) @@ -178,7 +177,7 @@ def _resolve_args(self, args: Dict[str, Any]) -> Dict[str, Any]: # If the argument is a class definition, resolve and instantiate it nested_class: BaseModel = self._resolve_class(value.get("class")) nested_args: Dict[str, Any] = self._resolve_args(value.get("args", empty)) - self._check_invalid_args(nested_class, nested_args) + ArgumentValidator.check_invalid_args(nested_class, nested_args) resolved_args[key] = nested_class(**nested_args) else: # Otherwise, keep primitive values as they are @@ -210,22 +209,6 @@ def _resolve_class(self, class_path: str) -> Type[BaseTool]: except AttributeError as exception: raise ValueError(f"Class {class_path} not found in PYTHONPATH") from exception - def _check_invalid_args(self, method_class: Union[Type, MethodType], args: Dict[str, Any]): - """ - Check for invalid arguments in class or method - :param method_class: Class or method to check for the invalid arguments - :param args: Arguments to check - """ - class_args_set: Set[str] = set(signature(method_class).parameters.keys()) - args_set: Set[str] = set(args.keys()) - invalid_args: Set[str] = args_set - class_args_set - - if invalid_args: - raise ValueError( - f"Arguments {invalid_args} for '{method_class.__name__}' do not match any attributes " - "of the class or any arguments of the method." - ) - def _get_from_api_wrapper_method( self, tool_class: Union[Type[BaseTool], Type[BaseToolkit]] diff --git a/neuro_san/internals/run_context/langchain/util/argument_validator.py b/neuro_san/internals/run_context/langchain/util/argument_validator.py new file mode 100644 index 000000000..c48748657 --- /dev/null +++ b/neuro_san/internals/run_context/langchain/util/argument_validator.py @@ -0,0 +1,81 @@ +# Copyright (C) 2023-2025 Cognizant Digital Business, Evolutionary AI. +# All Rights Reserved. +# Issued under the Academic Public License. +# +# You can be released from the terms, and requirements of the Academic Public +# License by purchasing a commercial license. +# Purchase of a commercial license is mandatory for any use of the +# neuro-san SDK Software in commercial settings. +# +# END COPYRIGHT + +from typing import Any +from typing import Dict +from typing import Set +from typing import Type +from typing import Union +from types import MethodType +from inspect import isclass +from inspect import signature +from pydantic import BaseModel + + +class ArgumentValidator: + """ + A utility class for inspecting method and class arguments, particularly useful for + validating input against Pydantic `BaseModel` subclasses or callable method signatures. + + This class provides static methods to: + - Validate that a dictionary of arguments matches the accepted parameters of a class or method. + - Extract all field names and aliases from a Pydantic BaseModel subclass. + """ + + @staticmethod + def check_invalid_args(method_class: Union[Type, MethodType], args: Dict[str, Any]): + """ + Check for invalid arguments in a class constructor or method call. + + :param method_class: The class or method to validate against. + :param args: Dictionary of argument to check. + :raises ValueError: If any argument name is not accepted by the method or class. + """ + + # If method_class is a Pydantic BaseModel, get its field names and aliases + if isclass(method_class) and issubclass(method_class, BaseModel): + class_args_set: Set[str] = ArgumentValidator.get_base_model_args(method_class) + else: + # Otherwise, extract argument names from the function/method signature + class_args_set = set(signature(method_class).parameters.keys()) + + # Get the argument keys provided by the user + args_set: Set[str] = set(args.keys()) + + # Identify which arguments are not accepted by the method/class + invalid_args: Set[str] = args_set - class_args_set + if invalid_args: + raise ValueError( + f"Arguments {invalid_args} for '{method_class.__name__}' do not match any attributes " + "of the class or any arguments of the method." + ) + + @staticmethod + def get_base_model_args(base_model_class: Type[BaseModel]) -> Set[str]: + """ + Extract all field names and aliases from a Pydantic BaseModel class. + + :param base_model_class: A class that inherits from `BaseModel`. + :return: A set of valid argument names, including both field names and aliases. + """ + + fields_and_aliases: Set[str] = set() + + # Check for field name and info + # field info includes attributes like "required", "default", "description", and "alias" + for field_name, field_info in base_model_class.model_fields.items(): + # Add field name to the set + fields_and_aliases.add(field_name) + if field_info.alias: + # If there is "alias" in the info add it to the set as well + fields_and_aliases.add(field_info.alias) + + return fields_and_aliases diff --git a/neuro_san/internals/utils/resolver_util.py b/neuro_san/internals/utils/resolver_util.py index 98c79c3de..45b482018 100644 --- a/neuro_san/internals/utils/resolver_util.py +++ b/neuro_san/internals/utils/resolver_util.py @@ -37,6 +37,35 @@ def create_instance(class_name: str, class_name_source: str, type_of_class: Type """ instance: Any = None + class_reference: Type[Any] = ResolverUtil.create_class(class_name, class_name_source, type_of_class) + + if class_reference is None: + return None + + # Instantiate the class + try: + instance = class_reference() + except TypeError as exception: + raise ValueError(f"Class '{class_name}' from {class_name_source} " + "must have a no-args constructor") from exception + + return instance + + @staticmethod + def create_class(class_name: str, class_name_source: str, type_of_class: Type) -> Type: + """ + Resolves a fully qualified class name string into an actual Python class object. + + This method expects the input string to follow the format: + '..' and uses a Resolver to dynamically + locate and return the class object. + + :param class_name: The fully qualified name of the class to resolve. + :param class_name_source: A description of the source of the class_name string, + used for clearer error messages. + :param type_of_class: Base type or interface the class must inherit from. + :return: The resolved class object. Can return None if class_name is a None or empty string. + """ if class_name is None or len(class_name) == 0: return None @@ -59,16 +88,10 @@ def create_instance(class_name: str, class_name_source: str, type_of_class: Type raise ValueError(f"Class '{class_name}' from {class_name_source} " "not found in PYTHONPATH") from exception - # Instantiate the class - try: - instance = class_reference() - except TypeError as exception: - raise ValueError(f"Class '{class_name}' from {class_name_source} " - "must have a no-args constructor") from exception - # Make sure it is the correct type - if not isinstance(instance, type_of_class): - raise ValueError(f"Class '{class_name}' from {class_name_source} " - "must be of type {type_of_class.__name__}") + if not issubclass(class_reference, type_of_class): + raise ValueError( + f"Class {class_name} in {class_name_source} must be a subclass of {type_of_class.__name__}" + ) - return instance + return class_reference diff --git a/neuro_san/registries/google_serper.hocon b/neuro_san/registries/google_serper.hocon index 1128f8ce8..44064bc11 100644 --- a/neuro_san/registries/google_serper.hocon +++ b/neuro_san/registries/google_serper.hocon @@ -35,10 +35,10 @@ { "name": "searcher", "instructions": "Use your tool to respond to the inquiry.", - "function": { + "function": { # The description acts as an initial prompt. "description": "Assist user with answer from internet." - } + } "tools": ["search_tool"] }, { diff --git a/tests/neuro_san/internals/run_context/langchain/toolbox/test_toolbox_factory.py b/tests/neuro_san/internals/run_context/langchain/toolbox/test_toolbox_factory.py index 483a8b08f..0579d450c 100644 --- a/tests/neuro_san/internals/run_context/langchain/toolbox/test_toolbox_factory.py +++ b/tests/neuro_san/internals/run_context/langchain/toolbox/test_toolbox_factory.py @@ -18,6 +18,12 @@ from neuro_san.internals.run_context.langchain.toolbox.toolbox_factory import ToolboxFactory +RESOLVER_PATH = "leaf_common.config.resolver.Resolver.resolve_class_in_module" +VALIDATIOR_PATH = ( + "neuro_san.internals.run_context.langchain.util.argument_validator." + "ArgumentValidator.check_invalid_args" +) + class TestBaseToolFactory: """Simplified test suite for ToolboxFactory.""" @@ -43,8 +49,7 @@ def test_create_toolbox_returns_single_base_tool(self, factory): # Mock user-provided arguments user_args = {"param2": "user_value", "param3": "extra_value"} - with patch("leaf_common.config.resolver.Resolver.resolve_class_in_module") as mock_resolver, \ - patch.object(factory, "_check_invalid_args") as mock_check_invalid: + with patch(RESOLVER_PATH) as mock_resolver, patch(VALIDATIOR_PATH) as mock_check_invalid: mock_tool_class = MagicMock(spec=BaseTool) mock_resolver.return_value = mock_tool_class @@ -80,8 +85,7 @@ def test_create_toolbox_with_toolkit_constructor(self, factory): # Mock user-provided arguments user_args = {"param2": "user_value", "param3": "extra_value"} - with patch("leaf_common.config.resolver.Resolver.resolve_class_in_module") as mock_resolver, \ - patch.object(factory, "_check_invalid_args") as mock_check_invalid: + with patch(RESOLVER_PATH) as mock_resolver, patch(VALIDATIOR_PATH) as mock_check_invalid: mock_toolkit_class = MagicMock(spec=BaseToolkit) mock_resolver.return_value = mock_toolkit_class @@ -119,8 +123,7 @@ def test_create_toolbox_with_toolkit_class_method(self, factory): # Mock user-provided arguments user_args = {"param2": "user_value", "param3": "extra_value"} - with patch("leaf_common.config.resolver.Resolver.resolve_class_in_module") as mock_resolver, \ - patch.object(factory, "_check_invalid_args") as mock_check_invalid: + with patch(RESOLVER_PATH) as mock_resolver, patch(VALIDATIOR_PATH) as mock_check_invalid: # Mock the toolkit class mock_toolkit_class = MagicMock() mock_resolver.return_value = mock_toolkit_class diff --git a/tests/neuro_san/internals/run_context/langchain/util/test_argument_validator.py b/tests/neuro_san/internals/run_context/langchain/util/test_argument_validator.py new file mode 100644 index 000000000..84a66d45c --- /dev/null +++ b/tests/neuro_san/internals/run_context/langchain/util/test_argument_validator.py @@ -0,0 +1,88 @@ +# Copyright (C) 2023-2025 Cognizant Digital Business, Evolutionary AI. +# All Rights Reserved. +# Issued under the Academic Public License. +# +# You can be released from the terms, and requirements of the Academic Public +# License by purchasing a commercial license. +# Purchase of a commercial license is mandatory for any use of the +# neuro-san SDK Software in commercial settings. +# +# END COPYRIGHT + +from typing import Any +from typing import Dict +from typing import Set + +from pydantic import BaseModel +from pydantic import Field +import pytest + +from neuro_san.internals.run_context.langchain.util.argument_validator import ArgumentValidator + + +# ----------- Sample Models and Functions for Testing ----------- + +class TestModel(BaseModel): + """Used for get_base_model_args method testing""" + name: str + age: int = Field(alias="years") + + +def sample_function(name: str, city: str): + """Used for method signature testing""" + print(name, city) + + +# ------------------------- Tests ------------------------------- + +class TestArgumentValidator: + """Tests for the ArgumentValidator class and its static validation utilities.""" + + def test_get_base_model_args_returns_field_names_and_aliases(self): + """ + Test that get_base_model_args returns both the original field names + and their aliases from a Pydantic BaseModel. + """ + result: Set[str] = ArgumentValidator.get_base_model_args(TestModel) + # Expect both the field name 'name' and the alias 'years' to be included + assert "name" in result + assert "years" in result + assert "age" in result # The original field name is also always included + + def test_check_invalid_args_with_valid_model_args_passes(self): + """ + Test that check_invalid_args does not raise an error when passed + valid field names and aliases for a BaseModel. + """ + args: Dict[str, Any] = {"name": "Alice", "years": 30} + # Should not raise + ArgumentValidator.check_invalid_args(TestModel, args) + + def test_check_invalid_args_with_invalid_model_args_raises(self): + """ + Test that check_invalid_args raises a ValueError when invalid field names + are passed to a BaseModel. + """ + args: Dict[str, Any] = {"name": "Alice", "invalid_field": "oops"} + with pytest.raises(ValueError) as excinfo: + ArgumentValidator.check_invalid_args(TestModel, args) + assert "invalid_field" in str(excinfo.value) + + def test_check_invalid_args_with_valid_function_args_passes(self): + """ + Test that check_invalid_args does not raise an error when valid argument + names are passed to a regular Python function. + """ + args: Dict[str, Any] = {"name": "Alice", "city": "NYC"} + # Should not raise + ArgumentValidator.check_invalid_args(sample_function, args) + + def test_check_invalid_args_with_invalid_function_args_raises(self): + """ + Test that check_invalid_args raises a ValueError when invalid argument + names are passed to a regular Python function. + """ + args: Dict[str, Any] = {"name": "Alice", "bad_arg": "fail"} + with pytest.raises(ValueError) as excinfo: + ArgumentValidator.check_invalid_args(sample_function, args) + assert "bad_arg" in str(excinfo.value)