Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 48 additions & 25 deletions src/phoenix/server/api/helpers/dataset_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from collections.abc import Mapping
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Sequence, TypedDict

from openinference.semconv.trace import (
MessageAttributes,
Expand All @@ -10,6 +10,7 @@
ToolAttributes,
ToolCallAttributes,
)
from typing_extensions import NotRequired

from phoenix.db.models import Span
from phoenix.trace.attributes import get_attribute_value
Expand Down Expand Up @@ -193,32 +194,54 @@ def _get_generic_io_value(
return {}


def _get_message(message: Mapping[str, Any]) -> dict[str, Any]:
class _Function(TypedDict):
name: str
arguments: Any


class _ToolCall(TypedDict):
function: _Function


class _Message(TypedDict):
role: str
content: NotRequired[Any]
name: NotRequired[str]
function_call: NotRequired[_Function]
tool_calls: NotRequired[Sequence[_ToolCall]]


def _get_message(message: Mapping[str, Any]) -> _Message:
content = get_attribute_value(message, MESSAGE_CONTENT)
name = get_attribute_value(message, MESSAGE_NAME)
function_call_name = get_attribute_value(message, MESSAGE_FUNCTION_CALL_NAME)
function_call_arguments = get_attribute_value(message, MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON)
function_call = (
{"name": function_call_name, "arguments": function_call_arguments}
if function_call_name is not None or function_call_arguments is not None
else None
)
tool_calls = [
{
"function": {
"name": get_attribute_value(tool_call, TOOL_CALL_FUNCTION_NAME),
"arguments": get_attribute_value(tool_call, TOOL_CALL_FUNCTION_ARGUMENTS_JSON),
}
}
for tool_call in get_attribute_value(message, MESSAGE_TOOL_CALLS) or ()
]
return {
"role": get_attribute_value(message, MESSAGE_ROLE),
**({"content": content} if content is not None else {}),
**({"name": name} if name is not None else {}),
**({"function_call": function_call} if function_call is not None else {}),
**({"tool_calls": tool_calls} if tool_calls else {}),
}
function_call: _Function | None = None
if function_call_name := get_attribute_value(message, MESSAGE_FUNCTION_CALL_NAME):
arguments = get_attribute_value(message, MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON)
function_call_arguments = _safely_json_decode(arguments)
if function_call_arguments is None:
function_call_arguments = arguments
function_call = _Function(name=function_call_name, arguments=function_call_arguments)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JSON null arguments stay as strings

Low Severity

_get_message treats a decoded None as a parse failure and falls back to the original string. Because _safely_json_decode("null") returns None, valid JSON null in MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON or TOOL_CALL_FUNCTION_ARGUMENTS_JSON is emitted as "null" instead of None.

Additional Locations (1)

Fix in Cursor Fix in Web

tool_calls = []
for tool_call in get_attribute_value(message, MESSAGE_TOOL_CALLS) or ():
if function_name := get_attribute_value(tool_call, TOOL_CALL_FUNCTION_NAME):
arguments = get_attribute_value(tool_call, TOOL_CALL_FUNCTION_ARGUMENTS_JSON)
function_arguments = _safely_json_decode(arguments)
if function_arguments is None:
function_arguments = arguments
function = _Function(name=function_name, arguments=function_arguments)
tool_calls.append(_ToolCall(function=function))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tool calls dropped without function names

Medium Severity

The new tool_calls loop only appends entries when TOOL_CALL_FUNCTION_NAME is truthy. Tool calls that have TOOL_CALL_FUNCTION_ARGUMENTS_JSON but no name are now discarded, causing loss of captured tool-call data in dataset_helpers.py.

Fix in Cursor Fix in Web

role = get_attribute_value(message, MESSAGE_ROLE) or "assistant"
content = get_attribute_value(message, MESSAGE_CONTENT)
msg = _Message(role=role)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing role now mislabeled as assistant

Medium Severity

_get_message now defaults role to "assistant" when MESSAGE_ROLE is absent. This changes previously missing/unknown roles into concrete assistant messages, which can misrepresent conversation data produced by dataset_helpers.py.

Fix in Cursor Fix in Web

if content is not None:
msg["content"] = content
if name is not None:
msg["name"] = name
if function_call is not None:
msg["function_call"] = function_call
if tool_calls:
msg["tool_calls"] = tool_calls
return msg


def _parse_retrieval_documents(retrieval_documents: Any) -> Optional[list[dict[str, Any]]]:
Expand Down
126 changes: 123 additions & 3 deletions tests/unit/server/api/helpers/test_dataset_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from phoenix.db.models import Span
from phoenix.server.api.helpers.dataset_helpers import (
_get_message,
get_dataset_example_input,
get_dataset_example_output,
get_experiment_example_output,
Expand Down Expand Up @@ -116,14 +117,14 @@
{"content": "user-message", "role": "user"},
{
"role": "assistant",
"function_call": {"name": "add", "arguments": '{"a": 363, "b": 42}'},
"function_call": {"name": "add", "arguments": {"a": 363, "b": 42}},
},
{"content": "user-message", "role": "user"},
{
"role": "assistant",
"tool_calls": [
{"function": {"name": "multiply", "arguments": '{"a": 121, "b": 3}'}},
{"function": {"name": "add", "arguments": '{"a": 363, "b": 42}'}},
{"function": {"name": "multiply", "arguments": {"a": 121, "b": 3}}},
{"function": {"name": "add", "arguments": {"a": 363, "b": 42}}},
],
},
],
Expand Down Expand Up @@ -251,6 +252,125 @@ def test_get_dataset_example_input(span: Span, expected_input_value: dict[str, A
assert expected_input_value == input_value


@pytest.mark.parametrize(
"message, expected",
[
pytest.param(
unflatten([(MESSAGE_ROLE, "user")]),
{"role": "user"},
id="role_only",
),
pytest.param(
unflatten(
[
(MESSAGE_ROLE, "assistant"),
(MESSAGE_CONTENT, "hi"),
(MESSAGE_NAME, "bot"),
]
),
{"role": "assistant", "content": "hi", "name": "bot"},
id="content_and_name",
),
pytest.param(
unflatten(
[
(MESSAGE_ROLE, "assistant"),
(MESSAGE_FUNCTION_CALL_NAME, "add"),
(MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, json.dumps({"a": 1, "b": 2})),
]
),
{
"role": "assistant",
"function_call": {"name": "add", "arguments": {"a": 1, "b": 2}},
},
id="function_call_json_string_deserialized",
),
pytest.param(
unflatten(
[
(MESSAGE_ROLE, "assistant"),
(MESSAGE_FUNCTION_CALL_NAME, "add"),
(MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, "not valid json"),
]
),
{
"role": "assistant",
"function_call": {"name": "add", "arguments": "not valid json"},
},
id="function_call_invalid_json_left_as_string",
),
pytest.param(
unflatten(
[
(MESSAGE_ROLE, "assistant"),
(
MESSAGE_TOOL_CALLS,
[
unflatten(
[
(TOOL_CALL_FUNCTION_NAME, "multiply"),
(
TOOL_CALL_FUNCTION_ARGUMENTS_JSON,
json.dumps({"x": 10}),
),
]
),
],
),
]
),
{
"role": "assistant",
"tool_calls": [
{"function": {"name": "multiply", "arguments": {"x": 10}}},
],
},
id="tool_calls_json_string_deserialized",
),
pytest.param(
unflatten(
[
(MESSAGE_ROLE, "assistant"),
(
MESSAGE_TOOL_CALLS,
[
unflatten(
[
(TOOL_CALL_FUNCTION_NAME, "run"),
(
TOOL_CALL_FUNCTION_ARGUMENTS_JSON,
"{broken",
),
]
),
],
),
]
),
{
"role": "assistant",
"tool_calls": [
{"function": {"name": "run", "arguments": "{broken"}},
],
},
id="tool_calls_invalid_json_left_as_string",
),
pytest.param(
unflatten([(MESSAGE_ROLE, "user"), (MESSAGE_TOOL_CALLS, [])]),
{"role": "user"},
id="empty_tool_calls_omitted",
),
pytest.param(
unflatten([(MESSAGE_ROLE, "user"), (MESSAGE_TOOL_CALLS, None)]),
{"role": "user"},
id="none_tool_calls_omitted",
),
],
)
def test_get_message(message: dict[str, Any], expected: dict[str, Any]) -> None:
assert _get_message(message) == expected


@pytest.mark.parametrize(
"span, expected_output_value",
[
Expand Down
Loading