Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 63 additions & 12 deletions pkg/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,9 @@ func (al *AgentLoop) runLLMIteration(
var response *providers.LLMResponse
var err error

callLLM := func() (*providers.LLMResponse, error) {
callLLMOnce := func(callCtx context.Context) (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
fbResult, fbErr := al.fallback.Execute(callCtx, agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
"max_tokens": agent.MaxTokens,
Expand All @@ -650,28 +650,45 @@ func (al *AgentLoop) runLLMIteration(
}
return fbResult.Response, nil
}
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]any{
return agent.Provider.Chat(callCtx, messages, providerToolDefs, agent.Model, map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
})
}

// Retry loop for context/token errors
retryPolicy := utils.DefaultLLMRetryPolicy()
retryPolicy.Notify = func(notice utils.RetryNotice) {
logger.WarnCF("agent", "Transient LLM error detected, retrying", map[string]any{
"attempt": notice.Attempt,
"total": notice.Total,
"reason": notice.Decision.Reason,
"status": notice.Decision.Status,
"retry_after": notice.Decision.RetryAfter.String(),
"backoff": notice.Delay.String(),
})

// User-facing notice only on first retry to avoid spam.
if notice.Attempt != 1 || constants.IsInternalChannel(opts.Channel) {
return
}
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: utils.FormatLLMRetryNotice(notice),
})
}

// Outer retry loop for context-window compression.
// Transient/network retries are handled inside DoWithRetry.
maxRetries := 2
for retry := 0; retry <= maxRetries; retry++ {
response, err = callLLM()
response, err = utils.DoWithRetry(ctx, retryPolicy, callLLMOnce)
if err == nil {
break
}

errMsg := strings.ToLower(err.Error())
isContextError := strings.Contains(errMsg, "token") ||
strings.Contains(errMsg, "context") ||
strings.Contains(errMsg, "invalidparameter") ||
strings.Contains(errMsg, "length")

if isContextError && retry < maxRetries {
if retry < maxRetries && isContextWindowError(err) {
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
"error": err.Error(),
"retry": retry,
Expand Down Expand Up @@ -913,6 +930,40 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c
}
}

func isContextWindowError(err error) bool {
if err == nil {
return false
}

errMsg := strings.ToLower(err.Error())
contextPatterns := []string{
"context window",
"context length",
"maximum context length",
"max context length",
"too many tokens",
"max message tokens",
"token limit",
"prompt is too long",
"exceed max message tokens",
}
for _, pattern := range contextPatterns {
if strings.Contains(errMsg, pattern) {
return true
}
}

// Provider-specific "invalid parameter" style errors frequently include token/length hints.
if strings.Contains(errMsg, "invalidparameter") &&
(strings.Contains(errMsg, "token") ||
strings.Contains(errMsg, "length") ||
strings.Contains(errMsg, "context")) {
return true
}

return false
}

// forceCompression aggressively reduces context when the limit is hit.
// It drops the oldest 50% of messages (keeping system prompt and last user message).
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
Expand Down
113 changes: 113 additions & 0 deletions pkg/agent/loop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,119 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
}
}

// TestAgentLoop_TransientLLMErrorRetry verifies transient 5xx failures are retried
// without triggering context compression.
func TestAgentLoop_TransientLLMErrorRetry(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)

cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}

msgBus := bus.NewMessageBus()
provider := &failFirstMockProvider{
failures: 1,
failError: fmt.Errorf("API request failed: status: 502 body: bad gateway"),
successResp: "Recovered from transient error",
}

al := NewAgentLoop(cfg, msgBus, provider)
routedSessionKey := "agent:main:main"

history := []providers.Message{
{Role: "system", Content: "System prompt"},
{Role: "user", Content: "Old message 1"},
{Role: "assistant", Content: "Old response 1"},
{Role: "user", Content: "Old message 2"},
{Role: "assistant", Content: "Old response 2"},
{Role: "user", Content: "Trigger message"},
}
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("No default agent found")
}
for _, m := range history {
defaultAgent.Sessions.AddFullMessage(routedSessionKey, m)
}

