diff --git a/neuro_san/internals/run_context/langchain/core/langchain_run_context.py b/neuro_san/internals/run_context/langchain/core/langchain_run_context.py index 324167f21..89c4f0e17 100644 --- a/neuro_san/internals/run_context/langchain/core/langchain_run_context.py +++ b/neuro_san/internals/run_context/langchain/core/langchain_run_context.py @@ -168,19 +168,8 @@ async def create_resources(self, agent_name: str, # DEF - Remove the arg if possible _ = agent_name - # Create the list of callbacks to pass to the LLM ChatModel - callbacks: List[BaseCallbackHandler] = [ - JournalingCallbackHandler(self.journal) - ] full_name: str = Origination.get_full_name_from_origin(self.origin) - - # Consult the agent spec for level of verbosity as it pertains to callbacks. agent_spec: Dict[str, Any] = self.tool_caller.get_agent_tool_spec() - verbose: Union[bool, str] = agent_spec.get("verbose", False) - if isinstance(verbose, str) and verbose.lower() in ("extra", "logging"): - # This particular class adds a *lot* of very detailed messages - # to the logs. Add this because some people are interested in it. - callbacks.append(LoggingCallbackHandler(self.logger)) # Now that we have a name, we can create an ErrorDetector for the output. self.error_detector = ErrorDetector(full_name, @@ -199,14 +188,12 @@ async def create_resources(self, agent_name: str, prompt_template: ChatPromptTemplate = await self._create_prompt_template(instructions) - self.agent = self.create_agent_with_fallbacks(prompt_template, callbacks) + self.agent = self.create_agent_with_fallbacks(prompt_template) - def create_agent_with_fallbacks(self, prompt_template: ChatPromptTemplate, - callbacks: List[BaseCallbackHandler]) -> Agent: + def create_agent_with_fallbacks(self, prompt_template: ChatPromptTemplate) -> Agent: """ Creates an agent with potential fallback llms to use. :param prompt_template: The ChatPromptTemplate to use for the agent - :param callbacks: The list of callbacks to use when creating any LLM via the factory :return: An Agent (Runnable) """ # Initialize our return value @@ -226,7 +213,7 @@ def create_agent_with_fallbacks(self, prompt_template: ChatPromptTemplate, for index, fallback in enumerate(fallbacks): # Create a model we might use. - one_llm: BaseLanguageModel = llm_factory.create_llm(fallback, callbacks=callbacks) + one_llm: BaseLanguageModel = llm_factory.create_llm(fallback) one_agent: Agent = self.create_agent(prompt_template, one_llm) if index == 0: @@ -324,7 +311,7 @@ async def _create_base_tool(self, name: str) -> BaseTool: if toolbox: toolbox_factory: ContextTypeToolboxFactory = self.invocation_context.get_toolbox_factory() try: - tool_from_toolbox = toolbox_factory.create_tool_from_toolbox(toolbox, agent_spec.get("args")) + tool_from_toolbox = toolbox_factory.create_tool_from_toolbox(toolbox, agent_spec.get("args"), name) # If the tool from toolbox is base tool or list of base tool, return the tool as is # since tool's definition and args schema are predefined in these the class of the tool. if isinstance(tool_from_toolbox, BaseTool) or ( @@ -451,10 +438,28 @@ async def wait_on_run(self, run: Run, journal: Journal = None) -> Run: "chat_history": previous_chat_history, "input": self.recent_human_message } + + # Create the list of callbacks to pass when invoking + parent_origin: List[Dict[str, Any]] = self.get_origin() + base_journal: Journal = self.invocation_context.get_journal() + origination: Origination = self.invocation_context.get_origination() + callbacks: List[BaseCallbackHandler] = [ + JournalingCallbackHandler(self.journal, base_journal, parent_origin, origination) + ] + # Consult the agent spec for level of verbosity as it pertains to callbacks. + agent_spec: Dict[str, Any] = self.tool_caller.get_agent_tool_spec() + verbose: Union[bool, str] = agent_spec.get("verbose", False) + if isinstance(verbose, str) and verbose.lower() in ("extra", "logging"): + # This particular class adds a *lot* of very detailed messages + # to the logs. Add this because some people are interested in it. + callbacks.append(LoggingCallbackHandler(self.logger)) + + # Add callbacks as an invoke config invoke_config = { "configurable": { "session_id": run.get_id() - } + }, + "callbacks": callbacks } # Chat history is updated in write_message diff --git a/neuro_san/internals/run_context/langchain/journaling/journaling_callback_handler.py b/neuro_san/internals/run_context/langchain/journaling/journaling_callback_handler.py index 2c2a9807f..d910311c0 100644 --- a/neuro_san/internals/run_context/langchain/journaling/journaling_callback_handler.py +++ b/neuro_san/internals/run_context/langchain/journaling/journaling_callback_handler.py @@ -12,6 +12,7 @@ from collections.abc import Sequence from typing import Any from typing import Dict +from typing import List from pydantic import ConfigDict @@ -19,10 +20,14 @@ from langchain_core.agents import AgentFinish from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.documents import Document +from langchain_core.messages.ai import AIMessage +from langchain_core.messages.base import BaseMessage from langchain_core.outputs import LLMResult from langchain_core.outputs.chat_generation import ChatGeneration +from neuro_san.internals.journals.journal import Journal from neuro_san.internals.journals.originating_journal import OriginatingJournal +from neuro_san.internals.messages.origination import Origination from neuro_san.internals.messages.agent_message import AgentMessage @@ -52,13 +57,39 @@ class JournalingCallbackHandler(AsyncCallbackHandler): # a non-pydantic Journal as a member, we need to do this. model_config = ConfigDict(arbitrary_types_allowed=True) - def __init__(self, journal: OriginatingJournal): + def __init__( + self, + calling_agent_journal: Journal, + base_journal: Journal, + parent_origin: List[Dict[str, Any]], + origination: Origination + ): """ Constructor - :param journal: The journal to write messages to + :param calling_agent_journal: The journal of the calling agent + :param base_journal: The Journal instance that allows message reporting during the course of the AgentSession. + This is used to construct the langchain_tool_journal. + :param parent_origin: A List of origin dictionaries indicating the origin of the run + This is used to construct the langchain_tool_journal. + :param origination: The Origination instance carrying state about tool instantation + during the course of the AgentSession. This is used to construct the langchain_tool_journal. """ - self.journal: OriginatingJournal = journal + + # The calling-agent journal logs the execution flow from the perspective of the agent invoking the tool + # (e.g., MusicNerdPro). In contrast, the LangChain tool journal represents the tool's own execution + # context—similar to how coded tools like Accountant have their own journal tied to their run context. + + # LangChain tools don’t instantiate their own RunContext, so they lack a dedicated journal by default. + # To maintain consistency with how other tools are tracked, we explicitly create a langchain_tool_journal + # when the tool starts. This ensures tool-specific inputs and outputs are captured independently, + # while still allowing the calling agent to log its own perspective. + + self.calling_agent_journal: Journal = calling_agent_journal + self.base_journal: Journal = base_journal + self.parent_origin: List[Dict[str, Any]] = parent_origin + self.origination: Origination = origination + self.langchain_tool_journal: Journal = None async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: @@ -76,17 +107,73 @@ async def on_llm_end(self, response: LLMResult, # Some AGENT messages that come from this source end up being dupes # of AI messages that can come later. # Use this method to put the message on hold for later comparison. - await self.journal.write_message_if_next_not_dupe(message) + await self.calling_agent_journal.write_message_if_next_not_dupe(message) async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: # print(f"In on_chain_end() with {outputs}") return - async def on_tool_end(self, output: Any, - **kwargs: Any) -> None: - # print(f"In on_tool_end() with {output}") - return + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + tags: List[str] = None, + inputs: Dict[str, Any] = None, + **kwargs: Any + ) -> None: + """ + Callback triggered when a tool starts execution. + + If the tool is identified as a LangChain tool (via the "langchain_tool" tag), + this method creates a journal entry containing the tool's input arguments, + origin metadata, and full tool name. + + :param serialized: Serialized representation of the tool, including its name and description. + :param input_str: String representation of the tool's input. + :param tags: List of tags associated with the tool. Used to determine whether it is a LangChain tool. + :param inputs: Structured dictionary of input arguments + passed to the tool. + """ + + if "langchain_tool" in tags: + + # Extract tool name from the serialized data + agent_name: str = serialized.get("name") + + # Build the origin path + origin: List[Dict[str, Any]] = self.origination.add_spec_name_to_origin(self.parent_origin, agent_name) + full_name: str = self.origination.get_full_name_from_origin(origin) + + # Combine the original tool inputs with origin metadata + combined_args: Dict[str, Any] = inputs.copy() + combined_args["origin"] = origin + combined_args["origin_str"] = full_name + + # Create a journal entry for this invocation and log the combined inputs + self.langchain_tool_journal = OriginatingJournal(self.base_journal, origin) + message: BaseMessage = AgentMessage(content=f"Received arguments {combined_args}") + await self.langchain_tool_journal.write_message(message) + + async def on_tool_end(self, output: Any, tags: List[str] = None, **kwargs: Any) -> None: + """ + Callback triggered when a tool finishes execution. + + If the tool is identified as a LangChain tool (via the "langchain_tool" tag), + this method logs the tool's output to both the calling agent's journal and the + LangChain tool's specific journal. + + :param output: The result produced by the tool after execution. + :param tags: List of tags associated with the tool. Used to determine whether it is a LangChain tool. + """ + + if "langchain_tool" in tags: + # Log the tool output to the calling agent's journal + await self.calling_agent_journal.write_message(AIMessage(content=output)) + + # Also log the tool output to the LangChain tool-specific journal + message: BaseMessage = AgentMessage(content=f"Got result: {output}") + await self.langchain_tool_journal.write_message(message) async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None: diff --git a/neuro_san/internals/run_context/langchain/toolbox/toolbox_factory.py b/neuro_san/internals/run_context/langchain/toolbox/toolbox_factory.py index 1e6db816d..fa39216e2 100644 --- a/neuro_san/internals/run_context/langchain/toolbox/toolbox_factory.py +++ b/neuro_san/internals/run_context/langchain/toolbox/toolbox_factory.py @@ -116,17 +116,25 @@ def load(self): def create_tool_from_toolbox( self, tool_name: str, - user_args: Dict[str, Any] = None + user_args: Dict[str, Any] = None, + agent_name: str = None ) -> Union[BaseTool, Dict[str, Any], List[BaseTool]]: """ Resolves dependencies and instantiates the requested tool. :param tool_name: The name of the tool to instantiate. :param user_args: Arguments provided by the user, which override the config file. + :param agent_name: The name of the agent to prefix each BaseTool's name in BaseToolkit with, + ensuring tool names are unique and traceable to their agent. :return: - Instantiated tool if "class" of tool_name points to a BaseTool class - A list of tools if "class of "tool_name points to a BaseToolkit class. - A dict of tool's "description" and "parameters" if tool_name points to a CodedTool """ + + # agent_name is required when the tool is used as an internal agent. + # However, tools from the toolbox could potentially be used as external tools, + # in which case agent_name may not be needed. + empty: Dict[str, Any] = {} tool_info: Dict[str, Any] = self.toolbox_infos.get(tool_name) @@ -168,8 +176,20 @@ def create_tool_from_toolbox( # If the instantiated class has "get_tools()", assume it's a toolkit and return a list of tools if hasattr(instance, "get_tools") and callable(instance.get_tools): - return instance.get_tools() - + toolkit: List[BaseTool] = instance.get_tools() + for tool in toolkit: + if agent_name: + # Prefix the name of the agent to each tool + tool.name = f"{agent_name}_{tool.name}" + # Add "langchain_tool" tags so journal callback can idenitify it + tool.tags = ["langchain_tool"] + return toolkit + + if agent_name: + # Replace langchain tool's name with agent name + instance.name = agent_name + # Add "langchain_tool" tags so journal callback can idenitify it + instance.tags = ["langchain_tool"] return instance def _resolve_args(self, args: Dict[str, Any]) -> Dict[str, Any]: diff --git a/tests/neuro_san/internals/run_context/langchain/toolbox/test_toolbox_factory.py b/tests/neuro_san/internals/run_context/langchain/toolbox/test_toolbox_factory.py index 0579d450c..2c38ab91c 100644 --- a/tests/neuro_san/internals/run_context/langchain/toolbox/test_toolbox_factory.py +++ b/tests/neuro_san/internals/run_context/langchain/toolbox/test_toolbox_factory.py @@ -25,7 +25,7 @@ ) -class TestBaseToolFactory: +class TestToolboxFactory: """Simplified test suite for ToolboxFactory.""" @pytest.fixture @@ -54,6 +54,8 @@ def test_create_toolbox_returns_single_base_tool(self, factory): mock_resolver.return_value = mock_tool_class mock_instance = MagicMock(spec=BaseTool) + mock_instance.name = MagicMock(spec=str) + mock_instance.tags = MagicMock(spec=list) mock_tool_class.return_value = mock_instance tool = factory.create_tool_from_toolbox("test_tool", user_args) @@ -90,7 +92,13 @@ def test_create_toolbox_with_toolkit_constructor(self, factory): mock_resolver.return_value = mock_toolkit_class mock_instance = MagicMock() - mock_tools = [MagicMock(spec=BaseTool), MagicMock(spec=BaseTool)] + mock_tool_1 = MagicMock(spec=BaseTool) + mock_tool_1.name = MagicMock(spec=str) + mock_tool_1.tags = MagicMock(spec=list) + mock_tool_2 = MagicMock(spec=BaseTool) + mock_tool_2.name = MagicMock(spec=str) + mock_tool_2.tags = MagicMock(spec=list) + mock_tools = [mock_tool_1, mock_tool_2] mock_instance.get_tools.return_value = mock_tools mock_toolkit_class.return_value = mock_instance @@ -134,7 +142,11 @@ def test_create_toolbox_with_toolkit_class_method(self, factory): # Mock get_tools() returning a list of tools mock_tool_1 = MagicMock(spec=BaseTool) + mock_tool_1.name = MagicMock(spec=str) + mock_tool_1.tags = MagicMock(spec=list) mock_tool_2 = MagicMock(spec=BaseTool) + mock_tool_2.name = MagicMock(spec=str) + mock_tool_2.tags = MagicMock(spec=list) mock_toolkit_instance.get_tools.return_value = [mock_tool_1, mock_tool_2] # Call the factory method