-
Notifications
You must be signed in to change notification settings - Fork 28
UN-3376 For langchain tools, make sure the UI knows that the communication between those nodes has been completed #337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
5efff05
3864dde
3904f2c
9a9e912
945afbd
5245ccc
fa1dd36
742e16d
29d5afd
18e4ea1
5ed6879
4f0ec73
7985e06
bba2542
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -168,19 +168,8 @@ async def create_resources(self, agent_name: str, | |
| # DEF - Remove the arg if possible | ||
| _ = agent_name | ||
|
|
||
| # Create the list of callbacks to pass to the LLM ChatModel | ||
| callbacks: List[BaseCallbackHandler] = [ | ||
| JournalingCallbackHandler(self.journal) | ||
| ] | ||
| full_name: str = Origination.get_full_name_from_origin(self.origin) | ||
|
|
||
| # Consult the agent spec for level of verbosity as it pertains to callbacks. | ||
| agent_spec: Dict[str, Any] = self.tool_caller.get_agent_tool_spec() | ||
| verbose: Union[bool, str] = agent_spec.get("verbose", False) | ||
| if isinstance(verbose, str) and verbose.lower() in ("extra", "logging"): | ||
| # This particular class adds a *lot* of very detailed messages | ||
| # to the logs. Add this because some people are interested in it. | ||
| callbacks.append(LoggingCallbackHandler(self.logger)) | ||
|
|
||
| # Now that we have a name, we can create an ErrorDetector for the output. | ||
| self.error_detector = ErrorDetector(full_name, | ||
|
|
@@ -199,10 +188,10 @@ async def create_resources(self, agent_name: str, | |
|
|
||
| prompt_template: ChatPromptTemplate = await self._create_prompt_template(instructions) | ||
|
|
||
| self.agent = self.create_agent_with_fallbacks(prompt_template, callbacks) | ||
| self.agent = self.create_agent_with_fallbacks(prompt_template) | ||
|
|
||
| def create_agent_with_fallbacks(self, prompt_template: ChatPromptTemplate, | ||
| callbacks: List[BaseCallbackHandler]) -> Agent: | ||
| callbacks: List[BaseCallbackHandler] = None) -> Agent: | ||
Noravee marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Creates an agent with potential fallback llms to use. | ||
| :param prompt_template: The ChatPromptTemplate to use for the agent | ||
|
|
@@ -324,7 +313,7 @@ async def _create_base_tool(self, name: str) -> BaseTool: | |
| if toolbox: | ||
| toolbox_factory: ContextTypeToolboxFactory = self.invocation_context.get_toolbox_factory() | ||
| try: | ||
| tool_from_toolbox = toolbox_factory.create_tool_from_toolbox(toolbox, agent_spec.get("args")) | ||
| tool_from_toolbox = toolbox_factory.create_tool_from_toolbox(toolbox, agent_spec.get("args"), name) | ||
Noravee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # If the tool from toolbox is base tool or list of base tool, return the tool as is | ||
| # since tool's definition and args schema are predefined in these the class of the tool. | ||
| if isinstance(tool_from_toolbox, BaseTool) or ( | ||
|
|
@@ -451,10 +440,28 @@ async def wait_on_run(self, run: Run, journal: Journal = None) -> Run: | |
| "chat_history": previous_chat_history, | ||
| "input": self.recent_human_message | ||
| } | ||
|
|
||
| # Create the list of callbacks to pass when invoking | ||
| parent_origin: List[Dict[str, Any]] = self.get_origin() | ||
| base_journal: Journal = self.invocation_context.get_journal() | ||
| origination: Origination = self.invocation_context.get_origination() | ||
| callbacks: List[BaseCallbackHandler] = [ | ||
| JournalingCallbackHandler(self.journal, base_journal, parent_origin, origination) | ||
| ] | ||
| # Consult the agent spec for level of verbosity as it pertains to callbacks. | ||
| agent_spec: Dict[str, Any] = self.tool_caller.get_agent_tool_spec() | ||
| verbose: Union[bool, str] = agent_spec.get("verbose", False) | ||
| if isinstance(verbose, str) and verbose.lower() in ("extra", "logging"): | ||
| # This particular class adds a *lot* of very detailed messages | ||
| # to the logs. Add this because some people are interested in it. | ||
| callbacks.append(LoggingCallbackHandler(self.logger)) | ||
|
|
||
| # Add callbacks as an invoke config | ||
| invoke_config = { | ||
| "configurable": { | ||
| "session_id": run.get_id() | ||
| } | ||
| }, | ||
| "callbacks": callbacks | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use callbacks as |
||
| } | ||
|
|
||
| # Chat history is updated in write_message | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,17 +12,22 @@ | |
| from collections.abc import Sequence | ||
| from typing import Any | ||
| from typing import Dict | ||
| from typing import List | ||
|
|
||
| from pydantic import ConfigDict | ||
|
|
||
| from langchain_core.agents import AgentAction | ||
| from langchain_core.agents import AgentFinish | ||
| from langchain_core.callbacks.base import AsyncCallbackHandler | ||
| from langchain_core.documents import Document | ||
| from langchain_core.messages.ai import AIMessage | ||
| from langchain_core.messages.base import BaseMessage | ||
| from langchain_core.outputs import LLMResult | ||
| from langchain_core.outputs.chat_generation import ChatGeneration | ||
|
|
||
| from neuro_san.internals.journals.journal import Journal | ||
| from neuro_san.internals.journals.originating_journal import OriginatingJournal | ||
| from neuro_san.internals.messages.origination import Origination | ||
| from neuro_san.internals.messages.agent_message import AgentMessage | ||
|
|
||
|
|
||
|
|
@@ -52,13 +57,29 @@ class JournalingCallbackHandler(AsyncCallbackHandler): | |
| # a non-pydantic Journal as a member, we need to do this. | ||
| model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
|
||
| def __init__(self, journal: OriginatingJournal): | ||
| def __init__( | ||
| self, | ||
| calling_agent_journal: Journal, | ||
| base_journal: Journal, | ||
| parent_origin: List[Dict[str, Any]], | ||
| origination: Origination | ||
| ): | ||
| """ | ||
| Constructor | ||
|
|
||
| :param journal: The journal to write messages to | ||
| :param calling_agent_journal: The journal of the calling agent | ||
| :param base_journal: The Journal instance that allows message reporting during the course of the AgentSession. | ||
| This is used to construct the langchain_tool_journal. | ||
| :param parent_origin: A List of origin dictionaries indicating the origin of the run | ||
| This is used to construct the langchain_tool_journal. | ||
| :param origination: The Origination instance carrying state about tool instantation | ||
| during the course of the AgentSession. This is used to construct the langchain_tool_journal. | ||
| """ | ||
| self.journal: OriginatingJournal = journal | ||
| self.calling_agent_journal: Journal = calling_agent_journal | ||
| self.base_journal: Journal = base_journal | ||
Noravee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.parent_origin: List[Dict[str, Any]] = parent_origin | ||
| self.origination: Origination = origination | ||
| self.langchain_tool_journal: Journal = None | ||
|
|
||
| async def on_llm_end(self, response: LLMResult, | ||
| **kwargs: Any) -> None: | ||
|
|
@@ -76,17 +97,73 @@ async def on_llm_end(self, response: LLMResult, | |
| # Some AGENT messages that come from this source end up being dupes | ||
| # of AI messages that can come later. | ||
| # Use this method to put the message on hold for later comparison. | ||
| await self.journal.write_message_if_next_not_dupe(message) | ||
| await self.calling_agent_journal.write_message_if_next_not_dupe(message) | ||
|
|
||
| async def on_chain_end(self, outputs: Dict[str, Any], | ||
| **kwargs: Any) -> None: | ||
| # print(f"In on_chain_end() with {outputs}") | ||
| return | ||
|
|
||
| async def on_tool_end(self, output: Any, | ||
| **kwargs: Any) -> None: | ||
| # print(f"In on_tool_end() with {output}") | ||
| return | ||
| async def on_tool_start( | ||
| self, | ||
| serialized: Dict[str, Any], | ||
| input_str: str, | ||
| tags: List[str] = None, | ||
| inputs: Dict[str, Any] = None, | ||
| **kwargs: Any | ||
| ) -> None: | ||
| """ | ||
| Callback triggered when a tool starts execution. | ||
|
|
||
| If the tool is identified as a LangChain tool (via the "langchain_tool" tag), | ||
| this method creates a journal entry containing the tool's input arguments, | ||
| origin metadata, and full tool name. | ||
|
|
||
| :param serialized: Serialized representation of the tool, including its name and description. | ||
| :param input_str: String representation of the tool's input. | ||
| :param tags: List of tags associated with the tool. Used to determine whether it is a LangChain tool. | ||
| :param inputs: Structured dictionary of input arguments | ||
| passed to the tool. | ||
| """ | ||
|
|
||
| if "langchain_tool" in tags: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check if the tool is langchain's tool with |
||
|
|
||
| # Extract tool name from the serialized data | ||
| agent_name: str = serialized.get("name") | ||
|
|
||
| # Build the origin path | ||
| origin: List[Dict[str, Any]] = self.origination.add_spec_name_to_origin(self.parent_origin, agent_name) | ||
| full_name: str = self.origination.get_full_name_from_origin(origin) | ||
|
|
||
| # Combine the original tool inputs with origin metadata | ||
| combined_args: Dict[str, Any] = inputs.copy() | ||
| combined_args["origin"] = origin | ||
| combined_args["origin_str"] = full_name | ||
|
|
||
| # Create a journal entry for this invocation and log the combined inputs | ||
| self.langchain_tool_journal = OriginatingJournal(self.base_journal, origin) | ||
| message: BaseMessage = AgentMessage(content=f"Received arguments {combined_args}") | ||
| await self.langchain_tool_journal.write_message(message) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Creating journal for langchain tool and write message. |
||
|
|
||
| async def on_tool_end(self, output: Any, tags: List[str] = None, **kwargs: Any) -> None: | ||
| """ | ||
| Callback triggered when a tool finishes execution. | ||
|
|
||
| If the tool is identified as a LangChain tool (via the "langchain_tool" tag), | ||
| this method logs the tool's output to both the calling agent's journal and the | ||
| LangChain tool's specific journal. | ||
|
|
||
| :param output: The result produced by the tool after execution. | ||
| :param tags: List of tags associated with the tool. Used to determine whether it is a LangChain tool. | ||
| """ | ||
|
|
||
| if "langchain_tool" in tags: | ||
| # Log the tool output to the calling agent's journal | ||
| await self.calling_agent_journal.write_message(AIMessage(content=output)) | ||
|
|
||
| # Also log the tool output to the LangChain tool-specific journal | ||
| message: BaseMessage = AgentMessage(content=f"Got result: {output}") | ||
| await self.langchain_tool_journal.write_message(message) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Write message in the langchain journal with tool output. This can be used as an indicator that the tool is finished.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Noravee Since the UI already checks for We also may want to consider adding a way to signify an end event for coded tools or langchain tools on the cc: @d1donlydfink
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can do that in a separate PR @swensel for full consistency
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @d1donlydfink , that would be great. cc: @Noravee |
||
|
|
||
| async def on_agent_action(self, action: AgentAction, | ||
| **kwargs: Any) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -116,13 +116,16 @@ def load(self): | |
| def create_tool_from_toolbox( | ||
| self, | ||
| tool_name: str, | ||
| user_args: Dict[str, Any] = None | ||
| user_args: Dict[str, Any] = None, | ||
| agent_name: str = None | ||
Noravee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) -> Union[BaseTool, Dict[str, Any], List[BaseTool]]: | ||
| """ | ||
| Resolves dependencies and instantiates the requested tool. | ||
|
|
||
| :param tool_name: The name of the tool to instantiate. | ||
| :param user_args: Arguments provided by the user, which override the config file. | ||
| :param agent_name: The name of the agent to prefix each BaseTool's name in BaseToolkit with, | ||
| ensuring tool names are unique and traceable to their agent. | ||
| :return: - Instantiated tool if "class" of tool_name points to a BaseTool class | ||
| - A list of tools if "class of "tool_name points to a BaseToolkit class. | ||
| - A dict of tool's "description" and "parameters" if tool_name points to a CodedTool | ||
|
|
@@ -168,8 +171,18 @@ def create_tool_from_toolbox( | |
|
|
||
| # If the instantiated class has "get_tools()", assume it's a toolkit and return a list of tools | ||
| if hasattr(instance, "get_tools") and callable(instance.get_tools): | ||
| return instance.get_tools() | ||
|
|
||
| toolkit: List[BaseTool] = instance.get_tools() | ||
| for tool in toolkit: | ||
| # Prefix the name of the agent to each tool | ||
| tool.name = f"{agent_name}_{tool.name}" | ||
|
||
| # Add "langchain_tool" tags so journal callback can idenitify it | ||
| tool.tags = ["langchain_tool"] | ||
| return toolkit | ||
|
|
||
| # Replace langchain tool's name with agent name | ||
| instance.name = agent_name | ||
| # Add "langchain_tool" tags so journal callback can idenitify it | ||
| instance.tags = ["langchain_tool"] | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Override langchain tool's name with |
||
| return instance | ||
|
|
||
| def _resolve_args(self, args: Dict[str, Any]) -> Dict[str, Any]: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.