Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
de0a29a
add user_specified_langchain_llm_factory
Noravee Jun 24, 2025
ecdee78
fix indent and pylint
Noravee Jun 24, 2025
d23186e
Add more comments
Noravee Jun 24, 2025
928ea36
Merge branch 'main' into UN-3276_Support_for_user-defined_llm_in_hoco…
Noravee Jun 26, 2025
58b3b8b
remove ollama response error
Noravee Jun 26, 2025
5906c6f
Merge branch 'UN-3276_Support_for_user-defined_llm_in_hocon_file_with…
Noravee Jun 26, 2025
8a8128f
Merge branch 'main' into UN-3276_Support_for_user-defined_llm_in_hoco…
Noravee Jun 26, 2025
23891db
minor changes
Noravee Jun 30, 2025
2a8131f
Merge branch 'main' into UN-3276_Support_for_user-defined_llm_in_hoco…
Noravee Jul 2, 2025
8e4d881
- Remove user_specified_langchain_llm_factory
Noravee Jul 2, 2025
962abf7
Merge branch 'UN-3276_Support_for_user-defined_llm_in_hocon_file_with…
Noravee Jul 2, 2025
17974b5
use alias for api error in langchain run context
Noravee Jul 2, 2025
9d90aaa
change comments
Noravee Jul 2, 2025
7647dc9
remove space
Noravee Jul 2, 2025
741e50e
Add comments
Noravee Jul 2, 2025
f8ce89d
Merge branch 'main' into UN-3276_Support_for_user-defined_llm_in_hoco…
Noravee Jul 2, 2025
69889ad
- Refactor logic on creating llm based on "class" into another method
Noravee Jul 3, 2025
9363936
Merge branch 'UN-3276_Support_for_user-defined_llm_in_hocon_file_with…
Noravee Jul 3, 2025
a796cf8
Add type hints
Noravee Jul 3, 2025
bbc3fbc
Add callbacks
Noravee Jul 3, 2025
50bd7f0
combine user config with the one in class in llm_info
Noravee Jul 4, 2025
1aac685
Merge branch 'main' into UN-3276_Support_for_user-defined_llm_in_hoco…
Noravee Jul 7, 2025
69f38f2
Merge branch 'main' into UN-3276_Support_for_user-defined_llm_in_hoco…
Noravee Jul 7, 2025
e6dc921
remove optional
Noravee Jul 7, 2025
0b98a39
Merge branch 'UN-3276_Support_for_user-defined_llm_in_hocon_file_with…
Noravee Jul 7, 2025
cca1d03
refactor default llm factory with resolver util
Noravee Jul 8, 2025
b021671
Merge branch 'main' into UN-3276_Support_for_user-defined_llm_in_hoco…
Noravee Jul 8, 2025
35c0362
Merge branch 'main' into UN-3276_Support_for_user-defined_llm_in_hoco…
Noravee Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
from logging import Logger
from logging import getLogger

from openai import APIError
from anthropic import BadRequestError
from anthropic import AuthenticationError
from openai import APIError as OpenAI_APIError
from anthropic import APIError as Anthropic_APIError
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only import APIError.


from pydantic_core import ValidationError

