Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions pkg/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c

// forceCompression aggressively reduces context when the limit is hit.
// It drops the oldest 50% of messages (keeping system prompt and last user message).
// IMPORTANT: It preserves tool call/response pairing to avoid API 400 errors.
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
history := agent.Sessions.GetHistory(sessionKey)
if len(history) <= 4 {
Expand All @@ -784,16 +785,23 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
return
}

// Helper to find the mid-point of the conversation
// Find a safe cut point that doesn't break tool call/response pairs.
// A "safe" cut point is right after a user message, because:
// 1. User messages don't have tool call dependencies
// 2. Any preceding tool call/response pairs will be kept together
mid := len(conversation) / 2
cutIndex := findSafeCutPoint(conversation, mid)

// New history structure:
// 1. System Prompt (with compression note appended)
// 2. Second half of conversation
// 2. Second half of conversation (from safe cut point)
// 3. Last message

droppedCount := mid
keptConversation := conversation[mid:]
droppedCount := cutIndex
keptConversation := conversation[cutIndex:]

// Additional safety: remove orphaned tool messages at the start of kept conversation
keptConversation = removeOrphanedToolMessages(keptConversation)

newHistory := make([]providers.Message, 0, 1+len(keptConversation)+1)

Expand Down Expand Up @@ -821,6 +829,39 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
})
}

// findSafeCutPoint finds a safe index to cut the conversation without breaking tool call/response pairs.
// It starts from the mid-point and searches forward for a user message, which is always safe to cut after.
func findSafeCutPoint(conversation []providers.Message, mid int) int {
// Search forward from mid to find a user message
for i := mid; i < len(conversation); i++ {
if conversation[i].Role == "user" {
return i + 1 // Cut after the user message
}
}

// Fallback: search backward from mid
for i := mid - 1; i >= 0; i-- {
if conversation[i].Role == "user" {
return i + 1 // Cut after the user message
}
}

// No user message found (edge case), use mid but this may cause issues
return mid
}

// removeOrphanedToolMessages removes tool messages at the start that don't have a preceding
// assistant message with tool_calls. This is a safety net for edge cases.
func removeOrphanedToolMessages(messages []providers.Message) []providers.Message {
// Find the first non-tool message
for i := 0; i < len(messages); i++ {
if messages[i].Role != "tool" {
return messages[i:]
}
}
return messages
}

// GetStartupInfo returns information about loaded tools and skills for logging.
func (al *AgentLoop) GetStartupInfo() map[string]any {
info := make(map[string]any)
Expand Down
166 changes: 166 additions & 0 deletions pkg/agent/loop_compression_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package agent

import (
"testing"

"github.com/sipeed/picoclaw/pkg/providers"
)

func TestFindSafeCutPoint(t *testing.T) {
tests := []struct {
name string
conversation []providers.Message
mid int
expectedIndex int
}{
{
name: "cut after user message at mid",
conversation: []providers.Message{
{Role: "user", Content: "msg1"},
{Role: "assistant", Content: "msg2"},
{Role: "user", Content: "msg3"},
{Role: "assistant", Content: "msg4"},
},
mid: 2,
expectedIndex: 3, // cut after user at index 2
},
{
name: "cut at user message forward search",
conversation: []providers.Message{
{Role: "user", Content: "msg1"},
{Role: "assistant", Content: "msg2"},
{Role: "assistant", Content: "msg3"},
{Role: "user", Content: "msg4"},
},
mid: 2,
expectedIndex: 4, // cut after user at index 3
},
{
name: "cut at user message backward search",
conversation: []providers.Message{
{Role: "user", Content: "msg1"},
{Role: "assistant", Content: "msg2"},
{Role: "assistant", Content: "msg3"},
{Role: "assistant", Content: "msg4"},
},
mid: 3,
expectedIndex: 1, // cut after user at index 0
},
{
name: "no user message fallback to mid",
conversation: []providers.Message{
{Role: "assistant", Content: "msg1"},
{Role: "assistant", Content: "msg2"},
{Role: "assistant", Content: "msg3"},
},
mid: 1,
expectedIndex: 1, // fallback to mid
},
{
name: "tool call response pair preserved",
conversation: []providers.Message{
{Role: "user", Content: "msg1"},
{Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{{ID: "tc1", Name: "tool1"}}},
{Role: "tool", Content: "result1", ToolCallID: "tc1"},
{Role: "user", Content: "msg2"},
{Role: "assistant", Content: "msg3"},
},
mid: 2,
expectedIndex: 4, // cut after user at index 3, preserving tool pair
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := findSafeCutPoint(tt.conversation, tt.mid)
if result != tt.expectedIndex {
t.Errorf("findSafeCutPoint() = %d, want %d", result, tt.expectedIndex)
}
})
}
}

func TestRemoveOrphanedToolMessages(t *testing.T) {
tests := []struct {
name string
messages []providers.Message
expected int // expected length of result
}{
{
name: "no orphaned messages",
messages: []providers.Message{
{Role: "user", Content: "msg1"},
{Role: "assistant", Content: "msg2"},
},
expected: 2,
},
{
name: "one orphaned tool message",
messages: []providers.Message{
{Role: "tool", Content: "orphaned", ToolCallID: "tc1"},
{Role: "user", Content: "msg1"},
{Role: "assistant", Content: "msg2"},
},
expected: 2,
},
{
name: "multiple orphaned tool messages",
messages: []providers.Message{
{Role: "tool", Content: "orphaned1", ToolCallID: "tc1"},
{Role: "tool", Content: "orphaned2", ToolCallID: "tc2"},
{Role: "user", Content: "msg1"},
},
expected: 1,
},
{
name: "all tool messages",
messages: []providers.Message{
{Role: "tool", Content: "orphaned1", ToolCallID: "tc1"},
{Role: "tool", Content: "orphaned2", ToolCallID: "tc2"},
},
expected: 2, // returns all if no non-tool found
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := removeOrphanedToolMessages(tt.messages)
if len(result) != tt.expected {
t.Errorf("removeOrphanedToolMessages() length = %d, want %d", len(result), tt.expected)
}
})
}
}

func TestForceCompressionPreservesToolPairs(t *testing.T) {
// This test verifies that forceCompression doesn't break tool call/response pairs.
// We can't easily test the full forceCompression function due to dependencies,
// but we can test the helper functions that ensure safety.

// Scenario: conversation with tool calls
conversation := []providers.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{{ID: "tc1", Name: "get_weather"}}},
{Role: "tool", Content: "Sunny, 25°C", ToolCallID: "tc1"},
{Role: "assistant", Content: "It's sunny today!"},
{Role: "user", Content: "Thanks!"},
{Role: "assistant", Content: "You're welcome!"},
}

// Mid point would be 3, but we should find user message at index 4
cutIndex := findSafeCutPoint(conversation, 3)
if cutIndex != 5 {
t.Errorf("Expected cut index 5 (after user at index 4), got %d", cutIndex)
}

// Verify that cutting at this index doesn't leave orphaned tool messages
kept := conversation[cutIndex:]
for _, msg := range kept {
if msg.Role == "tool" {
// A tool message in kept section shouldn't exist if we cut correctly
// because we cut after a user message, which means any preceding
// tool call/response pairs are before the cut point
t.Errorf("Tool message found in kept conversation, this breaks pairing")
}
}
}