-
Notifications
You must be signed in to change notification settings - Fork 28
UN-3276 support for user defined llm in hocon file with class key #270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
de0a29a
ecdee78
d23186e
928ea36
58b3b8b
5906c6f
8a8128f
23891db
2a8131f
8e4d881
962abf7
17974b5
9d90aaa
7647dc9
741e50e
f8ce89d
69889ad
9363936
a796cf8
bbc3fbc
50bd7f0
1aac685
69f38f2
e6dc921
0b98a39
cca1d03
b021671
35c0362
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
|
|
||
| # Copyright (C) 2023-2025 Cognizant Digital Business, Evolutionary AI. | ||
| # All Rights Reserved. | ||
| # Issued under the Academic Public License. | ||
| # | ||
| # You can be released from the terms, and requirements of the Academic Public | ||
| # License by purchasing a commercial license. | ||
| # Purchase of a commercial license is mandatory for any use of the | ||
| # neuro-san SDK Software in commercial settings. | ||
| # | ||
| # END COPYRIGHT | ||
|
|
||
| from typing import Any | ||
| from typing import Dict | ||
| from typing import List | ||
|
|
||
| from langchain_core.callbacks.base import BaseCallbackHandler | ||
| from langchain_core.language_models.base import BaseLanguageModel | ||
| from langchain_groq import ChatGroq | ||
|
|
||
| from neuro_san.internals.run_context.langchain.llms.langchain_llm_factory import LangChainLlmFactory | ||
|
|
||
|
|
||
| class GroqLangChainLlmFactory(LangChainLlmFactory): | ||
| """ | ||
| Factory class for LLM operations | ||
| """ | ||
|
|
||
| def create_base_chat_model(self, config: Dict[str, Any], | ||
| callbacks: List[BaseCallbackHandler] = None) -> BaseLanguageModel: | ||
| """ | ||
| Create a BaseLanguageModel from the fully-specified llm config. | ||
| :param config: The fully specified llm config which is a product of | ||
| _create_full_llm_config() above. | ||
| :param callbacks: A list of BaseCallbackHandlers to add to the chat model. | ||
| :return: A BaseLanguageModel (can be Chat or LLM) | ||
| Can raise a ValueError if the config's class or model_name value is | ||
| unknown to this method. | ||
| """ | ||
| # Construct the LLM | ||
| llm: BaseLanguageModel = None | ||
| chat_class: str = config.get("class") | ||
| if chat_class is not None: | ||
| chat_class = chat_class.lower() | ||
|
|
||
| model_name: str = config.get("model_name") | ||
|
|
||
| if chat_class == "groq": | ||
| llm = ChatGroq( | ||
| model=model_name, | ||
| temperature=config.get("temperature") | ||
| ) | ||
| elif chat_class is None: | ||
| raise ValueError(f"Class name {chat_class} for model_name {model_name} is unspecified.") | ||
| else: | ||
| raise ValueError(f"Class {chat_class} for model_name {model_name} is unrecognized.") | ||
|
|
||
| return llm | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # Copyright (C) 2023-2025 Cognizant Digital Business, Evolutionary AI. | ||
| # All Rights Reserved. | ||
| # Issued under the Academic Public License. | ||
| # | ||
| # You can be released from the terms, and requirements of the Academic Public | ||
| # License by purchasing a commercial license. | ||
| # Purchase of a commercial license is mandatory for any use of the | ||
| # neuro-san SDK Software in commercial settings. | ||
| # | ||
| # END COPYRIGHT | ||
|
|
||
| # The schema specifications for this file are documented here: | ||
| # https://github.com/cognizant-ai-lab/neuro-san/blob/main/docs/llm_info_hocon_reference.md | ||
|
|
||
| { | ||
|
|
||
|
|
||
| "classes": { | ||
| "factories": [ "llm_extension.groq_langchain_llm_factory.GroqLangChainLlmFactory" ] | ||
| "groq": { | ||
|
||
| # Add arguments like temperature that you want to pass to the llm here. | ||
| "temperature": 0.7 | ||
| } | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,9 +23,8 @@ | |
| from logging import Logger | ||
| from logging import getLogger | ||
|
|
||
| from openai import APIError | ||
| from anthropic import BadRequestError | ||
| from anthropic import AuthenticationError | ||
| from openai import APIError as OpenAI_APIError | ||
| from anthropic import APIError as Anthropic_APIError | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only import |
||
|
|
||
| from pydantic_core import ValidationError | ||
|
|
||
|
|
@@ -495,7 +494,7 @@ async def ainvoke(self, agent_executor: AgentExecutor, inputs: Dict[str, Any], i | |
| while return_dict is None and retries > 0: | ||
| try: | ||
| return_dict: Dict[str, Any] = await agent_executor.ainvoke(inputs, invoke_config) | ||
| except (APIError, BadRequestError, AuthenticationError, ChatGoogleGenerativeAIError) as api_error: | ||
| except (OpenAI_APIError, Anthropic_APIError, ChatGoogleGenerativeAIError) as api_error: | ||
Noravee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| message: str = ApiKeyErrorCheck.check_for_api_key_exception(api_error) | ||
| if message is not None: | ||
| raise ValueError(message) from api_error | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,7 @@ | |
| from neuro_san.internals.run_context.langchain.llms.langchain_llm_factory import LangChainLlmFactory | ||
| from neuro_san.internals.run_context.langchain.llms.llm_info_restorer import LlmInfoRestorer | ||
| from neuro_san.internals.run_context.langchain.llms.standard_langchain_llm_factory import StandardLangChainLlmFactory | ||
| from neuro_san.internals.run_context.langchain.toolbox.toolbox_factory import ToolboxFactory | ||
| from neuro_san.internals.run_context.langchain.util.api_key_error_check import ApiKeyErrorCheck | ||
|
|
||
|
|
||
|
|
@@ -123,24 +124,13 @@ def resolve_one_llm_factory(self, llm_factory_class_name: str, llm_info_file: st | |
| raise ValueError(f"The value for the classes.factories key in {llm_info_file} " | ||
| "must be a list of strings") | ||
|
|
||
| class_split: List[str] = llm_factory_class_name.split(".") | ||
| if len(class_split) <= 2: | ||
| raise ValueError(f"Value in the classes.factories in {llm_info_file} must be of the form " | ||
| "<package_name>.<module_name>.<ClassName>") | ||
|
|
||
| # Create a list of a single package given the name in the value | ||
| packages: List[str] = [".".join(class_split[:-2])] | ||
| class_name: str = class_split[-1] | ||
| resolver = Resolver(packages) | ||
|
|
||
| # Resolve the class name | ||
| llm_factory_class: Type[LangChainLlmFactory] = None | ||
| try: | ||
| llm_factory_class: Type[LangChainLlmFactory] = \ | ||
| resolver.resolve_class_in_module(class_name, module_name=class_split[-2]) | ||
| except AttributeError as exception: | ||
| raise ValueError(f"Class {llm_factory_class_name} in {llm_info_file} " | ||
| "not found in PYTHONPATH") from exception | ||
| # Resolve the factory class | ||
| llm_factory_class = self._resolve_class_from_path( | ||
| class_path=llm_factory_class_name, | ||
| expected_base=LangChainLlmFactory, | ||
| source_file=llm_info_file, | ||
| description="classes.factories" | ||
| ) | ||
|
|
||
| # Instantiate it | ||
| try: | ||
|
|
@@ -155,6 +145,38 @@ def resolve_one_llm_factory(self, llm_factory_class_name: str, llm_info_file: st | |
| "must be of type LangChainLlmFactory") | ||
| return llm_factory | ||
|
|
||
| def _resolve_class_from_path( | ||
| self, | ||
| class_path: str, | ||
| expected_base: Type, | ||
| source_file: str, | ||
| description: str | ||
| ) -> Type: | ||
|
||
|
|
||
| parts = class_path.split(".") | ||
| if len(parts) <= 2: | ||
| raise ValueError( | ||
| f"Value for '{description}' in {source_file} must be of the form " | ||
| "<package>.<module>.<ClassName>" | ||
| ) | ||
Noravee marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| module_name = parts[-2] | ||
| class_name = parts[-1] | ||
| packages = [".".join(parts[:-2])] | ||
| resolver = Resolver(packages) | ||
|
|
||
| try: | ||
| cls = resolver.resolve_class_in_module(class_name, module_name=module_name) | ||
| except AttributeError as e: | ||
| raise ValueError(f"Class {class_path} in {source_file} not found in PYTHONPATH") from e | ||
|
|
||
| if not issubclass(cls, expected_base): | ||
| raise ValueError( | ||
| f"Class {class_path} in {source_file} must be a subclass of {expected_base.__name__}" | ||
| ) | ||
|
|
||
| return cls | ||
|
|
||
| def create_llm(self, config: Dict[str, Any], callbacks: List[BaseCallbackHandler] = None) -> BaseLanguageModel: | ||
| """ | ||
| Creates a langchain LLM based on the 'model_name' value of | ||
|
|
@@ -169,13 +191,19 @@ def create_llm(self, config: Dict[str, Any], callbacks: List[BaseCallbackHandler | |
| """ | ||
| full_config: Dict[str, Any] = self.create_full_llm_config(config) | ||
| llm: BaseLanguageModel = self.create_base_chat_model(full_config, callbacks) | ||
| print(f"\n\n{llm=}\n\n") | ||
| return llm | ||
|
|
||
| def create_full_llm_config(self, config: Dict[str, Any]) -> Dict[str, Any]: | ||
| """ | ||
| :param config: The llm_config from the user | ||
| :return: The fully specified config with defaults filled in. | ||
| """ | ||
|
|
||
| if config.get("class"): | ||
| # If config has "class", it is a user-specified llm so return config as is, | ||
| return config | ||
|
||
|
|
||
| default_config: Dict[str, Any] = self.llm_infos.get("default_config") | ||
| use_config = self.overlayer.overlay(default_config, config) | ||
|
|
||
|
|
@@ -279,6 +307,27 @@ def create_base_chat_model(self, config: Dict[str, Any], | |
| # Let the next model have a crack | ||
| found_exception = exception | ||
|
|
||
| # Try resolving via 'class' in config if factories failed | ||
| class_path = config.get("class") | ||
| if found_exception is not None and class_path: | ||
Noravee marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not isinstance(class_path, str): | ||
| raise ValueError("'class' in llm_config must be a string") | ||
|
|
||
| # Resolve the 'class' | ||
| llm_class = self._resolve_class_from_path( | ||
Noravee marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| class_path=class_path, | ||
| expected_base=BaseLanguageModel, | ||
| source_file="agent network hocon file", | ||
| description="llm_config" | ||
| ) | ||
|
|
||
| # copy the config, take 'class' out, and unpack into llm constructor | ||
| user_config = config.copy() | ||
Noravee marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| user_config.pop("class") | ||
| ToolboxFactory.check_invalid_args(llm_class, user_config) | ||
|
||
| llm = llm_class(**user_config) | ||
| found_exception = None | ||
|
|
||
| if found_exception is not None: | ||
| raise found_exception | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| # | ||
| # END COPYRIGHT | ||
|
|
||
| from inspect import isclass | ||
| from inspect import signature | ||
| from types import MethodType | ||
| from typing import Any | ||
|
|
@@ -152,7 +153,7 @@ def create_tool_from_toolbox( | |
| self._get_from_api_wrapper_method(tool_class) or tool_class | ||
|
|
||
| # Validate and instantiate | ||
| self._check_invalid_args(callable_obj, final_args) | ||
| ToolboxFactory.check_invalid_args(callable_obj, final_args) | ||
Noravee marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Instance can be a BaseTool or a BaseToolkit | ||
| instance: Union[BaseTool, BaseToolkit] = callable_obj(**final_args) | ||
|
|
||
|
|
@@ -178,7 +179,7 @@ def _resolve_args(self, args: Dict[str, Any]) -> Dict[str, Any]: | |
| # If the argument is a class definition, resolve and instantiate it | ||
| nested_class: BaseModel = self._resolve_class(value.get("class")) | ||
| nested_args: Dict[str, Any] = self._resolve_args(value.get("args", empty)) | ||
| self._check_invalid_args(nested_class, nested_args) | ||
| ToolboxFactory.check_invalid_args(nested_class, nested_args) | ||
| resolved_args[key] = nested_class(**nested_args) | ||
| else: | ||
| # Otherwise, keep primitive values as they are | ||
|
|
@@ -210,16 +211,24 @@ def _resolve_class(self, class_path: str) -> Type[BaseTool]: | |
| except AttributeError as exception: | ||
| raise ValueError(f"Class {class_path} not found in PYTHONPATH") from exception | ||
|
|
||
| def _check_invalid_args(self, method_class: Union[Type, MethodType], args: Dict[str, Any]): | ||
| @staticmethod | ||
| def check_invalid_args(method_class: Union[Type, MethodType], args: Dict[str, Any]): | ||
|
||
| """ | ||
| Check for invalid arguments in class or method | ||
| :param method_class: Class or method to check for the invalid arguments | ||
| :param args: Arguments to check | ||
| """ | ||
| class_args_set: Set[str] = set(signature(method_class).parameters.keys()) | ||
| pydantic_args: Set[str] = set() | ||
| # Check for if it is a class that extends pydantic BaseModel | ||
| if isclass(method_class) and issubclass(method_class, BaseModel): | ||
|
||
| # Include field names as args | ||
| pydantic_args = set(method_class.model_fields.keys()) | ||
| # Combine the arguments | ||
| class_args_set: Set[str] = set(signature(method_class).parameters.keys()).union(pydantic_args) | ||
Noravee marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| args_set: Set[str] = set(args.keys()) | ||
| invalid_args: Set[str] = args_set - class_args_set | ||
|
|
||
| # If there are args that are not from class args or alias, raise error | ||
| invalid_args: Set[str] = args_set - class_args_set | ||
| if invalid_args: | ||
| raise ValueError( | ||
| f"Arguments {invalid_args} for '{method_class.__name__}' do not match any attributes " | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,7 +47,9 @@ | |
| # Agents having to do with test infrastructure | ||
| "gist.hocon": true, | ||
| "assess_failure.hocon": true, | ||
|
|
||
| "test_new_model_default_class.hocon": true, | ||
| "test_new_class.hocon": true, | ||
| "test_new_model_extended_class.hocon": true | ||
|
||
| # STOP AND READ: YOU PROBABLY DON'T WANT TO ADD YOUR .hocon FILE HERE. | ||
| # | ||
| # The agent network .hocon files above are examples specific to the neuro-san library. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for testing only. It will not be in the final PR.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the spirit of robust tests + limited library dependencies, you could put this class and extra hocon file under the test/ directory and the extra langchain_groq dependency in requirements-build.txt as long as your intent is to add some kind of regularly run test that uses it (unit test if < 15 secs, integration test if longer)