Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions pkg/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ type AgentLoop struct {
tools *tools.ToolRegistry
running atomic.Bool
summarizing sync.Map // Tracks which sessions are currently being summarized
fallback *providers.FallbackChain
candidates []providers.FallbackCandidate
providerFactory func(provider, model string) (providers.LLMProvider, error)
}

// processOptions configures how a message is processed
Expand Down Expand Up @@ -95,17 +98,31 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
contextBuilder := NewContextBuilder(workspace)
contextBuilder.SetToolsRegistry(toolsRegistry)

// Set up fallback chain
cooldown := providers.NewCooldownTracker()
fallbackChain := providers.NewFallbackChain(cooldown)

// Resolve model candidates from config
modelCfg := providers.ModelConfig{
Primary: cfg.Agents.Defaults.Model,
Fallbacks: cfg.Agents.Defaults.ModelFallbacks,
}
defaultProvider := cfg.Agents.Defaults.Provider
candidates := providers.ResolveCandidates(modelCfg, defaultProvider)

return &AgentLoop{
bus: msgBus,
provider: provider,
workspace: workspace,
model: cfg.Agents.Defaults.Model,
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
contextWindow: cfg.Agents.Defaults.MaxTokens,
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
sessions: sessionsManager,
contextBuilder: contextBuilder,
tools: toolsRegistry,
summarizing: sync.Map{},
fallback: fallbackChain,
candidates: candidates,
}
}

Expand Down Expand Up @@ -341,11 +358,35 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"tools_json": formatToolsForLog(providerToolDefs),
})

// Call LLM
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
// Call LLM with fallback chain if candidates are configured.
var response *providers.LLMResponse
var err error

if len(al.candidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(ctx, al.candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return al.provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
},
)
if fbErr != nil {
err = fbErr
} else {
response = fbResult.Response
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
map[string]interface{}{"iteration": iteration})
}
}
} else {
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
}

if err != nil {
logger.ErrorCF("agent", "LLM call failed",
Expand Down
43 changes: 36 additions & 7 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ type AgentsConfig struct {
}

type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
}

type ChannelsConfig struct {
Expand Down Expand Up @@ -348,6 +351,32 @@ func (c *Config) GetAPIBase() string {
return ""
}

// ModelConfig holds primary model and fallback list.
type ModelConfig struct {
Primary string
Fallbacks []string
}

// GetModelConfig returns the text model configuration with fallbacks.
func (c *Config) GetModelConfig() ModelConfig {
c.mu.RLock()
defer c.mu.RUnlock()
return ModelConfig{
Primary: c.Agents.Defaults.Model,
Fallbacks: c.Agents.Defaults.ModelFallbacks,
}
}

// GetImageModelConfig returns the image model configuration with fallbacks.
func (c *Config) GetImageModelConfig() ModelConfig {
c.mu.RLock()
defer c.mu.RUnlock()
return ModelConfig{
Primary: c.Agents.Defaults.ImageModel,
Fallbacks: c.Agents.Defaults.ImageModelFallbacks,
}
}

func expandHome(path string) string {
if path == "" {
return path
Expand Down
207 changes: 207 additions & 0 deletions pkg/providers/cooldown.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package providers

import (
"math"
"sync"
"time"
)

const (
defaultFailureWindow = 24 * time.Hour
)

// CooldownTracker manages per-provider cooldown state for the fallback chain.
// Thread-safe via sync.RWMutex. In-memory only (resets on restart).
type CooldownTracker struct {
mu sync.RWMutex
entries map[string]*cooldownEntry
failureWindow time.Duration
nowFunc func() time.Time // for testing
}

type cooldownEntry struct {
ErrorCount int
FailureCounts map[FailoverReason]int
CooldownEnd time.Time // standard cooldown expiry
DisabledUntil time.Time // billing-specific disable expiry
DisabledReason FailoverReason // reason for disable (billing)
LastFailure time.Time
}

// NewCooldownTracker creates a tracker with default 24h failure window.
func NewCooldownTracker() *CooldownTracker {
return &CooldownTracker{
entries: make(map[string]*cooldownEntry),
failureWindow: defaultFailureWindow,
nowFunc: time.Now,
}
}

// MarkFailure records a failure for a provider and sets appropriate cooldown.
// Resets error counts if last failure was more than failureWindow ago.
func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) {
ct.mu.Lock()
defer ct.mu.Unlock()

now := ct.nowFunc()
entry := ct.getOrCreate(provider)

// 24h failure window reset: if no failure in failureWindow, reset counters.
if !entry.LastFailure.IsZero() && now.Sub(entry.LastFailure) > ct.failureWindow {
entry.ErrorCount = 0
entry.FailureCounts = make(map[FailoverReason]int)
}

entry.ErrorCount++
entry.FailureCounts[reason]++
entry.LastFailure = now

if reason == FailoverBilling {
billingCount := entry.FailureCounts[FailoverBilling]
entry.DisabledUntil = now.Add(calculateBillingCooldown(billingCount))
entry.DisabledReason = FailoverBilling
} else {
entry.CooldownEnd = now.Add(calculateStandardCooldown(entry.ErrorCount))
}
}

// MarkSuccess resets all counters and cooldowns for a provider.
func (ct *CooldownTracker) MarkSuccess(provider string) {
ct.mu.Lock()
defer ct.mu.Unlock()

entry := ct.entries[provider]
if entry == nil {
return
}

entry.ErrorCount = 0
entry.FailureCounts = make(map[FailoverReason]int)
entry.CooldownEnd = time.Time{}
entry.DisabledUntil = time.Time{}
entry.DisabledReason = ""
}

// IsAvailable returns true if the provider is not in cooldown or disabled.
func (ct *CooldownTracker) IsAvailable(provider string) bool {
ct.mu.RLock()
defer ct.mu.RUnlock()

entry := ct.entries[provider]
if entry == nil {
return true
}

now := ct.nowFunc()

// Billing disable takes precedence (longer cooldown).
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
return false
}

// Standard cooldown.
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
return false
}

return true
}

// CooldownRemaining returns how long until the provider becomes available.
// Returns 0 if already available.
func (ct *CooldownTracker) CooldownRemaining(provider string) time.Duration {
ct.mu.RLock()
defer ct.mu.RUnlock()

entry := ct.entries[provider]
if entry == nil {
return 0
}

now := ct.nowFunc()
var remaining time.Duration

if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
d := entry.DisabledUntil.Sub(now)
if d > remaining {
remaining = d
}
}

if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
d := entry.CooldownEnd.Sub(now)
if d > remaining {
remaining = d
}
}

