Skip to content

Commit 93a2d1d

Browse files
committed
refactor agent
1 parent ec6d2ff commit 93a2d1d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+5402
-9116
lines changed

agent/agent.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,89 @@ func (a *Agent) GetCurrentChatID() string {
446446
return "main"
447447
}
448448

449+
// Steer adds a steering message to interrupt the agent mid-run
450+
// Inspired by pi-mono's Agent.steer() method
451+
func (a *Agent) Steer(msg AgentMessage) {
452+
a.mu.Lock()
453+
defer a.mu.Unlock()
454+
455+
a.state.Steer(msg)
456+
}
457+
458+
// FollowUp adds a follow-up message to be processed after agent finishes
459+
// Inspired by pi-mono's Agent.followUp() method
460+
func (a *Agent) FollowUp(msg AgentMessage) {
461+
a.mu.Lock()
462+
defer a.mu.Unlock()
463+
464+
a.state.FollowUp(msg)
465+
}
466+
467+
// WaitForIdle waits until the agent is not streaming
468+
// Inspired by pi-mono's Agent.waitForIdle() method
469+
func (a *Agent) WaitForIdle(ctx context.Context) error {
470+
ticker := time.NewTicker(10 * time.Millisecond)
471+
defer ticker.Stop()
472+
473+
for {
474+
select {
475+
case <-ctx.Done():
476+
return ctx.Err()
477+
case <-ticker.C:
478+
a.mu.RLock()
479+
isStreaming := a.state.IsStreaming
480+
a.mu.RUnlock()
481+
if !isStreaming {
482+
return nil
483+
}
484+
}
485+
}
486+
}
487+
488+
// Abort aborts the current agent execution
489+
// Inspired by pi-mono's Agent.abort() method
490+
func (a *Agent) Abort() {
491+
a.orchestrator.Stop()
492+
}
493+
494+
// Reset resets the agent state
495+
// Inspired by pi-mono's Agent.reset() method
496+
func (a *Agent) Reset() {
497+
a.mu.Lock()
498+
defer a.mu.Unlock()
499+
500+
a.state = NewAgentState()
501+
a.state.SystemPrompt = a.context.BuildSystemPrompt(nil)
502+
a.state.Model = getModelName(a.provider)
503+
a.state.Provider = "provider"
504+
a.state.SessionKey = "main"
505+
a.state.Tools = ToAgentTools(a.tools.ListExisting())
506+
}
507+
508+
// SetSteeringMode sets how steering messages are delivered
509+
func (a *Agent) SetSteeringMode(mode MessageQueueMode) {
510+
a.mu.Lock()
511+
defer a.mu.Unlock()
512+
a.state.SteeringMode = mode
513+
}
514+
515+
// SetFollowUpMode sets how follow-up messages are delivered
516+
func (a *Agent) SetFollowUpMode(mode MessageQueueMode) {
517+
a.mu.Lock()
518+
defer a.mu.Unlock()
519+
a.state.FollowUpMode = mode
520+
}
521+
522+
// ReplaceMessages replaces the message history
523+
// Inspired by pi-mono's Agent.replaceMessages() method
524+
func (a *Agent) ReplaceMessages(messages []AgentMessage) {
525+
a.mu.Lock()
526+
defer a.mu.Unlock()
527+
528+
a.state.Messages = make([]AgentMessage, len(messages))
529+
copy(a.state.Messages, messages)
530+
}
531+
449532
// GetOrchestrator 获取 orchestrator(供 AgentManager 使用)
450533
func (a *Agent) GetOrchestrator() *Orchestrator {
451534
return a.orchestrator

agent/agent_tools.go

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
"fmt"
6+
)
7+
8+
// AgentTool is the unified tool interface for the agent
9+
// Inspired by pi-mono's AgentTool<TParameters, TDetails> interface
10+
type AgentTool interface {
11+
// Name returns the tool name (used for tool calls)
12+
Name() string
13+
14+
// Description returns what the tool does
15+
Description() string
16+
17+
// Parameters returns JSON Schema for the tool's parameters
18+
Parameters() map[string]any
19+
20+
// Label returns a human-readable label for UI display
21+
// Inspired by pi-mono's AgentTool.label
22+
Label() string
23+
24+
// Execute runs the tool with streaming update support
25+
// toolCallId: unique identifier for this tool call
26+
// params: validated parameters
27+
// signal: cancellation signal
28+
// onUpdate: callback for streaming updates
29+
Execute(ctx context.Context, toolCallId string, params map[string]any, signal context.Context, onUpdate func(AgentToolResult)) (AgentToolResult, error)
30+
}
31+
32+
// AgentToolResult represents the result of a tool execution
33+
// Inspired by pi-mono's AgentToolResult<T>
34+
type AgentToolResult struct {
35+
// Content blocks supporting text and images
36+
Content []ContentBlock `json:"content"`
37+
// Details to be displayed in a UI or logged
38+
Details map[string]any `json:"details"`
39+
}
40+
41+
// NewAgentToolResult creates a new tool result
42+
func NewAgentToolResult(content string) AgentToolResult {
43+
return AgentToolResult{
44+
Content: []ContentBlock{TextContent{Text: content}},
45+
Details: make(map[string]any),
46+
}
47+
}
48+
49+
// NewAgentToolResultWithDetails creates a new tool result with details
50+
func NewAgentToolResultWithDetails(content string, details map[string]any) AgentToolResult {
51+
return AgentToolResult{
52+
Content: []ContentBlock{TextContent{Text: content}},
53+
Details: details,
54+
}
55+
}
56+
57+
// BaseAgentTool provides a base implementation of AgentTool
58+
type BaseAgentTool struct {
59+
name string
60+
label string
61+
description string
62+
parameters map[string]any
63+
executeFunc func(ctx context.Context, toolCallId string, params map[string]any, signal context.Context, onUpdate func(AgentToolResult)) (AgentToolResult, error)
64+
}
65+
66+
// NewBaseAgentTool creates a new base agent tool
67+
func NewBaseAgentTool(name, label, description string, parameters map[string]any, executeFunc func(ctx context.Context, toolCallId string, params map[string]any, signal context.Context, onUpdate func(AgentToolResult)) (AgentToolResult, error)) *BaseAgentTool {
68+
if label == "" {
69+
label = name
70+
}
71+
return &BaseAgentTool{
72+
name: name,
73+
label: label,
74+
description: description,
75+
parameters: parameters,
76+
executeFunc: executeFunc,
77+
}
78+
}
79+
80+
// Name returns the tool name
81+
func (t *BaseAgentTool) Name() string {
82+
return t.name
83+
}
84+
85+
// Label returns the tool label
86+
func (t *BaseAgentTool) Label() string {
87+
return t.label
88+
}
89+
90+
// Description returns the tool description
91+
func (t *BaseAgentTool) Description() string {
92+
return t.description
93+
}
94+
95+
// Parameters returns the tool parameters
96+
func (t *BaseAgentTool) Parameters() map[string]any {
97+
return t.parameters
98+
}
99+
100+
// Execute executes the tool
101+
func (t *BaseAgentTool) Execute(ctx context.Context, toolCallId string, params map[string]any, signal context.Context, onUpdate func(AgentToolResult)) (AgentToolResult, error) {
102+
return t.executeFunc(ctx, toolCallId, params, signal, onUpdate)
103+
}
104+
105+
// AdaptTool converts an existing Tool to AgentTool
106+
func AdaptTool(tool Tool) AgentTool {
107+
return &toolAgentAdapter{tool: tool}
108+
}
109+
110+
// toolAgentAdapter adapts Tool to AgentTool interface
111+
type toolAgentAdapter struct {
112+
tool Tool
113+
}
114+
115+
func (a *toolAgentAdapter) Name() string {
116+
return a.tool.Name()
117+
}
118+
119+
func (a *toolAgentAdapter) Label() string {
120+
return a.tool.Label()
121+
}
122+
123+
func (a *toolAgentAdapter) Description() string {
124+
return a.tool.Description()
125+
}
126+
127+
func (a *toolAgentAdapter) Parameters() map[string]any {
128+
return a.tool.Parameters()
129+
}
130+
131+
func (a *toolAgentAdapter) Execute(ctx context.Context, toolCallId string, params map[string]any, signal context.Context, onUpdate func(AgentToolResult)) (AgentToolResult, error) {
132+
// Call the tool's Execute method
133+
result, err := a.tool.Execute(ctx, params, func(tr ToolResult) {
134+
if onUpdate != nil {
135+
onUpdate(AgentToolResult{
136+
Content: tr.Content,
137+
Details: tr.Details,
138+
})
139+
}
140+
})
141+
142+
if err != nil {
143+
return AgentToolResult{}, err
144+
}
145+
146+
return AgentToolResult{
147+
Content: result.Content,
148+
Details: result.Details,
149+
}, nil
150+
}
151+
152+
// AgentToolFromFunc creates an AgentTool from a simple function
153+
func AgentToolFromFunc(name, label, description string, parameters map[string]any, fn func(ctx context.Context, params map[string]any) (string, error)) AgentTool {
154+
return NewBaseAgentTool(name, label, description, parameters,
155+
func(ctx context.Context, toolCallId string, params map[string]any, signal context.Context, onUpdate func(AgentToolResult)) (AgentToolResult, error) {
156+
result, err := fn(ctx, params)
157+
if err != nil {
158+
return AgentToolResult{}, err
159+
}
160+
return AgentToolResult{
161+
Content: []ContentBlock{TextContent{Text: result}},
162+
Details: make(map[string]any),
163+
}, nil
164+
},
165+
)
166+
}
167+
168+
// ToAgentTools converts a slice of Tool to AgentTool
169+
func ToAgentToolsSlice(tools []Tool) []AgentTool {
170+
result := make([]AgentTool, len(tools))
171+
for i, t := range tools {
172+
result[i] = AdaptTool(t)
173+
}
174+
return result
175+
}
176+
177+
// ToTools converts a slice of AgentTool to Tool (for compatibility)
178+
func ToTools(agentTools []AgentTool) []Tool {
179+
result := make([]Tool, len(agentTools))
180+
for i, t := range agentTools {
181+
result[i] = &agentToolAdapter{tool: t}
182+
}
183+
return result
184+
}
185+
186+
// agentToolAdapter adapts AgentTool to Tool interface
187+
type agentToolAdapter struct {
188+
tool AgentTool
189+
}
190+
191+
func (a *agentToolAdapter) Name() string {
192+
return a.tool.Name()
193+
}
194+
195+
func (a *agentToolAdapter) Label() string {
196+
return a.tool.Label()
197+
}
198+
199+
func (a *agentToolAdapter) Description() string {
200+
return a.tool.Description()
201+
}
202+
203+
func (a *agentToolAdapter) Parameters() map[string]any {
204+
return a.tool.Parameters()
205+
}
206+
207+
func (a *agentToolAdapter) Execute(ctx context.Context, params map[string]any, onUpdate func(ToolResult)) (ToolResult, error) {
208+
result, err := a.tool.Execute(ctx, "", params, nil, func(atr AgentToolResult) {
209+
if onUpdate != nil {
210+
onUpdate(ToolResult{
211+
Content: atr.Content,
212+
Details: atr.Details,
213+
})
214+
}
215+
})
216+
217+
if err != nil {
218+
return ToolResult{Error: err}, err
219+
}
220+
221+
return ToolResult{
222+
Content: result.Content,
223+
Details: result.Details,
224+
}, nil
225+
}
226+
227+
// ValidateToolParameters validates tool parameters against schema
228+
func ValidateToolParameters(params map[string]any, schema map[string]any) error {
229+
required := []string{}
230+
if req, ok := schema["required"].([]any); ok {
231+
for _, r := range req {
232+
if s, ok := r.(string); ok {
233+
required = append(required, s)
234+
}
235+
}
236+
}
237+
238+
for _, field := range required {
239+
if _, ok := params[field]; !ok {
240+
return &ToolValidationError{
241+
Field: field,
242+
Message: "required field missing",
243+
}
244+
}
245+
}
246+
247+
return nil
248+
}
249+
250+
// ToolValidationError is returned when parameter validation fails
251+
type ToolValidationError struct {
252+
Field string
253+
Message string
254+
}
255+
256+
func (e *ToolValidationError) Error() string {
257+
return fmt.Sprintf("validation error for field '%s': %s", e.Field, e.Message)
258+
}

0 commit comments

Comments
 (0)