Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,8 @@ async def create_resources(self, agent_name: str,
# DEF - Remove the arg if possible
_ = agent_name

# Create the list of callbacks to pass to the LLM ChatModel
callbacks: List[BaseCallbackHandler] = [
JournalingCallbackHandler(self.journal)
]
full_name: str = Origination.get_full_name_from_origin(self.origin)

# Consult the agent spec for level of verbosity as it pertains to callbacks.
agent_spec: Dict[str, Any] = self.tool_caller.get_agent_tool_spec()
verbose: Union[bool, str] = agent_spec.get("verbose", False)
if isinstance(verbose, str) and verbose.lower() in ("extra", "logging"):
# This particular class adds a *lot* of very detailed messages
# to the logs. Add this because some people are interested in it.
callbacks.append(LoggingCallbackHandler(self.logger))

# Now that we have a name, we can create an ErrorDetector for the output.
self.error_detector = ErrorDetector(full_name,
Expand All @@ -199,14 +188,12 @@ async def create_resources(self, agent_name: str,

prompt_template: ChatPromptTemplate = await self._create_prompt_template(instructions)

self.agent = self.create_agent_with_fallbacks(prompt_template, callbacks)
self.agent = self.create_agent_with_fallbacks(prompt_template)

def create_agent_with_fallbacks(self, prompt_template: ChatPromptTemplate,
callbacks: List[BaseCallbackHandler]) -> Agent:
def create_agent_with_fallbacks(self, prompt_template: ChatPromptTemplate) -> Agent:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove callbacks from the argument list of create_agent_with_fallbacks. This eliminates the need for
a callbacks parameter in both default_llm_factory and standard_langchain_llm_factory.
However, that cleanup is outside the scope of this PR and will be addressed separately.

"""
Creates an agent with potential fallback llms to use.
:param prompt_template: The ChatPromptTemplate to use for the agent
:param callbacks: The list of callbacks to use when creating any LLM via the factory
:return: An Agent (Runnable)
"""
# Initialize our return value
Expand All @@ -226,7 +213,7 @@ def create_agent_with_fallbacks(self, prompt_template: ChatPromptTemplate,
for index, fallback in enumerate(fallbacks):

# Create a model we might use.
one_llm: BaseLanguageModel = llm_factory.create_llm(fallback, callbacks=callbacks)
one_llm: BaseLanguageModel = llm_factory.create_llm(fallback)
one_agent: Agent = self.create_agent(prompt_template, one_llm)

if index == 0:
Expand Down Expand Up @@ -324,7 +311,7 @@ async def _create_base_tool(self, name: str) -> BaseTool:
if toolbox:
toolbox_factory: ContextTypeToolboxFactory = self.invocation_context.get_toolbox_factory()
try:
tool_from_toolbox = toolbox_factory.create_tool_from_toolbox(toolbox, agent_spec.get("args"))
tool_from_toolbox = toolbox_factory.create_tool_from_toolbox(toolbox, agent_spec.get("args"), name)
# If the tool from toolbox is base tool or list of base tool, return the tool as is
# since tool's definition and args schema are predefined in these the class of the tool.
if isinstance(tool_from_toolbox, BaseTool) or (
Expand Down Expand Up @@ -451,10 +438,28 @@ async def wait_on_run(self, run: Run, journal: Journal = None) -> Run:
"chat_history": previous_chat_history,
"input": self.recent_human_message
}

# Create the list of callbacks to pass when invoking
parent_origin: List[Dict[str, Any]] = self.get_origin()
base_journal: Journal = self.invocation_context.get_journal()
origination: Origination = self.invocation_context.get_origination()
callbacks: List[BaseCallbackHandler] = [
JournalingCallbackHandler(self.journal, base_journal, parent_origin, origination)
]
# Consult the agent spec for level of verbosity as it pertains to callbacks.
agent_spec: Dict[str, Any] = self.tool_caller.get_agent_tool_spec()
verbose: Union[bool, str] = agent_spec.get("verbose", False)
if isinstance(verbose, str) and verbose.lower() in ("extra", "logging"):
# This particular class adds a *lot* of very detailed messages
# to the logs. Add this because some people are interested in it.
callbacks.append(LoggingCallbackHandler(self.logger))

# Add callbacks as an invoke config
invoke_config = {
"configurable": {
"session_id": run.get_id()
}
},
"callbacks": callbacks
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use callbacks as invoke_config in wait_on_run() to have access to on_tool_start() and on_tool_run()

}

# Chat history is updated in write_message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,22 @@
from collections.abc import Sequence
from typing import Any
from typing import Dict
from typing import List

from pydantic import ConfigDict

from langchain_core.agents import AgentAction
from langchain_core.agents import AgentFinish
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.documents import Document
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.outputs import LLMResult
from langchain_core.outputs.chat_generation import ChatGeneration

from neuro_san.internals.journals.journal import Journal
from neuro_san.internals.journals.originating_journal import OriginatingJournal
from neuro_san.internals.messages.origination import Origination
from neuro_san.internals.messages.agent_message import AgentMessage


Expand Down Expand Up @@ -52,13 +57,39 @@ class JournalingCallbackHandler(AsyncCallbackHandler):
# a non-pydantic Journal as a member, we need to do this.
model_config = ConfigDict(arbitrary_types_allowed=True)

def __init__(self, journal: OriginatingJournal):
def __init__(
self,
calling_agent_journal: Journal,
base_journal: Journal,
parent_origin: List[Dict[str, Any]],
origination: Origination
):
"""
Constructor

:param journal: The journal to write messages to
:param calling_agent_journal: The journal of the calling agent
:param base_journal: The Journal instance that allows message reporting during the course of the AgentSession.
This is used to construct the langchain_tool_journal.
:param parent_origin: A List of origin dictionaries indicating the origin of the run
This is used to construct the langchain_tool_journal.
:param origination: The Origination instance carrying state about tool instantation
during the course of the AgentSession. This is used to construct the langchain_tool_journal.
"""
self.journal: OriginatingJournal = journal

# The calling-agent journal logs the execution flow from the perspective of the agent invoking the tool
# (e.g., MusicNerdPro). In contrast, the LangChain tool journal represents the tool's own execution
# context—similar to how coded tools like Accountant have their own journal tied to their run context.

# LangChain tools don’t instantiate their own RunContext, so they lack a dedicated journal by default.
# To maintain consistency with how other tools are tracked, we explicitly create a langchain_tool_journal
# when the tool starts. This ensures tool-specific inputs and outputs are captured independently,
# while still allowing the calling agent to log its own perspective.

self.calling_agent_journal: Journal = calling_agent_journal
self.base_journal: Journal = base_journal
self.parent_origin: List[Dict[str, Any]] = parent_origin
self.origination: Origination = origination
self.langchain_tool_journal: Journal = None

async def on_llm_end(self, response: LLMResult,
**kwargs: Any) -> None:
Expand All @@ -76,17 +107,73 @@ async def on_llm_end(self, response: LLMResult,
# Some AGENT messages that come from this source end up being dupes
# of AI messages that can come later.
# Use this method to put the message on hold for later comparison.
await self.journal.write_message_if_next_not_dupe(message)
await self.calling_agent_journal.write_message_if_next_not_dupe(message)

async def on_chain_end(self, outputs: Dict[str, Any],
**kwargs: Any) -> None:
# print(f"In on_chain_end() with {outputs}")
return

async def on_tool_end(self, output: Any,
**kwargs: Any) -> None:
# print(f"In on_tool_end() with {output}")
return
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
tags: List[str] = None,
inputs: Dict[str, Any] = None,
**kwargs: Any
) -> None:
"""
Callback triggered when a tool starts execution.

If the tool is identified as a LangChain tool (via the "langchain_tool" tag),
this method creates a journal entry containing the tool's input arguments,
origin metadata, and full tool name.

:param serialized: Serialized representation of the tool, including its name and description.
:param input_str: String representation of the tool's input.
:param tags: List of tags associated with the tool. Used to determine whether it is a LangChain tool.
:param inputs: Structured dictionary of input arguments
passed to the tool.
"""

if "langchain_tool" in tags:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if the tool is langchain's tool with tags.


# Extract tool name from the serialized data
agent_name: str = serialized.get("name")

# Build the origin path
origin: List[Dict[str, Any]] = self.origination.add_spec_name_to_origin(self.parent_origin, agent_name)
full_name: str = self.origination.get_full_name_from_origin(origin)

# Combine the original tool inputs with origin metadata
combined_args: Dict[str, Any] = inputs.copy()
combined_args["origin"] = origin
combined_args["origin_str"] = full_name

# Create a journal entry for this invocation and log the combined inputs
self.langchain_tool_journal = OriginatingJournal(self.base_journal, origin)
message: BaseMessage = AgentMessage(content=f"Received arguments {combined_args}")
await self.langchain_tool_journal.write_message(message)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating journal for langchain tool and write message.


async def on_tool_end(self, output: Any, tags: List[str] = None, **kwargs: Any) -> None:
"""
Callback triggered when a tool finishes execution.

If the tool is identified as a LangChain tool (via the "langchain_tool" tag),
this method logs the tool's output to both the calling agent's journal and the
LangChain tool's specific journal.

:param output: The result produced by the tool after execution.
:param tags: List of tags associated with the tool. Used to determine whether it is a LangChain tool.
"""

if "langchain_tool" in tags:
# Log the tool output to the calling agent's journal
await self.calling_agent_journal.write_message(AIMessage(content=output))

# Also log the tool output to the LangChain tool-specific journal
message: BaseMessage = AgentMessage(content=f"Got result: {output}")
await self.langchain_tool_journal.write_message(message)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Write message in the langchain journal with tool output. This can be used as an indicator that the tool is finished.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Noravee Since the UI already checks for "Got result:" for coded tool end events on the text field, this should work fine.

We also may want to consider adding a way to signify an end event for coded tools or langchain tools on the structure key instead, so that we don't have to do string parsing on the text field.

cc: @d1donlydfink

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do that in a separate PR @swensel for full consistency

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @d1donlydfink , that would be great.

cc: @Noravee


async def on_agent_action(self, action: AgentAction,
**kwargs: Any) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,25 @@ def load(self):
def create_tool_from_toolbox(
self,
tool_name: str,
user_args: Dict[str, Any] = None
user_args: Dict[str, Any] = None,
agent_name: str = None
) -> Union[BaseTool, Dict[str, Any], List[BaseTool]]:
"""
Resolves dependencies and instantiates the requested tool.

:param tool_name: The name of the tool to instantiate.
:param user_args: Arguments provided by the user, which override the config file.
:param agent_name: The name of the agent to prefix each BaseTool's name in BaseToolkit with,
ensuring tool names are unique and traceable to their agent.
:return: - Instantiated tool if "class" of tool_name points to a BaseTool class
- A list of tools if "class of "tool_name points to a BaseToolkit class.
- A dict of tool's "description" and "parameters" if tool_name points to a CodedTool
"""

# agent_name is required when the tool is used as an internal agent.
# However, tools from the toolbox could potentially be used as external tools,
# in which case agent_name may not be needed.

empty: Dict[str, Any] = {}

tool_info: Dict[str, Any] = self.toolbox_infos.get(tool_name)
Expand Down Expand Up @@ -168,8 +176,20 @@ def create_tool_from_toolbox(

# If the instantiated class has "get_tools()", assume it's a toolkit and return a list of tools
if hasattr(instance, "get_tools") and callable(instance.get_tools):
return instance.get_tools()

toolkit: List[BaseTool] = instance.get_tools()
for tool in toolkit:
if agent_name:
# Prefix the name of the agent to each tool
tool.name = f"{agent_name}_{tool.name}"
# Add "langchain_tool" tags so journal callback can idenitify it
tool.tags = ["langchain_tool"]
return toolkit

if agent_name:
# Replace langchain tool's name with agent name
instance.name = agent_name
# Add "langchain_tool" tags so journal callback can idenitify it
instance.tags = ["langchain_tool"]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Override langchain tool's name with agent_anme and add langchain_tool to tags as identifier.

return instance

def _resolve_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)


class TestBaseToolFactory:
class TestToolboxFactory:
"""Simplified test suite for ToolboxFactory."""

@pytest.fixture
Expand Down Expand Up @@ -54,6 +54,8 @@ def test_create_toolbox_returns_single_base_tool(self, factory):
mock_resolver.return_value = mock_tool_class

mock_instance = MagicMock(spec=BaseTool)
mock_instance.name = MagicMock(spec=str)
mock_instance.tags = MagicMock(spec=list)
mock_tool_class.return_value = mock_instance

tool = factory.create_tool_from_toolbox("test_tool", user_args)
Expand Down Expand Up @@ -90,7 +92,13 @@ def test_create_toolbox_with_toolkit_constructor(self, factory):
mock_resolver.return_value = mock_toolkit_class

mock_instance = MagicMock()
mock_tools = [MagicMock(spec=BaseTool), MagicMock(spec=BaseTool)]
mock_tool_1 = MagicMock(spec=BaseTool)
mock_tool_1.name = MagicMock(spec=str)
mock_tool_1.tags = MagicMock(spec=list)
mock_tool_2 = MagicMock(spec=BaseTool)
mock_tool_2.name = MagicMock(spec=str)
mock_tool_2.tags = MagicMock(spec=list)
mock_tools = [mock_tool_1, mock_tool_2]
mock_instance.get_tools.return_value = mock_tools
mock_toolkit_class.return_value = mock_instance

Expand Down Expand Up @@ -134,7 +142,11 @@ def test_create_toolbox_with_toolkit_class_method(self, factory):

# Mock get_tools() returning a list of tools
mock_tool_1 = MagicMock(spec=BaseTool)
mock_tool_1.name = MagicMock(spec=str)
mock_tool_1.tags = MagicMock(spec=list)
mock_tool_2 = MagicMock(spec=BaseTool)
mock_tool_2.name = MagicMock(spec=str)
mock_tool_2.tags = MagicMock(spec=list)
mock_toolkit_instance.get_tools.return_value = [mock_tool_1, mock_tool_2]

# Call the factory method
Expand Down