-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathagent.py
More file actions
136 lines (125 loc) · 4.93 KB
/
agent.py
File metadata and controls
136 lines (125 loc) · 4.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager
from typing import Generic, TypeVar
from pydantic_ai.agent import Agent, AgentRun, AgentRunResult
from pydantic_ai.mcp import MCPServer
from pydantic_ai.messages import (
AgentStreamEvent,
FinalResultEvent,
HandleResponseEvent,
ModelMessage,
UserContent,
)
from pydantic_ai.models import Model
from pydantic_ai.models.function import FunctionModel
from pydantic_ai.result import ToolOutput
from pydantic_ai.tools import Tool
from pydantic_ai.usage import Usage
from lightblue_ai.log import logger
from lightblue_ai.mcps import get_mcp_servers
from lightblue_ai.models import infer_model
from lightblue_ai.prompts import get_system_prompt
from lightblue_ai.settings import Settings
from lightblue_ai.tools.manager import LightBlueToolManager
OutputDataT = TypeVar("T")
class LightBlueAgent(Generic[OutputDataT]):
def __init__(
self,
model: str | Model | None = None,
system_prompt: str | None = None,
result_type: type[OutputDataT] = str,
result_tool_name: str = "final_result",
result_tool_description: str | None = None,
tools: list[Tool] | None = None,
mcp_servers: list[MCPServer] | None = None,
retries: int = 3,
max_description_length: int | None = None,
strict: bool | None = None,
):
self.settings = Settings()
model = model or self.settings.default_model
tools = tools or []
mcp_servers = mcp_servers or []
if not model:
raise ValueError("model or ENV `DEFAULT_MODEL` must be set")
model_name = model.model_name if isinstance(model, Model) else model
system_prompt = system_prompt or get_system_prompt()
if (
"openrouter" in model_name
or "openai" in model_name
or ("anthropic" not in model_name and "gemini-2.5" not in model_name)
) and not isinstance(model, FunctionModel):
# OpenAI Compatible OR not anthropic/gemini-2.5
max_description_length = max_description_length or 1000
else:
max_description_length = max_description_length
logger.info(f"Using model: {model_name}, description length: {max_description_length}")
self.tool_manager = LightBlueToolManager(max_description_length=max_description_length, strict=strict)
if max_description_length and self.settings.append_tools_to_prompt:
system_prompt = "\n".join([
system_prompt,
"## The following tools are available to you:",
self.tool_manager.describe_all_tools(),
])
self.agent = Agent[result_type](
infer_model(model),
output_type=(
ToolOutput(
type_=result_type,
name=result_tool_name,
description=result_tool_description,
strict=strict,
)
if result_type is not str
else str
),
system_prompt=system_prompt,
tools=[*tools, *self.tool_manager.get_all_tools()],
mcp_servers=[*mcp_servers, *get_mcp_servers()],
retries=retries,
)
async def run(
self,
user_prompt: str | Sequence[UserContent],
*,
message_history: None | list[ModelMessage] = None,
usage: None | Usage = None,
) -> AgentRunResult[OutputDataT]:
async with self.agent.run_mcp_servers():
result = await self.agent.run(user_prompt, message_history=message_history)
if usage:
usage.incr(result.usage())
return result
@asynccontextmanager
async def iter(
self,
user_prompt: str | Sequence[UserContent],
*,
message_history: None | list[ModelMessage] = None,
usage: None | Usage = None,
) -> AsyncIterator[AgentRun]:
async with (
self.agent.run_mcp_servers(),
self.agent.iter(
user_prompt,
message_history=message_history,
) as run,
):
yield run
if usage:
usage.incr(run.usage())
async def yield_response_event(self, run: AgentRun) -> AsyncIterator[HandleResponseEvent | AgentStreamEvent]:
"""
Yield the response event from the node.
"""
async for node in run:
if Agent.is_user_prompt_node(node) or Agent.is_end_node(node):
continue
elif Agent.is_model_request_node(node) or Agent.is_call_tools_node(node):
async with node.stream(run.ctx) as request_stream:
async for event in request_stream:
if not event or isinstance(event, FinalResultEvent):
continue
yield event
else:
logger.warning(f"Unknown node: {node}")