-
Notifications
You must be signed in to change notification settings - Fork 28
UN-3276 support for user defined llm in hocon file with class key #270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 27 commits
de0a29a
ecdee78
d23186e
928ea36
58b3b8b
5906c6f
8a8128f
23891db
2a8131f
8e4d881
962abf7
17974b5
9d90aaa
7647dc9
741e50e
f8ce89d
69889ad
9363936
a796cf8
bbc3fbc
50bd7f0
1aac685
69f38f2
e6dc921
0b98a39
cca1d03
b021671
35c0362
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,6 @@ | |
| from typing import Any | ||
| from typing import Dict | ||
| from typing import List | ||
| from typing import Optional | ||
| from typing import Type | ||
|
|
||
| import os | ||
|
|
@@ -26,14 +25,15 @@ | |
| from langchain_core.language_models.base import BaseLanguageModel | ||
|
|
||
| from leaf_common.config.dictionary_overlay import DictionaryOverlay | ||
| from leaf_common.config.resolver import Resolver | ||
| from leaf_common.parsers.dictionary_extractor import DictionaryExtractor | ||
|
|
||
| from neuro_san.internals.interfaces.context_type_llm_factory import ContextTypeLlmFactory | ||
| from neuro_san.internals.run_context.langchain.llms.langchain_llm_factory import LangChainLlmFactory | ||
| from neuro_san.internals.run_context.langchain.llms.llm_info_restorer import LlmInfoRestorer | ||
| from neuro_san.internals.run_context.langchain.llms.standard_langchain_llm_factory import StandardLangChainLlmFactory | ||
| from neuro_san.internals.run_context.langchain.util.api_key_error_check import ApiKeyErrorCheck | ||
| from neuro_san.internals.run_context.langchain.util.argument_validator import ArgumentValidator | ||
| from neuro_san.internals.utils.resolver_util import ResolverUtil | ||
|
|
||
|
|
||
| class DefaultLlmFactory(ContextTypeLlmFactory, LangChainLlmFactory): | ||
|
|
@@ -63,7 +63,7 @@ class DefaultLlmFactory(ContextTypeLlmFactory, LangChainLlmFactory): | |
| the model description in this class. | ||
| """ | ||
|
|
||
| def __init__(self, config: Optional[Dict[str, Any]] = None): | ||
| def __init__(self, config: Dict[str, Any] = None): | ||
| """ | ||
| Constructor | ||
|
|
||
|
|
@@ -123,39 +123,20 @@ def resolve_one_llm_factory(self, llm_factory_class_name: str, llm_info_file: st | |
| raise ValueError(f"The value for the classes.factories key in {llm_info_file} " | ||
| "must be a list of strings") | ||
|
|
||
| class_split: List[str] = llm_factory_class_name.split(".") | ||
| if len(class_split) <= 2: | ||
| raise ValueError(f"Value in the classes.factories in {llm_info_file} must be of the form " | ||
| "<package_name>.<module_name>.<ClassName>") | ||
|
|
||
| # Create a list of a single package given the name in the value | ||
| packages: List[str] = [".".join(class_split[:-2])] | ||
| class_name: str = class_split[-1] | ||
| resolver = Resolver(packages) | ||
|
|
||
| # Resolve the class name | ||
| llm_factory_class: Type[LangChainLlmFactory] = None | ||
| try: | ||
| llm_factory_class: Type[LangChainLlmFactory] = \ | ||
| resolver.resolve_class_in_module(class_name, module_name=class_split[-2]) | ||
| except AttributeError as exception: | ||
| raise ValueError(f"Class {llm_factory_class_name} in {llm_info_file} " | ||
| "not found in PYTHONPATH") from exception | ||
|
|
||
| # Instantiate it | ||
| try: | ||
| llm_factory: LangChainLlmFactory = llm_factory_class() | ||
| except TypeError as exception: | ||
| raise ValueError(f"Class {llm_factory_class_name} in {llm_info_file} " | ||
| "must have a no-args constructor") from exception | ||
|
|
||
| # Make sure its the correct type | ||
| if not isinstance(llm_factory, LangChainLlmFactory): | ||
| raise ValueError(f"Class {llm_factory_class_name} in {llm_info_file} " | ||
| "must be of type LangChainLlmFactory") | ||
| # Resolve and instantiate the factory class | ||
| llm_factory = ResolverUtil.create_instance( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use ResolverUtil.create_instance for resolve and instantiation of the factory class.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ooo. Nice. |
||
| class_name=llm_factory_class_name, | ||
| class_name_source=llm_info_file, | ||
| type_of_class=LangChainLlmFactory | ||
| ) | ||
|
|
||
| return llm_factory | ||
|
|
||
| def create_llm(self, config: Dict[str, Any], callbacks: List[BaseCallbackHandler] = None) -> BaseLanguageModel: | ||
| def create_llm( | ||
| self, | ||
| config: Dict[str, Any], | ||
| callbacks: List[BaseCallbackHandler] = None | ||
| ) -> BaseLanguageModel: | ||
| """ | ||
| Creates a langchain LLM based on the 'model_name' value of | ||
| the config passed in. | ||
|
|
@@ -176,6 +157,28 @@ def create_full_llm_config(self, config: Dict[str, Any]) -> Dict[str, Any]: | |
| :param config: The llm_config from the user | ||
| :return: The fully specified config with defaults filled in. | ||
| """ | ||
|
|
||
| class_from_llm_config: str = config.get("class") | ||
| if class_from_llm_config: | ||
| if not isinstance(class_from_llm_config, str): | ||
| raise ValueError("Value of 'class' has to be string.") | ||
| # A "class" key in the config indicates the user has specified a particular LLM implementation. | ||
| # However, the config may only contain partial arguments (e.g., {"arg_1": 0.5}) and omit others. | ||
| # | ||
| # In the standard factory, LLM classes are instantiated like: | ||
| # ChatOpenAI(arg_1=config.get("arg_1"), arg_2=config.get("arg_2")) | ||
| # If a required argument like "arg_2" is missing in the config, config.get("arg_2") returns None, | ||
| # which may raise an error during instantiation if the argument has no default. | ||
| # | ||
| # To prevent this, we first fetch the default arguments for the given class from llm_info, | ||
| # then merge them with the user-provided config. This ensures all expected arguments are present, | ||
| # and the user’s config values take precedence over the defaults. | ||
| config_from_class_in_llm_info: Dict[str, Any] = self.get_chat_class_args(class_from_llm_config) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Combine user-specified config with the one in |
||
|
|
||
| # Merge the defaults from llm_info with the user-defined config, | ||
| # giving priority to values in config. | ||
| return self.overlayer.overlay(config_from_class_in_llm_info, config) | ||
|
|
||
| default_config: Dict[str, Any] = self.llm_infos.get("default_config") | ||
| use_config = self.overlayer.overlay(default_config, config) | ||
|
|
||
|
|
@@ -215,7 +218,7 @@ def create_full_llm_config(self, config: Dict[str, Any]) -> Dict[str, Any]: | |
|
|
||
| return full_config | ||
|
|
||
| def get_chat_class_args(self, chat_class_name: str, use_model_name: str) -> Dict[str, Any]: | ||
| def get_chat_class_args(self, chat_class_name: str, use_model_name: str = None) -> Dict[str, Any]: | ||
| """ | ||
| :param chat_class_name: string name of the chat class to look up. | ||
| :param use_model_name: the original model name that prompted the chat class lookups | ||
|
|
@@ -227,8 +230,13 @@ def get_chat_class_args(self, chat_class_name: str, use_model_name: str) -> Dict | |
| chat_classes: Dict[str, Any] = self.llm_infos.get("classes") | ||
| chat_class: Dict[str, Any] = chat_classes.get(chat_class_name) | ||
| if chat_class is None: | ||
| raise ValueError(f"llm info entry for {use_model_name} uses a 'class' of {chat_class_name} " | ||
| "which is not defined in the 'classes' table.") | ||
| if use_model_name is not None: | ||
| # If use_model_name is given, it must have a "class" in "classes" | ||
| raise ValueError(f"llm info entry for {use_model_name} uses a 'class' of {chat_class_name} " | ||
| "which is not defined in the 'classes' table.") | ||
| # If use_model_name is not provided and chat_class_name is not in "classes" in llm_info, | ||
| # it could be a user-specified langchain model class | ||
| return {} | ||
Noravee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Get the args from the chat class | ||
| args: Dict[str, Any] = chat_class.get("args") | ||
|
|
@@ -244,7 +252,8 @@ def get_chat_class_args(self, chat_class_name: str, use_model_name: str) -> Dict | |
| def create_base_chat_model(self, config: Dict[str, Any], | ||
| callbacks: List[BaseCallbackHandler] = None) -> BaseLanguageModel: | ||
| """ | ||
| Create a BaseLanguageModel from the fully-specified llm config. | ||
| Create a BaseLanguageModel from the fully-specified llm config either from standard LLM factory, | ||
| user-defined LLM factory, or user-specified langchain model class. | ||
| :param config: The fully specified llm config which is a product of | ||
| _create_full_llm_config() above. | ||
| :param callbacks: A list of BaseCallbackHandlers to add to the chat model. | ||
|
|
@@ -279,11 +288,54 @@ def create_base_chat_model(self, config: Dict[str, Any], | |
| # Let the next model have a crack | ||
| found_exception = exception | ||
|
|
||
| # Try resolving via 'class' in config if factories failed | ||
| class_path: str = config.get("class") | ||
| if llm is None and found_exception is not None and class_path: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The conditions to instantiate using
|
||
| llm = self.create_base_chat_model_from_user_class(class_path, config) | ||
| found_exception = None | ||
|
|
||
| if found_exception is not None: | ||
| raise found_exception | ||
|
|
||
| return llm | ||
|
|
||
| def create_base_chat_model_from_user_class( | ||
| self, | ||
| class_path: str, | ||
| config: Dict[str, Any], | ||
| callbacks: List[BaseCallbackHandler] = None | ||
| ) -> BaseLanguageModel: | ||
| """ | ||
| Create a BaseLanguageModel from the user-specified langchain model class. | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor the logic of instantiating using "class" and user-defined args to |
||
| :param class_path: A string in the form of <package>.<module>.<Class> | ||
| :param config: The fully specified llm config which is a product of | ||
| _create_full_llm_config() above. | ||
| :param callbacks: A list of BaseCallbackHandlers to add to the chat model. | ||
|
|
||
| :return: A BaseLanguageModel | ||
| """ | ||
|
|
||
| if not isinstance(class_path, str): | ||
| raise ValueError("'class' in llm_config must be a string") | ||
|
|
||
| # Resolve the 'class' | ||
| llm_class: Type[BaseLanguageModel] = ResolverUtil.create_class( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
| class_name=class_path, | ||
| class_name_source="agent network hocon file", | ||
| type_of_class=BaseLanguageModel | ||
| ) | ||
|
|
||
| # Copy the config, take 'class' out, and add callbacks | ||
| # Then unpack into llm constructor | ||
| user_config: Dict[str, Any] = config.copy() | ||
| user_config.pop("class") | ||
| user_config["callbacks"] = callbacks | ||
|
|
||
| # Check for invalid args and throw error if found | ||
| ArgumentValidator.check_invalid_args(llm_class, user_config) | ||
|
|
||
| return llm_class(**user_config) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LLM instantiation with args. |
||
|
|
||
| def get_max_prompt_tokens(self, config: Dict[str, Any]) -> int: | ||
| """ | ||
| :param config: A dictionary which describes which LLM to use. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # Copyright (C) 2023-2025 Cognizant Digital Business, Evolutionary AI. | ||
| # All Rights Reserved. | ||
| # Issued under the Academic Public License. | ||
| # | ||
| # You can be released from the terms, and requirements of the Academic Public | ||
| # License by purchasing a commercial license. | ||
| # Purchase of a commercial license is mandatory for any use of the | ||
| # neuro-san SDK Software in commercial settings. | ||
| # | ||
| # END COPYRIGHT | ||
|
|
||
| from typing import Any | ||
| from typing import Dict | ||
| from typing import Set | ||
| from typing import Type | ||
| from typing import Union | ||
| from types import MethodType | ||
| from inspect import isclass | ||
| from inspect import signature | ||
| from pydantic import BaseModel | ||
|
|
||
|
|
||
| class ArgumentValidator: | ||
| """ | ||
| A utility class for inspecting method and class arguments, particularly useful for | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor checking for invalid args to a utility class
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very nicely abstracted. |
||
| validating input against Pydantic `BaseModel` subclasses or callable method signatures. | ||
|
|
||
| This class provides static methods to: | ||
| - Validate that a dictionary of arguments matches the accepted parameters of a class or method. | ||
| - Extract all field names and aliases from a Pydantic BaseModel subclass. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def check_invalid_args(method_class: Union[Type, MethodType], args: Dict[str, Any]): | ||
| """ | ||
| Check for invalid arguments in a class constructor or method call. | ||
|
|
||
| :param method_class: The class or method to validate against. | ||
| :param args: Dictionary of argument to check. | ||
| :raises ValueError: If any argument name is not accepted by the method or class. | ||
| """ | ||
|
|
||
| # If method_class is a Pydantic BaseModel, get its field names and aliases | ||
| if isclass(method_class) and issubclass(method_class, BaseModel): | ||
| class_args_set: Set[str] = ArgumentValidator.get_base_model_args(method_class) | ||
| else: | ||
| # Otherwise, extract argument names from the function/method signature | ||
| class_args_set = set(signature(method_class).parameters.keys()) | ||
|
|
||
| # Get the argument keys provided by the user | ||
| args_set: Set[str] = set(args.keys()) | ||
|
|
||
| # Identify which arguments are not accepted by the method/class | ||
| invalid_args: Set[str] = args_set - class_args_set | ||
| if invalid_args: | ||
| raise ValueError( | ||
| f"Arguments {invalid_args} for '{method_class.__name__}' do not match any attributes " | ||
| "of the class or any arguments of the method." | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def get_base_model_args(base_model_class: Type[BaseModel]) -> Set[str]: | ||
| """ | ||
| Extract all field names and aliases from a Pydantic BaseModel class. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's good that you call out the BaseModel class as coming from Pydantic. |
||
|
|
||
| :param base_model_class: A class that inherits from `BaseModel`. | ||
| :return: A set of valid argument names, including both field names and aliases. | ||
| """ | ||
|
|
||
| fields_and_aliases: Set[str] = set() | ||
|
|
||
| # Check for field name and info | ||
| # field info includes attributes like "required", "default", "description", and "alias" | ||
| for field_name, field_info in base_model_class.model_fields.items(): | ||
| # Add field name to the set | ||
| fields_and_aliases.add(field_name) | ||
| if field_info.alias: | ||
| # If there is "alias" in the info add it to the set as well | ||
| fields_and_aliases.add(field_info.alias) | ||
|
|
||
| return fields_and_aliases | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only import
APIError.