Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion nemoguardrails/integrations/langchain/runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any, List, Optional

from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
from langchain_core.runnables import Runnable
from langchain_core.runnables.config import RunnableConfig
Expand Down Expand Up @@ -139,6 +139,8 @@ def _transform_input_to_rails_format(self, _input):
messages.append({"role": "assistant", "content": msg.content})
elif isinstance(msg, HumanMessage):
messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, SystemMessage):
messages.append({"role": "system", "content": msg.content})
elif isinstance(_input, StringPromptValue):
messages.append({"role": "user", "content": _input.text})
elif isinstance(_input, dict):
Expand Down
60 changes: 59 additions & 1 deletion tests/test_runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from typing import List, Optional

import pytest
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import (
Runnable,
Expand Down Expand Up @@ -153,6 +154,63 @@ def test_dict_messages_in_dict_messages_out():
assert result["output"] == {"role": "assistant", "content": "Paris."}


def test_dict_system_message_in_dict_messages_out():
"""Tests that SystemMessage is correctly handled."""
llm = FakeLLM(
responses=[
"Okay.",
]
)
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

original_generate_async = model_with_rails.rails.generate_async
messages_passed = None

async def mock_generate_async(*args, **kwargs):
nonlocal messages_passed
messages_passed = kwargs.get("messages")
return await original_generate_async(*args, **kwargs)

model_with_rails.rails.generate_async = mock_generate_async

result = model_with_rails.invoke(
input={
"input": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Got it?"},
]
}
)

assert isinstance(result, dict)
assert result["output"] == {"role": "assistant", "content": "Okay."}
assert messages_passed == [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Got it?"},
]


def test_list_system_message_in_list_messages_out():
"""Tests that SystemMessage is correctly handled when input is ChatPromptValue."""
llm_response = "Intent: user asks question"
llm = FakeLLM(responses=[llm_response])

config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config)

chain = model_with_rails | llm

input_messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Got it?"),
]
result = chain.invoke(input=ChatPromptValue(messages=input_messages))

assert isinstance(result, AIMessage)
assert result.content == llm_response


def test_context_passing():
llm = FakeLLM(
responses=[
Expand Down