|
3 | 3 | import asyncio |
4 | 4 | import json |
5 | 5 | import tempfile |
| 6 | +import warnings |
6 | 7 | from pathlib import Path |
7 | 8 | from typing import Any, cast |
8 | 9 | from unittest.mock import patch |
|
11 | 12 | import pytest |
12 | 13 | from openai import BadRequestError |
13 | 14 | from openai.types.responses import ResponseFunctionToolCall |
| 15 | +from openai.types.responses.response_output_text import AnnotationFileCitation, ResponseOutputText |
14 | 16 | from typing_extensions import TypedDict |
15 | 17 |
|
16 | 18 | from agents import ( |
|
48 | 50 | from agents.run_internal.items import ( |
49 | 51 | drop_orphan_function_calls, |
50 | 52 | ensure_input_item_format, |
| 53 | + fingerprint_input_item, |
51 | 54 | normalize_input_items_for_api, |
52 | 55 | normalize_resumed_input, |
53 | 56 | ) |
@@ -329,6 +332,14 @@ def testnormalize_input_items_for_api_preserves_provider_data(): |
329 | 332 | assert second["provider_data"] == {"trace": "remove"} |
330 | 333 |
|
331 | 334 |
|
| 335 | +def test_fingerprint_input_item_returns_none_when_model_dump_fails(): |
| 336 | + class _BrokenModelDump: |
| 337 | + def model_dump(self, *_args: Any, **_kwargs: Any) -> dict[str, Any]: |
| 338 | + raise RuntimeError("model_dump failed") |
| 339 | + |
| 340 | + assert fingerprint_input_item(_BrokenModelDump()) is None |
| 341 | + |
| 342 | + |
332 | 343 | def test_server_conversation_tracker_tracks_previous_response_id(): |
333 | 344 | tracker = OpenAIServerConversationTracker(conversation_id=None, previous_response_id="resp_a") |
334 | 345 | response = ModelResponse( |
@@ -1310,6 +1321,29 @@ def model_dump(self, exclude_unset: bool = True) -> dict[str, Any]: |
1310 | 1321 | assert converted["output"] == "dumped" |
1311 | 1322 |
|
1312 | 1323 |
|
| 1324 | +def test_ensure_api_input_item_avoids_pydantic_serialization_warnings(): |
| 1325 | + annotation = AnnotationFileCitation.model_construct( |
| 1326 | + type="container_file_citation", |
| 1327 | + file_id="file_123", |
| 1328 | + filename="result.txt", |
| 1329 | + index=0, |
| 1330 | + ) |
| 1331 | + output_text = ResponseOutputText.model_construct( |
| 1332 | + type="output_text", |
| 1333 | + text="done", |
| 1334 | + annotations=[annotation], |
| 1335 | + ) |
| 1336 | + |
| 1337 | + with warnings.catch_warnings(record=True) as captured: |
| 1338 | + warnings.simplefilter("always") |
| 1339 | + converted = ensure_input_item_format(cast(Any, output_text)) |
| 1340 | + |
| 1341 | + converted_payload = cast(dict[str, Any], converted) |
| 1342 | + assert captured == [] |
| 1343 | + assert converted_payload["type"] == "output_text" |
| 1344 | + assert converted_payload["annotations"][0]["type"] == "container_file_citation" |
| 1345 | + |
| 1346 | + |
1313 | 1347 | def test_ensure_api_input_item_preserves_object_output(): |
1314 | 1348 | payload = cast( |
1315 | 1349 | TResponseInputItem, |
|
0 commit comments