From e171d7735347efa4f1f51d5b6bee1007c813992b Mon Sep 17 00:00:00 2001 From: pikaxinge <2392811793@qq.com> Date: Fri, 27 Feb 2026 10:04:51 +0000 Subject: [PATCH 1/3] feat(retry): add complete bounded LLM retry handling (#629) --- pkg/agent/loop.go | 75 +++++- pkg/agent/loop_test.go | 113 +++++++++ pkg/providers/error_classifier.go | 4 + pkg/providers/error_classifier_test.go | 2 + pkg/providers/openai_compat/provider.go | 8 + pkg/tools/toolloop.go | 22 +- pkg/tools/toolloop_test.go | 113 +++++++++ pkg/utils/llm_retry.go | 289 ++++++++++++++++++++++++ pkg/utils/llm_retry_test.go | 152 +++++++++++++ 9 files changed, 764 insertions(+), 14 deletions(-) create mode 100644 pkg/tools/toolloop_test.go create mode 100644 pkg/utils/llm_retry.go create mode 100644 pkg/utils/llm_retry_test.go diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 693f2227b..647f6187a 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -516,9 +516,9 @@ func (al *AgentLoop) runLLMIteration( var response *providers.LLMResponse var err error - callLLM := func() (*providers.LLMResponse, error) { + callLLMOnce := func(callCtx context.Context) (*providers.LLMResponse, error) { if len(agent.Candidates) > 1 && al.fallback != nil { - fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates, + fbResult, fbErr := al.fallback.Execute(callCtx, agent.Candidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{ "max_tokens": agent.MaxTokens, @@ -537,28 +537,45 @@ func (al *AgentLoop) runLLMIteration( } return fbResult.Response, nil } - return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]any{ + return agent.Provider.Chat(callCtx, messages, providerToolDefs, agent.Model, map[string]any{ "max_tokens": agent.MaxTokens, "temperature": agent.Temperature, "prompt_cache_key": agent.ID, }) } - // Retry loop for context/token errors + retryPolicy := utils.DefaultLLMRetryPolicy() + retryPolicy.Notify = func(notice utils.RetryNotice) { + logger.WarnCF("agent", "Transient LLM error detected, retrying", map[string]any{ + "attempt": notice.Attempt, + "total": notice.Total, + "reason": notice.Decision.Reason, + "status": notice.Decision.Status, + "retry_after": notice.Decision.RetryAfter.String(), + "backoff": notice.Delay.String(), + }) + + // User-facing notice only on first retry to avoid spam. + if notice.Attempt != 1 || constants.IsInternalChannel(opts.Channel) { + return + } + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: utils.FormatLLMRetryNotice(notice), + }) + } + + // Outer retry loop for context-window compression. + // Transient/network retries are handled inside DoWithRetry. maxRetries := 2 for retry := 0; retry <= maxRetries; retry++ { - response, err = callLLM() + response, err = utils.DoWithRetry(ctx, retryPolicy, callLLMOnce) if err == nil { break } - errMsg := strings.ToLower(err.Error()) - isContextError := strings.Contains(errMsg, "token") || - strings.Contains(errMsg, "context") || - strings.Contains(errMsg, "invalidparameter") || - strings.Contains(errMsg, "length") - - if isContextError && retry < maxRetries { + if retry < maxRetries && isContextWindowError(err) { logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{ "error": err.Error(), "retry": retry, @@ -766,6 +783,40 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c } } +func isContextWindowError(err error) bool { + if err == nil { + return false + } + + errMsg := strings.ToLower(err.Error()) + contextPatterns := []string{ + "context window", + "context length", + "maximum context length", + "max context length", + "too many tokens", + "max message tokens", + "token limit", + "prompt is too long", + "exceed max message tokens", + } + for _, pattern := range contextPatterns { + if strings.Contains(errMsg, pattern) { + return true + } + } + + // Provider-specific "invalid parameter" style errors frequently include token/length hints. + if strings.Contains(errMsg, "invalidparameter") && + (strings.Contains(errMsg, "token") || + strings.Contains(errMsg, "length") || + strings.Contains(errMsg, "context")) { + return true + } + + return false +} + // forceCompression aggressively reduces context when the limit is hit. // It drops the oldest 50% of messages (keeping system prompt and last user message). func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 4414398b1..ce3b4251b 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -631,3 +631,116 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory)) } } + +// TestAgentLoop_TransientLLMErrorRetry verifies transient 5xx failures are retried +// without triggering context compression. +func TestAgentLoop_TransientLLMErrorRetry(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &failFirstMockProvider{ + failures: 1, + failError: fmt.Errorf("API request failed: status: 502 body: bad gateway"), + successResp: "Recovered from transient error", + } + + al := NewAgentLoop(cfg, msgBus, provider) + routedSessionKey := "agent:main:main" + + history := []providers.Message{ + {Role: "system", Content: "System prompt"}, + {Role: "user", Content: "Old message 1"}, + {Role: "assistant", Content: "Old response 1"}, + {Role: "user", Content: "Old message 2"}, + {Role: "assistant", Content: "Old response 2"}, + {Role: "user", Content: "Trigger message"}, + } + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("No default agent found") + } + for _, m := range history { + defaultAgent.Sessions.AddFullMessage(routedSessionKey, m) + } + + response, err := al.ProcessDirectWithChannel( + context.Background(), + "Trigger message", + routedSessionKey, + "test", + "test-chat", + ) + if err != nil { + t.Fatalf("Expected success after transient retry, got error: %v", err) + } + if response != "Recovered from transient error" { + t.Errorf("Expected 'Recovered from transient error', got '%s'", response) + } + if provider.currentCall != 2 { + t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall) + } + + // Transient errors should not trigger context compression. + finalHistory := defaultAgent.Sessions.GetHistory(routedSessionKey) + if len(finalHistory) != 8 { + t.Errorf("Expected no compression for transient retries (len == 8), got %d", len(finalHistory)) + } +} + +// TestAgentLoop_NonRetryableLLMError_NoRetry verifies non-retryable 4xx failures +// return immediately without additional attempts. +func TestAgentLoop_NonRetryableLLMError_NoRetry(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &failFirstMockProvider{ + failures: 1, + failError: fmt.Errorf("API request failed: status: 400 body: invalid request"), + successResp: "should not be used", + } + + al := NewAgentLoop(cfg, msgBus, provider) + _, err = al.ProcessDirectWithChannel( + context.Background(), + "Trigger message", + "test-session-no-retry", + "test", + "test-chat", + ) + if err == nil { + t.Fatal("Expected non-retryable 400 error, got nil") + } + if provider.currentCall != 1 { + t.Errorf("Expected 1 call for non-retryable error, got %d", provider.currentCall) + } +} diff --git a/pkg/providers/error_classifier.go b/pkg/providers/error_classifier.go index a0f003006..064547888 100644 --- a/pkg/providers/error_classifier.go +++ b/pkg/providers/error_classifier.go @@ -41,6 +41,10 @@ var ( substr("timed out"), substr("deadline exceeded"), substr("context deadline exceeded"), + substr("connection reset"), + substr("connection reset by peer"), + substr("tls handshake timeout"), + substr("eof"), } billingPatterns = []errorPattern{ diff --git a/pkg/providers/error_classifier_test.go b/pkg/providers/error_classifier_test.go index 865aea57a..810561fdb 100644 --- a/pkg/providers/error_classifier_test.go +++ b/pkg/providers/error_classifier_test.go @@ -139,6 +139,8 @@ func TestClassifyError_TimeoutPatterns(t *testing.T) { "connection timed out", "deadline exceeded", "context deadline exceeded", + "connection reset by peer", + "tls handshake timeout", } for _, msg := range patterns { diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 7dace71f2..26d0e17b4 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -188,6 +188,14 @@ func (p *Provider) Chat( } if resp.StatusCode != http.StatusOK { + if retryAfter := strings.TrimSpace(resp.Header.Get("Retry-After")); retryAfter != "" { + return nil, fmt.Errorf( + "API request failed:\n Status: %d\n Retry-After: %s\n Body: %s", + resp.StatusCode, + retryAfter, + string(body), + ) + } return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) } diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index cdfe0d6ce..53f2c52e8 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -23,6 +23,8 @@ type ToolLoopConfig struct { Tools *ToolRegistry MaxIterations int LLMOptions map[string]any + RetryPolicy *utils.RetryPolicy + RetryNotice func(string) } // ToolLoopResult contains the result of running the tool loop. @@ -62,8 +64,24 @@ func RunToolLoop( if llmOpts == nil { llmOpts = map[string]any{} } - // 3. Call LLM - response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts) + // 3. Call LLM with bounded transient retries. + retryPolicy := utils.DefaultLLMRetryPolicy() + if config.RetryPolicy != nil { + retryPolicy = *config.RetryPolicy + } + existingNotify := retryPolicy.Notify + retryPolicy.Notify = func(notice utils.RetryNotice) { + if existingNotify != nil { + existingNotify(notice) + } + if config.RetryNotice != nil && channel != "" && chatID != "" { + config.RetryNotice(utils.FormatLLMRetryNotice(notice)) + } + } + + response, err := utils.DoWithRetry(ctx, retryPolicy, func(callCtx context.Context) (*providers.LLMResponse, error) { + return config.Provider.Chat(callCtx, messages, providerToolDefs, config.Model, llmOpts) + }) if err != nil { logger.ErrorCF("toolloop", "LLM call failed", map[string]any{ diff --git a/pkg/tools/toolloop_test.go b/pkg/tools/toolloop_test.go new file mode 100644 index 000000000..e156f02f5 --- /dev/null +++ b/pkg/tools/toolloop_test.go @@ -0,0 +1,113 @@ +package tools + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/utils" +) + +type flakyToolLoopProvider struct { + errors []error + calls int +} + +func (p *flakyToolLoopProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + options map[string]any, +) (*providers.LLMResponse, error) { + p.calls++ + if p.calls <= len(p.errors) { + return nil, p.errors[p.calls-1] + } + return &providers.LLMResponse{ + Content: "ok", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (p *flakyToolLoopProvider) GetDefaultModel() string { + return "mock-toolloop-model" +} + +func TestRunToolLoop_TransientRetry(t *testing.T) { + provider := &flakyToolLoopProvider{ + errors: []error{ + fmt.Errorf("API request failed: status: 502 body: bad gateway"), + }, + } + + notices := make([]string, 0, 1) + cfg := ToolLoopConfig{ + Provider: provider, + Model: "test-model", + MaxIterations: 1, + RetryPolicy: &utils.RetryPolicy{ + AttemptTimeouts: []time.Duration{time.Second, time.Second}, + }, + RetryNotice: func(content string) { + notices = append(notices, content) + }, + } + + result, err := RunToolLoop( + context.Background(), + cfg, + []providers.Message{{Role: "user", Content: "hello"}}, + "test", + "chat-1", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil || result.Content != "ok" { + t.Fatalf("unexpected result: %+v", result) + } + if provider.calls != 2 { + t.Fatalf("provider.calls = %d, want 2", provider.calls) + } + if len(notices) != 1 { + t.Fatalf("notices = %d, want 1", len(notices)) + } + if !strings.Contains(strings.ToLower(notices[0]), "retry") { + t.Fatalf("notice = %q, want retry hint", notices[0]) + } +} + +func TestRunToolLoop_NonRetryableError_NoRetry(t *testing.T) { + provider := &flakyToolLoopProvider{ + errors: []error{ + fmt.Errorf("API request failed: status: 400 body: bad request"), + }, + } + + cfg := ToolLoopConfig{ + Provider: provider, + Model: "test-model", + MaxIterations: 1, + RetryPolicy: &utils.RetryPolicy{ + AttemptTimeouts: []time.Duration{time.Second, time.Second}, + }, + } + + _, err := RunToolLoop( + context.Background(), + cfg, + []providers.Message{{Role: "user", Content: "hello"}}, + "test", + "chat-1", + ) + if err == nil { + t.Fatal("expected error, got nil") + } + if provider.calls != 1 { + t.Fatalf("provider.calls = %d, want 1", provider.calls) + } +} diff --git a/pkg/utils/llm_retry.go b/pkg/utils/llm_retry.go new file mode 100644 index 000000000..737cc0ab2 --- /dev/null +++ b/pkg/utils/llm_retry.go @@ -0,0 +1,289 @@ +package utils + +import ( + "context" + "fmt" + "math/rand" + "regexp" + "strconv" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// RetryDecision captures whether an LLM error should be retried. +type RetryDecision struct { + Retryable bool + Reason providers.FailoverReason + Status int + RetryAfter time.Duration +} + +// RetryNotice is emitted before waiting for the next retry attempt. +type RetryNotice struct { + Attempt int // failed attempt number, starts at 1 + Total int // total attempt count + Decision RetryDecision + Delay time.Duration +} + +type RetryNotifyFunc func(RetryNotice) +type RetrySleepFunc func(context.Context, time.Duration) error +type RetryJitterFunc func(time.Duration) time.Duration + +// RetryPolicy defines per-attempt timeouts and backoffs for retry execution. +type RetryPolicy struct { + AttemptTimeouts []time.Duration + Backoffs []time.Duration + MaxElapsed time.Duration + MaxJitter time.Duration + Notify RetryNotifyFunc + Sleep RetrySleepFunc + Jitter RetryJitterFunc +} + +var retryAfterPattern = regexp.MustCompile(`(?i)retry[- ]after[:=]?\s*([^\s\r\n]+)`) + +// DefaultLLMRetryPolicy returns the default retry behavior for LLM calls. +func DefaultLLMRetryPolicy() RetryPolicy { + return RetryPolicy{ + AttemptTimeouts: []time.Duration{45 * time.Second, 90 * time.Second, 120 * time.Second}, + Backoffs: []time.Duration{2 * time.Second, 5 * time.Second}, + MaxElapsed: 120 * time.Second, + MaxJitter: 500 * time.Millisecond, + } +} + +// ClassifyRetryDecision classifies retryability using providers.ClassifyError. +func ClassifyRetryDecision(err error) RetryDecision { + if err == nil { + return RetryDecision{} + } + + classified := providers.ClassifyError(err, "", "") + if classified == nil { + return RetryDecision{} + } + + decision := RetryDecision{ + Reason: classified.Reason, + Status: classified.Status, + } + + switch classified.Reason { + case providers.FailoverTimeout, providers.FailoverRateLimit: + decision.Retryable = true + default: + decision.Retryable = false + } + + if retryAfter, ok := extractRetryAfter(err, time.Now()); ok { + decision.RetryAfter = retryAfter + } + + return decision +} + +// DoWithRetry executes fn with retry according to policy. +func DoWithRetry[T any](ctx context.Context, policy RetryPolicy, fn func(context.Context) (T, error)) (T, error) { + var zero T + if len(policy.AttemptTimeouts) == 0 { + return fn(ctx) + } + + runCtx := ctx + cancelRun := func() {} + if policy.MaxElapsed > 0 { + runCtx, cancelRun = context.WithTimeout(ctx, policy.MaxElapsed) + } + defer cancelRun() + + sleepFn := policy.Sleep + if sleepFn == nil { + sleepFn = sleepWithContext + } + jitterFn := policy.Jitter + if jitterFn == nil { + jitterFn = defaultJitter + } + + var lastErr error + totalAttempts := len(policy.AttemptTimeouts) + for attempt := 0; attempt < totalAttempts; attempt++ { + if runCtx.Err() != nil { + return zero, runCtx.Err() + } + + attemptCtx := runCtx + cancelAttempt := func() {} + if attemptTimeout := policy.AttemptTimeouts[attempt]; attemptTimeout > 0 { + timeout, ok := boundedAttemptTimeout(runCtx, attemptTimeout) + if !ok { + return zero, runCtx.Err() + } + attemptCtx, cancelAttempt = context.WithTimeout(runCtx, timeout) + } + + val, err := fn(attemptCtx) + cancelAttempt() + if err == nil { + return val, nil + } + lastErr = err + + // No retries left. + if attempt == totalAttempts-1 { + break + } + + decision := ClassifyRetryDecision(err) + if !decision.Retryable { + break + } + + delay := retryDelay(policy, attempt, decision, jitterFn) + if policy.Notify != nil { + policy.Notify(RetryNotice{ + Attempt: attempt + 1, + Total: totalAttempts, + Decision: decision, + Delay: delay, + }) + } + + if delay > 0 { + if err := sleepFn(runCtx, delay); err != nil { + return zero, err + } + } + } + + return zero, lastErr +} + +// FormatLLMRetryNotice formats user-facing retry notice text. +func FormatLLMRetryNotice(notice RetryNotice) string { + nextAttempt := notice.Attempt + 1 + if nextAttempt > notice.Total { + nextAttempt = notice.Total + } + + switch notice.Decision.Reason { + case providers.FailoverRateLimit: + if notice.Decision.Status > 0 { + return fmt.Sprintf("LLM rate limited (%d). Retrying (%d/%d)...", notice.Decision.Status, nextAttempt, notice.Total) + } + return fmt.Sprintf("LLM rate limited. Retrying (%d/%d)...", nextAttempt, notice.Total) + case providers.FailoverTimeout: + if notice.Decision.Status > 0 { + return fmt.Sprintf("LLM timeout/server error (%d). Retrying (%d/%d)...", notice.Decision.Status, nextAttempt, notice.Total) + } + return fmt.Sprintf("Temporary LLM timeout. Retrying (%d/%d)...", nextAttempt, notice.Total) + default: + return fmt.Sprintf("Temporary LLM error. Retrying (%d/%d)...", nextAttempt, notice.Total) + } +} + +func retryDelay(policy RetryPolicy, attempt int, decision RetryDecision, jitterFn RetryJitterFunc) time.Duration { + if decision.RetryAfter > 0 { + return decision.RetryAfter + } + + if attempt < 0 || attempt >= len(policy.Backoffs) { + return 0 + } + + base := policy.Backoffs[attempt] + if base <= 0 { + return 0 + } + if policy.MaxJitter <= 0 { + return base + } + + jitter := jitterFn(policy.MaxJitter) + if jitter < 0 { + jitter = 0 + } + if jitter > policy.MaxJitter { + jitter = policy.MaxJitter + } + return base + jitter +} + +func boundedAttemptTimeout(ctx context.Context, configured time.Duration) (time.Duration, bool) { + if configured <= 0 { + return 0, false + } + + deadline, ok := ctx.Deadline() + if !ok { + return configured, true + } + + remaining := time.Until(deadline) + if remaining <= 0 { + return 0, false + } + if configured > remaining { + return remaining, true + } + return configured, true +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func defaultJitter(max time.Duration) time.Duration { + if max <= 0 { + return 0 + } + //nolint:gosec // Used only for retry backoff jitter. + n := rand.Int63n(int64(max) + 1) + return time.Duration(n) +} + +func extractRetryAfter(err error, now time.Time) (time.Duration, bool) { + if err == nil { + return 0, false + } + + matches := retryAfterPattern.FindStringSubmatch(err.Error()) + if len(matches) < 2 { + return 0, false + } + + value := strings.TrimSpace(matches[1]) + if value == "" { + return 0, false + } + + if secs, convErr := strconv.Atoi(value); convErr == nil { + if secs < 0 { + return 0, false + } + return time.Duration(secs) * time.Second, true + } + + for _, layout := range []string{time.RFC1123, time.RFC1123Z, time.RFC850, time.ANSIC} { + if t, parseErr := time.Parse(layout, value); parseErr == nil { + delay := t.Sub(now) + if delay <= 0 { + return 0, false + } + return delay, true + } + } + + return 0, false +} diff --git a/pkg/utils/llm_retry_test.go b/pkg/utils/llm_retry_test.go new file mode 100644 index 000000000..d3175e98a --- /dev/null +++ b/pkg/utils/llm_retry_test.go @@ -0,0 +1,152 @@ +package utils + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func TestLLMRetry_ClassifyRetryDecision_429WithRetryAfter(t *testing.T) { + err := errors.New("API request failed:\n Status: 429\n Retry-After: 7") + decision := ClassifyRetryDecision(err) + + if !decision.Retryable { + t.Fatal("expected 429 to be retryable") + } + if decision.Reason != providers.FailoverRateLimit { + t.Fatalf("reason = %q, want %q", decision.Reason, providers.FailoverRateLimit) + } + if decision.Status != 429 { + t.Fatalf("status = %d, want 429", decision.Status) + } + if decision.RetryAfter != 7*time.Second { + t.Fatalf("retry-after = %v, want 7s", decision.RetryAfter) + } +} + +func TestLLMRetry_DoWithRetry_ParentDeadlineDoesNotBurnAttempts(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + defer cancel() + + calls := 0 + _, err := DoWithRetry(ctx, RetryPolicy{ + AttemptTimeouts: []time.Duration{time.Second, time.Second}, + }, func(callCtx context.Context) (string, error) { + calls++ + <-callCtx.Done() + return "", callCtx.Err() + }) + + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("err = %v, want context deadline exceeded", err) + } + if calls != 1 { + t.Fatalf("calls = %d, want 1", calls) + } +} + +func TestLLMRetry_DoWithRetry_CancelDuringBackoff(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + calls := 0 + sleepCalled := false + + _, err := DoWithRetry(ctx, RetryPolicy{ + AttemptTimeouts: []time.Duration{time.Second, time.Second}, + Backoffs: []time.Duration{time.Hour}, + Sleep: func(waitCtx context.Context, _ time.Duration) error { + sleepCalled = true + <-waitCtx.Done() + return waitCtx.Err() + }, + }, func(context.Context) (string, error) { + calls++ + cancel() + return "", errors.New("API request failed: status: 502 body: bad gateway") + }) + + if !errors.Is(err, context.Canceled) { + t.Fatalf("err = %v, want context canceled", err) + } + if calls != 1 { + t.Fatalf("calls = %d, want 1", calls) + } + if !sleepCalled { + t.Fatal("expected sleep path to be called") + } +} + +func TestLLMRetry_DoWithRetry_JitterBoundedBackoff(t *testing.T) { + calls := 0 + var slept time.Duration + + _, err := DoWithRetry(context.Background(), RetryPolicy{ + AttemptTimeouts: []time.Duration{time.Second, time.Second}, + Backoffs: []time.Duration{100 * time.Millisecond}, + MaxJitter: 50 * time.Millisecond, + Jitter: func(max time.Duration) time.Duration { + if max != 50*time.Millisecond { + t.Fatalf("max jitter = %v, want 50ms", max) + } + return 37 * time.Millisecond + }, + Sleep: func(_ context.Context, d time.Duration) error { + slept = d + return nil + }, + }, func(context.Context) (string, error) { + calls++ + if calls == 1 { + return "", errors.New("API request failed: status: 502 body: bad gateway") + } + return "ok", nil + }) + + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if calls != 2 { + t.Fatalf("calls = %d, want 2", calls) + } + if slept != 137*time.Millisecond { + t.Fatalf("slept = %v, want 137ms", slept) + } +} + +func TestLLMRetry_DoWithRetry_UsesRetryAfterFor429(t *testing.T) { + calls := 0 + var slept time.Duration + + _, err := DoWithRetry(context.Background(), RetryPolicy{ + AttemptTimeouts: []time.Duration{time.Second, time.Second}, + Backoffs: []time.Duration{100 * time.Millisecond}, + MaxJitter: 80 * time.Millisecond, + Jitter: func(_ time.Duration) time.Duration { + return 50 * time.Millisecond + }, + Sleep: func(_ context.Context, d time.Duration) error { + slept = d + return nil + }, + }, func(context.Context) (string, error) { + calls++ + if calls == 1 { + return "", errors.New("API request failed:\n Status: 429\n Retry-After: 3") + } + return "ok", nil + }) + + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if calls != 2 { + t.Fatalf("calls = %d, want 2", calls) + } + if slept != 3*time.Second { + t.Fatalf("slept = %v, want 3s", slept) + } +} From 07e3f16ca85efa433d8fdccc5a3e67225fa76b7f Mon Sep 17 00:00:00 2001 From: pikaxinge <2392811793@qq.com> Date: Fri, 27 Feb 2026 10:45:18 +0000 Subject: [PATCH 2/3] fix(retry): handle Retry-After date and align review nits --- pkg/providers/error_classifier_test.go | 2 ++ pkg/utils/llm_retry.go | 16 ++-------------- pkg/utils/llm_retry_test.go | 13 +++++++++++++ 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/pkg/providers/error_classifier_test.go b/pkg/providers/error_classifier_test.go index 810561fdb..4c3dd9314 100644 --- a/pkg/providers/error_classifier_test.go +++ b/pkg/providers/error_classifier_test.go @@ -139,8 +139,10 @@ func TestClassifyError_TimeoutPatterns(t *testing.T) { "connection timed out", "deadline exceeded", "context deadline exceeded", + "connection reset", "connection reset by peer", "tls handshake timeout", + "EOF", } for _, msg := range patterns { diff --git a/pkg/utils/llm_retry.go b/pkg/utils/llm_retry.go index 737cc0ab2..dabf2a15b 100644 --- a/pkg/utils/llm_retry.go +++ b/pkg/utils/llm_retry.go @@ -43,7 +43,7 @@ type RetryPolicy struct { Jitter RetryJitterFunc } -var retryAfterPattern = regexp.MustCompile(`(?i)retry[- ]after[:=]?\s*([^\s\r\n]+)`) +var retryAfterPattern = regexp.MustCompile(`(?i)retry[- ]after[:=]?\s*([^\r\n]+)`) // DefaultLLMRetryPolicy returns the default retry behavior for LLM calls. func DefaultLLMRetryPolicy() RetryPolicy { @@ -101,7 +101,7 @@ func DoWithRetry[T any](ctx context.Context, policy RetryPolicy, fn func(context sleepFn := policy.Sleep if sleepFn == nil { - sleepFn = sleepWithContext + sleepFn = sleepWithCtx } jitterFn := policy.Jitter if jitterFn == nil { @@ -232,18 +232,6 @@ func boundedAttemptTimeout(ctx context.Context, configured time.Duration) (time. return configured, true } -func sleepWithContext(ctx context.Context, d time.Duration) error { - timer := time.NewTimer(d) - defer timer.Stop() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - return nil - } -} - func defaultJitter(max time.Duration) time.Duration { if max <= 0 { return 0 diff --git a/pkg/utils/llm_retry_test.go b/pkg/utils/llm_retry_test.go index d3175e98a..c51729598 100644 --- a/pkg/utils/llm_retry_test.go +++ b/pkg/utils/llm_retry_test.go @@ -150,3 +150,16 @@ func TestLLMRetry_DoWithRetry_UsesRetryAfterFor429(t *testing.T) { t.Fatalf("slept = %v, want 3s", slept) } } + +func TestLLMRetry_ExtractRetryAfter_HTTPDate(t *testing.T) { + now := time.Date(2015, 10, 21, 7, 27, 0, 0, time.UTC) + err := errors.New("API request failed:\n Status: 429\n Retry-After: Wed, 21 Oct 2015 07:28:00 GMT") + + delay, ok := extractRetryAfter(err, now) + if !ok { + t.Fatal("expected retry-after date to parse") + } + if delay != time.Minute { + t.Fatalf("delay = %v, want 1m", delay) + } +} From f8720e16ed558455e33cc8bf3d7dbbda807ccbef Mon Sep 17 00:00:00 2001 From: pikaxinge <2392811793@qq.com> Date: Sat, 28 Feb 2026 06:52:54 +0000 Subject: [PATCH 3/3] fix(retry): address review edge cases in notices and deadline handling --- pkg/tools/toolloop.go | 2 +- pkg/tools/toolloop_test.go | 44 +++++++++++++++++++++++++++++++++++++ pkg/utils/llm_retry.go | 5 ++++- pkg/utils/llm_retry_test.go | 35 +++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 2 deletions(-) diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index 53f2c52e8..e524916f6 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -74,7 +74,7 @@ func RunToolLoop( if existingNotify != nil { existingNotify(notice) } - if config.RetryNotice != nil && channel != "" && chatID != "" { + if config.RetryNotice != nil { config.RetryNotice(utils.FormatLLMRetryNotice(notice)) } } diff --git a/pkg/tools/toolloop_test.go b/pkg/tools/toolloop_test.go index e156f02f5..e707955a6 100644 --- a/pkg/tools/toolloop_test.go +++ b/pkg/tools/toolloop_test.go @@ -81,6 +81,50 @@ func TestRunToolLoop_TransientRetry(t *testing.T) { } } +func TestRunToolLoop_TransientRetry_NoticeWithoutChannelContext(t *testing.T) { + provider := &flakyToolLoopProvider{ + errors: []error{ + fmt.Errorf("API request failed: status: 502 body: bad gateway"), + }, + } + + notices := make([]string, 0, 1) + cfg := ToolLoopConfig{ + Provider: provider, + Model: "test-model", + MaxIterations: 1, + RetryPolicy: &utils.RetryPolicy{ + AttemptTimeouts: []time.Duration{time.Second, time.Second}, + }, + RetryNotice: func(content string) { + notices = append(notices, content) + }, + } + + result, err := RunToolLoop( + context.Background(), + cfg, + []providers.Message{{Role: "user", Content: "hello"}}, + "", + "", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil || result.Content != "ok" { + t.Fatalf("unexpected result: %+v", result) + } + if provider.calls != 2 { + t.Fatalf("provider.calls = %d, want 2", provider.calls) + } + if len(notices) != 1 { + t.Fatalf("notices = %d, want 1", len(notices)) + } + if !strings.Contains(strings.ToLower(notices[0]), "retry") { + t.Fatalf("notice = %q, want retry hint", notices[0]) + } +} + func TestRunToolLoop_NonRetryableError_NoRetry(t *testing.T) { provider := &flakyToolLoopProvider{ errors: []error{ diff --git a/pkg/utils/llm_retry.go b/pkg/utils/llm_retry.go index dabf2a15b..a53725667 100644 --- a/pkg/utils/llm_retry.go +++ b/pkg/utils/llm_retry.go @@ -120,7 +120,10 @@ func DoWithRetry[T any](ctx context.Context, policy RetryPolicy, fn func(context if attemptTimeout := policy.AttemptTimeouts[attempt]; attemptTimeout > 0 { timeout, ok := boundedAttemptTimeout(runCtx, attemptTimeout) if !ok { - return zero, runCtx.Err() + if err := runCtx.Err(); err != nil { + return zero, err + } + return zero, context.DeadlineExceeded } attemptCtx, cancelAttempt = context.WithTimeout(runCtx, timeout) } diff --git a/pkg/utils/llm_retry_test.go b/pkg/utils/llm_retry_test.go index c51729598..e83e50065 100644 --- a/pkg/utils/llm_retry_test.go +++ b/pkg/utils/llm_retry_test.go @@ -9,6 +9,19 @@ import ( "github.com/sipeed/picoclaw/pkg/providers" ) +type staleDeadlineContext struct { + context.Context + deadline time.Time +} + +func (c staleDeadlineContext) Deadline() (time.Time, bool) { + return c.deadline, true +} + +func (c staleDeadlineContext) Err() error { + return nil +} + func TestLLMRetry_ClassifyRetryDecision_429WithRetryAfter(t *testing.T) { err := errors.New("API request failed:\n Status: 429\n Retry-After: 7") decision := ClassifyRetryDecision(err) @@ -163,3 +176,25 @@ func TestLLMRetry_ExtractRetryAfter_HTTPDate(t *testing.T) { t.Fatalf("delay = %v, want 1m", delay) } } + +func TestLLMRetry_DoWithRetry_ExpiredDeadlineReturnsDeadlineExceeded(t *testing.T) { + ctx := staleDeadlineContext{ + Context: context.Background(), + deadline: time.Now().Add(-time.Second), + } + + calls := 0 + _, err := DoWithRetry(ctx, RetryPolicy{ + AttemptTimeouts: []time.Duration{time.Second}, + }, func(context.Context) (string, error) { + calls++ + return "ok", nil + }) + + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("err = %v, want context deadline exceeded", err) + } + if calls != 0 { + t.Fatalf("calls = %d, want 0", calls) + } +}