Expand Down Expand Up @@ -495,7 +494,7 @@ async def ainvoke(self, agent_executor: AgentExecutor, inputs: Dict[str, Any], i
while return_dict is None and retries > 0:
try:
return_dict: Dict[str, Any] = await agent_executor.ainvoke(inputs, invoke_config)
except (APIError, BadRequestError, AuthenticationError, ChatGoogleGenerativeAIError) as api_error:
except (OpenAI_APIError, Anthropic_APIError, ChatGoogleGenerativeAIError) as api_error:
message: str = ApiKeyErrorCheck.check_for_api_key_exception(api_error)
if message is not None:
raise ValueError(message) from api_error
Expand Down
128 changes: 90 additions & 38 deletions neuro_san/internals/run_context/langchain/llms/default_llm_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Type

import os
Expand All @@ -26,14 +25,15 @@
from langchain_core.language_models.base import BaseLanguageModel

from leaf_common.config.dictionary_overlay import DictionaryOverlay
from leaf_common.config.resolver import Resolver
from leaf_common.parsers.dictionary_extractor import DictionaryExtractor

from neuro_san.internals.interfaces.context_type_llm_factory import ContextTypeLlmFactory
from neuro_san.internals.run_context.langchain.llms.langchain_llm_factory import LangChainLlmFactory
from neuro_san.internals.run_context.langchain.llms.llm_info_restorer import LlmInfoRestorer
from neuro_san.internals.run_context.langchain.llms.standard_langchain_llm_factory import StandardLangChainLlmFactory
from neuro_san.internals.run_context.langchain.util.api_key_error_check import ApiKeyErrorCheck
from neuro_san.internals.run_context.langchain.util.argument_validator import ArgumentValidator
from neuro_san.internals.utils.resolver_util import ResolverUtil


class DefaultLlmFactory(ContextTypeLlmFactory, LangChainLlmFactory):
Expand Down Expand Up @@ -63,7 +63,7 @@ class DefaultLlmFactory(ContextTypeLlmFactory, LangChainLlmFactory):
the model description in this class.
"""

def __init__(self, config: Optional[Dict[str, Any]] = None):
def __init__(self, config: Dict[str, Any] = None):
"""
Constructor

Expand Down Expand Up @@ -123,39 +123,20 @@ def resolve_one_llm_factory(self, llm_factory_class_name: str, llm_info_file: st
raise ValueError(f"The value for the classes.factories key in {llm_info_file} "
"must be a list of strings")

class_split: List[str] = llm_factory_class_name.split(".")
if len(class_split) <= 2:
raise ValueError(f"Value in the classes.factories in {llm_info_file} must be of the form "
"<package_name>.<module_name>.<ClassName>")

# Create a list of a single package given the name in the value
packages: List[str] = [".".join(class_split[:-2])]
class_name: str = class_split[-1]
resolver = Resolver(packages)

# Resolve the class name
llm_factory_class: Type[LangChainLlmFactory] = None
try:
llm_factory_class: Type[LangChainLlmFactory] = \
resolver.resolve_class_in_module(class_name, module_name=class_split[-2])
except AttributeError as exception:
raise ValueError(f"Class {llm_factory_class_name} in {llm_info_file} "
"not found in PYTHONPATH") from exception

# Instantiate it
try:
llm_factory: LangChainLlmFactory = llm_factory_class()
except TypeError as exception:
raise ValueError(f"Class {llm_factory_class_name} in {llm_info_file} "
"must have a no-args constructor") from exception

# Make sure its the correct type
if not isinstance(llm_factory, LangChainLlmFactory):
raise ValueError(f"Class {llm_factory_class_name} in {llm_info_file} "
"must be of type LangChainLlmFactory")
# Resolve and instantiate the factory class
llm_factory = ResolverUtil.create_instance(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ResolverUtil.create_instance for resolve and instantiation of the factory class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooo. Nice.

class_name=llm_factory_class_name,
class_name_source=llm_info_file,
type_of_class=LangChainLlmFactory
)

return llm_factory

def create_llm(self, config: Dict[str, Any], callbacks: List[BaseCallbackHandler] = None) -> BaseLanguageModel:
def create_llm(
self,
config: Dict[str, Any],
callbacks: List[BaseCallbackHandler] = None
) -> BaseLanguageModel:
"""
Creates a langchain LLM based on the 'model_name' value of
the config passed in.
Expand All @@ -176,6 +157,28 @@ def create_full_llm_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
:param config: The llm_config from the user
:return: The fully specified config with defaults filled in.
"""

class_from_llm_config: str = config.get("class")
if class_from_llm_config:
if not isinstance(class_from_llm_config, str):
raise ValueError("Value of 'class' has to be string.")
# A "class" key in the config indicates the user has specified a particular LLM implementation.
# However, the config may only contain partial arguments (e.g., {"arg_1": 0.5}) and omit others.
#
# In the standard factory, LLM classes are instantiated like:
# ChatOpenAI(arg_1=config.get("arg_1"), arg_2=config.get("arg_2"))
# If a required argument like "arg_2" is missing in the config, config.get("arg_2") returns None,
# which may raise an error during instantiation if the argument has no default.
#
# To prevent this, we first fetch the default arguments for the given class from llm_info,
# then merge them with the user-provided config. This ensures all expected arguments are present,
# and the user’s config values take precedence over the defaults.
config_from_class_in_llm_info: Dict[str, Any] = self.get_chat_class_args(class_from_llm_config)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine user-specified config with the one in class in llm_info.


# Merge the defaults from llm_info with the user-defined config,
# giving priority to values in config.
return self.overlayer.overlay(config_from_class_in_llm_info, config)

default_config: Dict[str, Any] = self.llm_infos.get("default_config")
use_config = self.overlayer.overlay(default_config, config)

Expand Down Expand Up @@ -215,7 +218,7 @@ def create_full_llm_config(self, config: Dict[str, Any]) -> Dict[str, Any]:

return full_config

def get_chat_class_args(self, chat_class_name: str, use_model_name: str) -> Dict[str, Any]:
def get_chat_class_args(self, chat_class_name: str, use_model_name: str = None) -> Dict[str, Any]:
"""
:param chat_class_name: string name of the chat class to look up.
:param use_model_name: the original model name that prompted the chat class lookups
Expand All @@ -227,8 +230,13 @@ def get_chat_class_args(self, chat_class_name: str, use_model_name: str) -> Dict
chat_classes: Dict[str, Any] = self.llm_infos.get("classes")
chat_class: Dict[str, Any] = chat_classes.get(chat_class_name)
if chat_class is None:
raise ValueError(f"llm info entry for {use_model_name} uses a 'class' of {chat_class_name} "
"which is not defined in the 'classes' table.")
if use_model_name is not None:
# If use_model_name is given, it must have a "class" in "classes"
raise ValueError(f"llm info entry for {use_model_name} uses a 'class' of {chat_class_name} "
"which is not defined in the 'classes' table.")
# If use_model_name is not provided and chat_class_name is not in "classes" in llm_info,
# it could be a user-specified langchain model class
return {}

# Get the args from the chat class
args: Dict[str, Any] = chat_class.get("args")
Expand All @@ -244,7 +252,8 @@ def get_chat_class_args(self, chat_class_name: str, use_model_name: str) -> Dict
def create_base_chat_model(self, config: Dict[str, Any],
callbacks: List[BaseCallbackHandler] = None) -> BaseLanguageModel:
"""
Create a BaseLanguageModel from the fully-specified llm config.
Create a BaseLanguageModel from the fully-specified llm config either from standard LLM factory,
user-defined LLM factory, or user-specified langchain model class.
:param config: The fully specified llm config which is a product of
_create_full_llm_config() above.
:param callbacks: A list of BaseCallbackHandlers to add to the chat model.
Expand Down Expand Up @@ -279,11 +288,54 @@ def create_base_chat_model(self, config: Dict[str, Any],
# Let the next model have a crack
found_exception = exception

# Try resolving via 'class' in config if factories failed
class_path: str = config.get("class")
if llm is None and found_exception is not None and class_path:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conditions to instantiate using class in llm_config in agent network hocon are

  • Failed instantiation from the llm factories (thus llm is None)
  • Consequently, found_exception is not None
  • There is class in the config.

llm = self.create_base_chat_model_from_user_class(class_path, config)
found_exception = None

if found_exception is not None:
raise found_exception

return llm

def create_base_chat_model_from_user_class(
self,
class_path: str,
config: Dict[str, Any],
callbacks: List[BaseCallbackHandler] = None
) -> BaseLanguageModel:
"""
Create a BaseLanguageModel from the user-specified langchain model class.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the logic of instantiating using "class" and user-defined args to create_base_chat_model_from_user_class() method.

:param class_path: A string in the form of <package>.<module>.<Class>
:param config: The fully specified llm config which is a product of
_create_full_llm_config() above.
:param callbacks: A list of BaseCallbackHandlers to add to the chat model.

:return: A BaseLanguageModel
"""

if not isinstance(class_path, str):
raise ValueError("'class' in llm_config must be a string")

# Resolve the 'class'
llm_class: Type[BaseLanguageModel] = ResolverUtil.create_class(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ResolveUtil.create_class for the llm since we need to instantiate it with args.

class_name=class_path,
class_name_source="agent network hocon file",
type_of_class=BaseLanguageModel
)

# Copy the config, take 'class' out, and add callbacks
# Then unpack into llm constructor
user_config: Dict[str, Any] = config.copy()
user_config.pop("class")
user_config["callbacks"] = callbacks

# Check for invalid args and throw error if found
ArgumentValidator.check_invalid_args(llm_class, user_config)

return llm_class(**user_config)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLM instantiation with args.


def get_max_prompt_tokens(self, config: Dict[str, Any]) -> int:
"""
:param config: A dictionary which describes which LLM to use.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def create_base_chat_model(self, config: Dict[str, Any],
if chat_class is not None:
chat_class = chat_class.lower()

model_name: str = config.get("model_name")
# Check for key "model_name", "model", and "model_id" to use as model name
# If the config is from default_llm_info, this is always "model_name"
# but with user-specified config, it is possible to have the other keys will be specifed instead.
model_name: str = config.get("model_name") or config.get("model") or config.get("model_id")

if chat_class == "openai":
llm = ChatOpenAI(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
#
# END COPYRIGHT

from inspect import signature
from types import MethodType

from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Type
from typing import Union

Expand All @@ -32,6 +30,7 @@

from neuro_san.internals.interfaces.context_type_toolbox_factory import ContextTypeToolboxFactory
from neuro_san.internals.run_context.langchain.toolbox.toolbox_info_restorer import ToolboxInfoRestorer
from neuro_san.internals.run_context.langchain.util.argument_validator import ArgumentValidator


class ToolboxFactory(ContextTypeToolboxFactory):
Expand Down Expand Up @@ -152,7 +151,7 @@ def create_tool_from_toolbox(
self._get_from_api_wrapper_method(tool_class) or tool_class

# Validate and instantiate
self._check_invalid_args(callable_obj, final_args)
ArgumentValidator.check_invalid_args(callable_obj, final_args)
# Instance can be a BaseTool or a BaseToolkit
instance: Union[BaseTool, BaseToolkit] = callable_obj(**final_args)

Expand All @@ -178,7 +177,7 @@ def _resolve_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
# If the argument is a class definition, resolve and instantiate it
nested_class: BaseModel = self._resolve_class(value.get("class"))
nested_args: Dict[str, Any] = self._resolve_args(value.get("args", empty))
self._check_invalid_args(nested_class, nested_args)
ArgumentValidator.check_invalid_args(nested_class, nested_args)
resolved_args[key] = nested_class(**nested_args)
else:
# Otherwise, keep primitive values as they are
Expand Down Expand Up @@ -210,22 +209,6 @@ def _resolve_class(self, class_path: str) -> Type[BaseTool]:
except AttributeError as exception:
raise ValueError(f"Class {class_path} not found in PYTHONPATH") from exception

def _check_invalid_args(self, method_class: Union[Type, MethodType], args: Dict[str, Any]):
"""
Check for invalid arguments in class or method
:param method_class: Class or method to check for the invalid arguments
:param args: Arguments to check
"""
class_args_set: Set[str] = set(signature(method_class).parameters.keys())
args_set: Set[str] = set(args.keys())
invalid_args: Set[str] = args_set - class_args_set

if invalid_args:
raise ValueError(
f"Arguments {invalid_args} for '{method_class.__name__}' do not match any attributes "
"of the class or any arguments of the method."
)

def _get_from_api_wrapper_method(
self,
tool_class: Union[Type[BaseTool], Type[BaseToolkit]]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (C) 2023-2025 Cognizant Digital Business, Evolutionary AI.
# All Rights Reserved.
# Issued under the Academic Public License.
#
# You can be released from the terms, and requirements of the Academic Public
# License by purchasing a commercial license.
# Purchase of a commercial license is mandatory for any use of the
# neuro-san SDK Software in commercial settings.
#
# END COPYRIGHT

from typing import Any
from typing import Dict
from typing import Set
from typing import Type
from typing import Union
from types import MethodType
from inspect import isclass
from inspect import signature
from pydantic import BaseModel


class ArgumentValidator:
"""
A utility class for inspecting method and class arguments, particularly useful for
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor checking for invalid args to a utility class ArgumentValidator.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nicely abstracted.

validating input against Pydantic `BaseModel` subclasses or callable method signatures.

This class provides static methods to:
- Validate that a dictionary of arguments matches the accepted parameters of a class or method.
- Extract all field names and aliases from a Pydantic BaseModel subclass.
"""

@staticmethod
def check_invalid_args(method_class: Union[Type, MethodType], args: Dict[str, Any]):
"""
Check for invalid arguments in a class constructor or method call.

:param method_class: The class or method to validate against.
:param args: Dictionary of argument to check.
:raises ValueError: If any argument name is not accepted by the method or class.
"""

# If method_class is a Pydantic BaseModel, get its field names and aliases
if isclass(method_class) and issubclass(method_class, BaseModel):
class_args_set: Set[str] = ArgumentValidator.get_base_model_args(method_class)
else:
# Otherwise, extract argument names from the function/method signature
class_args_set = set(signature(method_class).parameters.keys())

# Get the argument keys provided by the user
args_set: Set[str] = set(args.keys())

# Identify which arguments are not accepted by the method/class
invalid_args: Set[str] = args_set - class_args_set
if invalid_args:
raise ValueError(
f"Arguments {invalid_args} for '{method_class.__name__}' do not match any attributes "
"of the class or any arguments of the method."
)

@staticmethod
def get_base_model_args(base_model_class: Type[BaseModel]) -> Set[str]:
"""
Extract all field names and aliases from a Pydantic BaseModel class.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good that you call out the BaseModel class as coming from Pydantic.
I always get confused thinking it's an LLM, but it's not, it's lower level than that.


:param base_model_class: A class that inherits from `BaseModel`.
:return: A set of valid argument names, including both field names and aliases.
"""

fields_and_aliases: Set[str] = set()

# Check for field name and info
# field info includes attributes like "required", "default", "description", and "alias"
for field_name, field_info in base_model_class.model_fields.items():
# Add field name to the set
fields_and_aliases.add(field_name)
if field_info.alias:
# If there is "alias" in the info add it to the set as well
fields_and_aliases.add(field_info.alias)

return fields_and_aliases
Loading