response, err := al.ProcessDirectWithChannel(
context.Background(),
"Trigger message",
routedSessionKey,
"test",
"test-chat",
)
if err != nil {
t.Fatalf("Expected success after transient retry, got error: %v", err)
}
if response != "Recovered from transient error" {
t.Errorf("Expected 'Recovered from transient error', got '%s'", response)
}
if provider.currentCall != 2 {
t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall)
}

// Transient errors should not trigger context compression.
finalHistory := defaultAgent.Sessions.GetHistory(routedSessionKey)
if len(finalHistory) != 8 {
t.Errorf("Expected no compression for transient retries (len == 8), got %d", len(finalHistory))
}
}

// TestAgentLoop_NonRetryableLLMError_NoRetry verifies non-retryable 4xx failures
// return immediately without additional attempts.
func TestAgentLoop_NonRetryableLLMError_NoRetry(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)

cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}

msgBus := bus.NewMessageBus()
provider := &failFirstMockProvider{
failures: 1,
failError: fmt.Errorf("API request failed: status: 400 body: invalid request"),
successResp: "should not be used",
}

al := NewAgentLoop(cfg, msgBus, provider)
_, err = al.ProcessDirectWithChannel(
context.Background(),
"Trigger message",
"test-session-no-retry",
"test",
"test-chat",
)
if err == nil {
t.Fatal("Expected non-retryable 400 error, got nil")
}
if provider.currentCall != 1 {
t.Errorf("Expected 1 call for non-retryable error, got %d", provider.currentCall)
}
}

func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions pkg/providers/error_classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ var (
substr("timed out"),
substr("deadline exceeded"),
substr("context deadline exceeded"),
substr("connection reset"),
substr("connection reset by peer"),
substr("tls handshake timeout"),
substr("eof"),
}

billingPatterns = []errorPattern{
Expand Down
4 changes: 4 additions & 0 deletions pkg/providers/error_classifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ func TestClassifyError_TimeoutPatterns(t *testing.T) {
"connection timed out",
"deadline exceeded",
"context deadline exceeded",
"connection reset",
"connection reset by peer",
"tls handshake timeout",
"EOF",
}

for _, msg := range patterns {
Expand Down
8 changes: 8 additions & 0 deletions pkg/providers/openai_compat/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ func (p *Provider) Chat(
}

if resp.StatusCode != http.StatusOK {
if retryAfter := strings.TrimSpace(resp.Header.Get("Retry-After")); retryAfter != "" {
return nil, fmt.Errorf(
"API request failed:\n Status: %d\n Retry-After: %s\n Body: %s",
resp.StatusCode,
retryAfter,
string(body),
)
}
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
}

Expand Down
22 changes: 20 additions & 2 deletions pkg/tools/toolloop.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ type ToolLoopConfig struct {
Tools *ToolRegistry
MaxIterations int
LLMOptions map[string]any
RetryPolicy *utils.RetryPolicy
RetryNotice func(string)
}

// ToolLoopResult contains the result of running the tool loop.
Expand Down Expand Up @@ -62,8 +64,24 @@ func RunToolLoop(
if llmOpts == nil {
llmOpts = map[string]any{}
}
// 3. Call LLM
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
// 3. Call LLM with bounded transient retries.
retryPolicy := utils.DefaultLLMRetryPolicy()
if config.RetryPolicy != nil {
retryPolicy = *config.RetryPolicy
}
existingNotify := retryPolicy.Notify
retryPolicy.Notify = func(notice utils.RetryNotice) {
if existingNotify != nil {
existingNotify(notice)
}
if config.RetryNotice != nil {
config.RetryNotice(utils.FormatLLMRetryNotice(notice))
}
}

response, err := utils.DoWithRetry(ctx, retryPolicy, func(callCtx context.Context) (*providers.LLMResponse, error) {
return config.Provider.Chat(callCtx, messages, providerToolDefs, config.Model, llmOpts)
})
if err != nil {
logger.ErrorCF("toolloop", "LLM call failed",
map[string]any{
Expand Down
Loading