diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 7c774ff17..72f4fb510 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -770,6 +770,7 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c // forceCompression aggressively reduces context when the limit is hit. // It drops the oldest 50% of messages (keeping system prompt and last user message). +// IMPORTANT: It preserves tool call/response pairing to avoid API 400 errors. func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { history := agent.Sessions.GetHistory(sessionKey) if len(history) <= 4 { @@ -784,16 +785,26 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { return } - // Helper to find the mid-point of the conversation + // Find a safe cut point that doesn't break tool call/response pairs. + // A "safe" cut point is right after a user message, because: + // 1. User messages don't have tool call dependencies + // 2. Any preceding tool call/response pairs will be kept together mid := len(conversation) / 2 + cutIndex := findSafeCutPoint(conversation, mid) // New history structure: // 1. System Prompt (with compression note appended) - // 2. Second half of conversation + // 2. Second half of conversation (from safe cut point) // 3. Last message - droppedCount := mid - keptConversation := conversation[mid:] + droppedCount := cutIndex + keptConversation := conversation[cutIndex:] + + // Additional safety: remove orphaned tool messages at the start of kept conversation + keptConversation = removeOrphanedToolMessages(keptConversation) + + // Additional safety: remove orphaned assistant messages with tool_calls at the end + keptConversation = removeOrphanedAssistantWithToolCalls(keptConversation) newHistory := make([]providers.Message, 0, 1+len(keptConversation)+1) @@ -821,6 +832,154 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { }) } +// findSafeCutPoint finds a safe index to cut the conversation without breaking tool call/response pairs. +// It starts from the mid-point and searches forward for a user message, which is always safe to cut after. +// If no user message is found, it falls back to finding a safe cut after the last complete tool call/response sequence. +func findSafeCutPoint(conversation []providers.Message, mid int) int { + // Search forward from mid to find a user message + for i := mid; i < len(conversation); i++ { + if conversation[i].Role == "user" { + return i + 1 // Cut after the user message + } + } + + // Fallback: search backward from mid + for i := mid - 1; i >= 0; i-- { + if conversation[i].Role == "user" { + return i + 1 // Cut after the user message + } + } + + // No user message found (edge case): find a safe cut after the last complete tool sequence + // A safe cut is after all tool results for a tool call, i.e., after a non-tool message + // that doesn't have tool_calls, or after the last tool result of a complete sequence. + for i := mid; i < len(conversation); i++ { + // Find a position after all consecutive tool messages + if conversation[i].Role != "tool" { + // Check if this is an assistant without tool_calls (safe cut point) + // or if we need to skip past any tool results + if conversation[i].Role == "assistant" && len(conversation[i].ToolCalls) == 0 { + return i + 1 // Cut after this assistant message + } + // If it's an assistant with tool_calls, we need to find the end of the tool results + if conversation[i].Role == "assistant" && len(conversation[i].ToolCalls) > 0 { + // Count how many tool results we expect + expectedResults := len(conversation[i].ToolCalls) + resultCount := 0 + for j := i + 1; j < len(conversation) && resultCount < expectedResults; j++ { + if conversation[j].Role == "tool" { + resultCount++ + } + } + // Cut after all tool results + if resultCount == expectedResults { + // Find the position after the last tool result + for j := i + expectedResults; j < len(conversation); j++ { + if conversation[j].Role != "tool" { + return j + } + } + return len(conversation) + } + } + } + } + + // Ultimate fallback: use mid (may cause issues, but removeOrphanedToolMessages will help) + return mid +} + +// removeOrphanedToolMessages removes tool messages at the start that don't have a preceding +// assistant message with tool_calls. This is a safety net for edge cases. +func removeOrphanedToolMessages(messages []providers.Message) []providers.Message { + // Find the first non-tool message + for i := 0; i < len(messages); i++ { + if messages[i].Role != "tool" { + return messages[i:] + } + } + return messages +} + +// removeOrphanedAssistantWithToolCalls removes assistant messages with tool_calls at the end +// that don't have corresponding tool result messages. This prevents API errors where the +// provider expects tool results that were cut away. +func removeOrphanedAssistantWithToolCalls(messages []providers.Message) []providers.Message { + // Two-pass approach: + // 1. First pass: determine which assistant messages with tool_calls are valid (have all results) + // 2. Second pass: filter messages, keeping only valid tool results and assistants + + // Build set of tool_call IDs from assistants that have ALL their results present + validToolCallIDs := make(map[string]bool) + for _, m := range messages { + if m.Role == "assistant" && len(m.ToolCalls) > 0 { + // Check if ALL tool_calls have results + allHaveResults := true + for _, tc := range m.ToolCalls { + if tc.ID == "" { + continue + } + // Check if this tool_call has a result + hasResult := false + for _, m2 := range messages { + if m2.Role == "tool" && m2.ToolCallID == tc.ID { + hasResult = true + break + } + } + if !hasResult { + allHaveResults = false + break + } + } + if allHaveResults { + for _, tc := range m.ToolCalls { + if tc.ID != "" { + validToolCallIDs[tc.ID] = true + } + } + } + } + } + + // Second pass: filter messages + result := make([]providers.Message, 0, len(messages)) + for _, m := range messages { + switch { + case m.Role == "tool" && m.ToolCallID != "": + // Keep tool result only if its tool_call is valid + if validToolCallIDs[m.ToolCallID] { + result = append(result, m) + } + + case m.Role == "assistant" && len(m.ToolCalls) > 0: + // Check if this assistant's tool_calls are all valid + allValid := true + for _, tc := range m.ToolCalls { + if tc.ID != "" && !validToolCallIDs[tc.ID] { + allValid = false + break + } + } + if allValid { + result = append(result, m) + } else if m.Content != "" { + // Keep text content but strip tool_calls + result = append(result, providers.Message{ + Role: "assistant", + Content: m.Content, + }) + } + // If no content and invalid tool_calls, drop entirely + + default: + result = append(result, m) + } + } + + return result +} + // GetStartupInfo returns information about loaded tools and skills for logging. func (al *AgentLoop) GetStartupInfo() map[string]any { info := make(map[string]any) diff --git a/pkg/agent/loop_compression_test.go b/pkg/agent/loop_compression_test.go new file mode 100644 index 000000000..12c6aaada --- /dev/null +++ b/pkg/agent/loop_compression_test.go @@ -0,0 +1,254 @@ +package agent + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func TestFindSafeCutPoint(t *testing.T) { + tests := []struct { + name string + conversation []providers.Message + mid int + expectedIndex int + }{ + { + name: "cut after user message at mid", + conversation: []providers.Message{ + {Role: "user", Content: "msg1"}, + {Role: "assistant", Content: "msg2"}, + {Role: "user", Content: "msg3"}, + {Role: "assistant", Content: "msg4"}, + }, + mid: 2, + expectedIndex: 3, // cut after user at index 2 + }, + { + name: "cut at user message forward search", + conversation: []providers.Message{ + {Role: "user", Content: "msg1"}, + {Role: "assistant", Content: "msg2"}, + {Role: "assistant", Content: "msg3"}, + {Role: "user", Content: "msg4"}, + }, + mid: 2, + expectedIndex: 4, // cut after user at index 3 + }, + { + name: "cut at user message backward search", + conversation: []providers.Message{ + {Role: "user", Content: "msg1"}, + {Role: "assistant", Content: "msg2"}, + {Role: "assistant", Content: "msg3"}, + {Role: "assistant", Content: "msg4"}, + }, + mid: 3, + expectedIndex: 1, // cut after user at index 0 + }, + { + name: "no user message fallback to tool sequence", + conversation: []providers.Message{ + {Role: "assistant", Content: "msg1"}, + {Role: "assistant", Content: "msg2"}, + {Role: "assistant", Content: "msg3"}, + }, + mid: 1, + expectedIndex: 2, // finds assistant without tool_calls at index 1, returns 2 + }, + { + name: "tool call response pair preserved", + conversation: []providers.Message{ + {Role: "user", Content: "msg1"}, + {Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{{ID: "tc1", Name: "tool1"}}}, + {Role: "tool", Content: "result1", ToolCallID: "tc1"}, + {Role: "user", Content: "msg2"}, + {Role: "assistant", Content: "msg3"}, + }, + mid: 2, + expectedIndex: 4, // cut after user at index 3, preserving tool pair + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := findSafeCutPoint(tt.conversation, tt.mid) + if result != tt.expectedIndex { + t.Errorf("findSafeCutPoint() = %d, want %d", result, tt.expectedIndex) + } + }) + } +} + +func TestRemoveOrphanedToolMessages(t *testing.T) { + tests := []struct { + name string + messages []providers.Message + expected int // expected length of result + }{ + { + name: "no orphaned messages", + messages: []providers.Message{ + {Role: "user", Content: "msg1"}, + {Role: "assistant", Content: "msg2"}, + }, + expected: 2, + }, + { + name: "one orphaned tool message", + messages: []providers.Message{ + {Role: "tool", Content: "orphaned", ToolCallID: "tc1"}, + {Role: "user", Content: "msg1"}, + {Role: "assistant", Content: "msg2"}, + }, + expected: 2, + }, + { + name: "multiple orphaned tool messages", + messages: []providers.Message{ + {Role: "tool", Content: "orphaned1", ToolCallID: "tc1"}, + {Role: "tool", Content: "orphaned2", ToolCallID: "tc2"}, + {Role: "user", Content: "msg1"}, + }, + expected: 1, + }, + { + name: "all tool messages", + messages: []providers.Message{ + {Role: "tool", Content: "orphaned1", ToolCallID: "tc1"}, + {Role: "tool", Content: "orphaned2", ToolCallID: "tc2"}, + }, + expected: 2, // returns all if no non-tool found + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := removeOrphanedToolMessages(tt.messages) + if len(result) != tt.expected { + t.Errorf("removeOrphanedToolMessages() length = %d, want %d", len(result), tt.expected) + } + }) + } +} + +func TestForceCompressionPreservesToolPairs(t *testing.T) { + // This test verifies that forceCompression doesn't break tool call/response pairs. + // We can't easily test the full forceCompression function due to dependencies, + // but we can test the helper functions that ensure safety. + + // Scenario: conversation with tool calls + conversation := []providers.Message{ + {Role: "user", Content: "What's the weather?"}, + {Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{{ID: "tc1", Name: "get_weather"}}}, + {Role: "tool", Content: "Sunny, 25°C", ToolCallID: "tc1"}, + {Role: "assistant", Content: "It's sunny today!"}, + {Role: "user", Content: "Thanks!"}, + {Role: "assistant", Content: "You're welcome!"}, + } + + // Mid point would be 3, but we should find user message at index 4 + cutIndex := findSafeCutPoint(conversation, 3) + if cutIndex != 5 { + t.Errorf("Expected cut index 5 (after user at index 4), got %d", cutIndex) + } + + // Verify that cutting at this index doesn't leave orphaned tool messages + kept := conversation[cutIndex:] + for _, msg := range kept { + if msg.Role == "tool" { + // A tool message in kept section shouldn't exist if we cut correctly + // because we cut after a user message, which means any preceding + // tool call/response pairs are before the cut point + t.Errorf("Tool message found in kept conversation, this breaks pairing") + } + } +} + +func TestRemoveOrphanedAssistantWithToolCalls(t *testing.T) { + tests := []struct { + name string + messages []providers.Message + expectedLen int + expectedRoles []string + }{ + { + name: "no orphaned assistant messages", + messages: []providers.Message{ + {Role: "user", Content: "msg1"}, + {Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{{ID: "tc1", Name: "tool1"}}}, + {Role: "tool", Content: "result1", ToolCallID: "tc1"}, + {Role: "assistant", Content: "msg2"}, + }, + expectedLen: 4, + expectedRoles: []string{"user", "assistant", "tool", "assistant"}, + }, + { + name: "orphaned assistant with tool_calls at end", + messages: []providers.Message{ + {Role: "user", Content: "msg1"}, + {Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{{ID: "tc1", Name: "tool1"}}}, + // tool result was cut away + }, + expectedLen: 1, + expectedRoles: []string{"user"}, + }, + { + name: "orphaned assistant with tool_calls and text content", + messages: []providers.Message{ + {Role: "user", Content: "msg1"}, + {Role: "assistant", Content: "Let me help", ToolCalls: []providers.ToolCall{{ID: "tc1", Name: "tool1"}}}, + // tool result was cut away + }, + expectedLen: 2, + expectedRoles: []string{"user", "assistant"}, + }, + { + name: "partial tool results - some missing", + messages: []providers.Message{ + {Role: "user", Content: "msg1"}, + {Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{ + {ID: "tc1", Name: "tool1"}, + {ID: "tc2", Name: "tool2"}, + }}, + {Role: "tool", Content: "result1", ToolCallID: "tc1"}, + // tc2 result missing + {Role: "assistant", Content: "msg2"}, + }, + expectedLen: 2, // user + final assistant (orphaned assistant and its partial results removed) + expectedRoles: []string{"user", "assistant"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := removeOrphanedAssistantWithToolCalls(tt.messages) + if len(result) != tt.expectedLen { + t.Errorf("removeOrphanedAssistantWithToolCalls() length = %d, want %d", len(result), tt.expectedLen) + } + for i, role := range tt.expectedRoles { + if i < len(result) && result[i].Role != role { + t.Errorf("message[%d].Role = %s, want %s", i, result[i].Role, role) + } + } + }) + } +} + +func TestFindSafeCutPoint_FallbackToToolSequence(t *testing.T) { + // Edge case: conversation with no user messages but tool sequences + conversation := []providers.Message{ + {Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{{ID: "tc1", Name: "tool1"}}}, + {Role: "tool", Content: "result1", ToolCallID: "tc1"}, + {Role: "assistant", Content: "response"}, + {Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{{ID: "tc2", Name: "tool2"}}}, + {Role: "tool", Content: "result2", ToolCallID: "tc2"}, + } + + // mid = 2, should find safe cut after first tool sequence + cutIndex := findSafeCutPoint(conversation, 2) + // Should cut after "response" (index 2), so cutIndex = 3 + if cutIndex < 2 || cutIndex > 3 { + t.Logf("cutIndex = %d (acceptable range 2-3)", cutIndex) + } +} \ No newline at end of file