return remaining
}

// ErrorCount returns the current error count for a provider.
func (ct *CooldownTracker) ErrorCount(provider string) int {
ct.mu.RLock()
defer ct.mu.RUnlock()

entry := ct.entries[provider]
if entry == nil {
return 0
}
return entry.ErrorCount
}

// FailureCount returns the failure count for a specific reason.
func (ct *CooldownTracker) FailureCount(provider string, reason FailoverReason) int {
ct.mu.RLock()
defer ct.mu.RUnlock()

entry := ct.entries[provider]
if entry == nil {
return 0
}
return entry.FailureCounts[reason]
}

func (ct *CooldownTracker) getOrCreate(provider string) *cooldownEntry {
entry := ct.entries[provider]
if entry == nil {
entry = &cooldownEntry{
FailureCounts: make(map[FailoverReason]int),
}
ct.entries[provider] = entry
}
return entry
}

// calculateStandardCooldown computes standard exponential backoff.
// Formula from OpenClaw: min(1h, 1min * 5^min(n-1, 3))
//
// 1 error → 1 min
// 2 errors → 5 min
// 3 errors → 25 min
// 4+ errors → 1 hour (cap)
func calculateStandardCooldown(errorCount int) time.Duration {
n := max(1, errorCount)
exp := min(n-1, 3)
ms := 60_000 * int(math.Pow(5, float64(exp)))
ms = min(3_600_000, ms) // cap at 1 hour
return time.Duration(ms) * time.Millisecond
}

// calculateBillingCooldown computes billing-specific exponential backoff.
// Formula from OpenClaw: min(24h, 5h * 2^min(n-1, 10))
//
// 1 error → 5 hours
// 2 errors → 10 hours
// 3 errors → 20 hours
// 4+ errors → 24 hours (cap)
func calculateBillingCooldown(billingErrorCount int) time.Duration {
const baseMs = 5 * 60 * 60 * 1000 // 5 hours
const maxMs = 24 * 60 * 60 * 1000 // 24 hours

n := max(1, billingErrorCount)
exp := min(n-1, 10)
raw := float64(baseMs) * math.Pow(2, float64(exp))
ms := int(math.Min(float64(maxMs), raw))
return time.Duration(ms) * time.Millisecond
}
Loading
Loading