diff --git a/cmd/picoclaw/archive_cmd.go b/cmd/picoclaw/archive_cmd.go new file mode 100644 index 0000000..5c6b4ef --- /dev/null +++ b/cmd/picoclaw/archive_cmd.go @@ -0,0 +1,272 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/sipeed/picoclaw/pkg/archive/discordarchive" + "github.com/sipeed/picoclaw/pkg/session" +) + +type archiveRunOptions struct { + SessionKey string + All bool + OverLimit bool + DryRun bool +} + +type archiveRecallOptions struct { + SessionKey string + TopK int + MaxChars int + JSON bool + Query string +} + +func archiveCmd() { + if len(os.Args) < 3 { + archiveHelp() + return + } + switch strings.ToLower(strings.TrimSpace(os.Args[2])) { + case "discord": + archiveDiscordCmd(os.Args[3:]) + case "help", "--help", "-h": + archiveHelp() + default: + fmt.Printf("Unknown archive command: %s\n", os.Args[2]) + archiveHelp() + } +} + +func archiveDiscordCmd(args []string) { + if len(args) == 0 { + archiveDiscordHelp() + return + } + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + return + } + workspace := cfg.WorkspacePath() + sm := session.NewSessionManager(filepath.Join(workspace, "sessions")) + manager := discordarchive.NewManager(workspace, sm, cfg.Channels.Discord.Archive) + + sub := strings.ToLower(strings.TrimSpace(args[0])) + switch sub { + case "list": + overLimitOnly := false + jsonOut := false + for _, arg := range args[1:] { + switch strings.ToLower(strings.TrimSpace(arg)) { + case "--over-limit": + overLimitOnly = true + case "--json": + jsonOut = true + } + } + stats := manager.ListDiscordSessions(overLimitOnly) + if jsonOut { + out, _ := json.MarshalIndent(stats, "", " ") + fmt.Println(string(out)) + return + } + if len(stats) == 0 { + fmt.Println("No Discord sessions found.") + return + } + fmt.Println("Discord sessions:") + for _, stat := range stats { + over := "no" + if stat.OverLimit { + over = "yes" + } + fmt.Printf(" - %s | messages=%d tokens~%d over_limit=%s\n", stat.SessionKey, stat.Messages, stat.Tokens, over) + } + case "run": + opts, err := parseArchiveRunOptions(args[1:]) + if err != nil { + fmt.Printf("Error: %v\n", err) + archiveDiscordHelp() + return + } + if opts.SessionKey == "" && !opts.All && !opts.OverLimit { + opts.OverLimit = true + } + if opts.All { + opts.OverLimit = false + } + + if opts.SessionKey != "" { + result, err := manager.ArchiveSession(opts.SessionKey, opts.DryRun) + if err != nil { + fmt.Printf("Archive failed: %v\n", err) + return + } + if result == nil { + fmt.Println("No archive action taken.") + return + } + printArchiveResult(*result) + return + } + + results, err := manager.ArchiveAll(opts.OverLimit, opts.DryRun) + if err != nil { + fmt.Printf("Archive failed: %v\n", err) + return + } + if len(results) == 0 { + fmt.Println("No sessions archived.") + return + } + for _, result := range results { + printArchiveResult(result) + } + case "recall": + opts, err := parseArchiveRecallOptions(args[1:], cfg.Channels.Discord.Archive.RecallTopK, cfg.Channels.Discord.Archive.RecallMaxChars) + if err != nil { + fmt.Printf("Error: %v\n", err) + archiveDiscordHelp() + return + } + hits := manager.Recall(opts.Query, opts.SessionKey, opts.TopK, opts.MaxChars) + if opts.JSON { + out, _ := json.MarshalIndent(hits, "", " ") + fmt.Println(string(out)) + return + } + if len(hits) == 0 { + fmt.Println("No recall hits.") + return + } + for i, hit := range hits { + fmt.Printf("%d) score=%d session=%s file=%s\n", i+1, hit.Score, hit.SessionKey, hit.SourcePath) + fmt.Printf(" %s\n\n", hit.Text) + } + case "index": + // Phase 1 uses lexical recall directly over archive markdown. + fmt.Println("Index step is not required for phase-1 lexical recall (on-demand scan).") + case "help", "--help", "-h": + archiveDiscordHelp() + default: + fmt.Printf("Unknown archive discord command: %s\n", sub) + archiveDiscordHelp() + } +} + +func parseArchiveRunOptions(args []string) (archiveRunOptions, error) { + opts := archiveRunOptions{} + for i := 0; i < len(args); i++ { + switch args[i] { + case "--session-key": + if i+1 >= len(args) { + return opts, fmt.Errorf("--session-key requires a value") + } + opts.SessionKey = strings.TrimSpace(args[i+1]) + i++ + case "--all": + opts.All = true + case "--over-limit": + opts.OverLimit = true + case "--dry-run": + opts.DryRun = true + default: + return opts, fmt.Errorf("unknown option: %s", args[i]) + } + } + return opts, nil +} + +func parseArchiveRecallOptions(args []string, defaultTopK, defaultMaxChars int) (archiveRecallOptions, error) { + opts := archiveRecallOptions{ + TopK: defaultTopK, + MaxChars: defaultMaxChars, + } + queryParts := make([]string, 0, len(args)) + + for i := 0; i < len(args); i++ { + switch args[i] { + case "--top-k": + if i+1 >= len(args) { + return opts, fmt.Errorf("--top-k requires a value") + } + n, err := strconv.Atoi(args[i+1]) + if err != nil || n <= 0 { + return opts, fmt.Errorf("invalid --top-k value: %s", args[i+1]) + } + opts.TopK = n + i++ + case "--max-chars": + if i+1 >= len(args) { + return opts, fmt.Errorf("--max-chars requires a value") + } + n, err := strconv.Atoi(args[i+1]) + if err != nil || n <= 0 { + return opts, fmt.Errorf("invalid --max-chars value: %s", args[i+1]) + } + opts.MaxChars = n + i++ + case "--session-key": + if i+1 >= len(args) { + return opts, fmt.Errorf("--session-key requires a value") + } + opts.SessionKey = strings.TrimSpace(args[i+1]) + i++ + case "--json": + opts.JSON = true + default: + queryParts = append(queryParts, args[i]) + } + } + + opts.Query = strings.TrimSpace(strings.Join(queryParts, " ")) + if opts.Query == "" { + return opts, fmt.Errorf("query is required") + } + if opts.TopK <= 0 { + opts.TopK = 6 + } + if opts.MaxChars <= 0 { + opts.MaxChars = 3000 + } + return opts, nil +} + +func printArchiveResult(result discordarchive.ArchiveResult) { + mode := "archived" + if result.DryRun { + mode = "dry-run" + } + fmt.Printf( + "%s: %s | archived=%d kept=%d tokens~%d->%d file=%s\n", + mode, + result.SessionKey, + result.ArchivedMessages, + result.KeptMessages, + result.TokensBefore, + result.TokensAfter, + result.ArchivePath, + ) +} + +func archiveHelp() { + commandName := invokedCLIName() + fmt.Println("\nArchive commands:") + fmt.Printf(" %s archive discord \n", commandName) + fmt.Printf(" %s archive discord help\n", commandName) +} + +func archiveDiscordHelp() { + commandName := invokedCLIName() + fmt.Println("\nArchive Discord commands:") + fmt.Printf(" %s archive discord list [--over-limit] [--json]\n", commandName) + fmt.Printf(" %s archive discord run [--session-key | --all | --over-limit] [--dry-run]\n", commandName) + fmt.Printf(" %s archive discord recall [--top-k ] [--max-chars ] [--session-key ] [--json]\n", commandName) + fmt.Printf(" %s archive discord index\n", commandName) +} diff --git a/cmd/picoclaw/archive_cmd_test.go b/cmd/picoclaw/archive_cmd_test.go new file mode 100644 index 0000000..c1d3e22 --- /dev/null +++ b/cmd/picoclaw/archive_cmd_test.go @@ -0,0 +1,54 @@ +package main + +import "testing" + +func TestParseArchiveRunOptions(t *testing.T) { + opts, err := parseArchiveRunOptions([]string{"--session-key", "discord:1", "--dry-run"}) + if err != nil { + t.Fatalf("parseArchiveRunOptions error: %v", err) + } + if opts.SessionKey != "discord:1" { + t.Fatalf("unexpected session key: %q", opts.SessionKey) + } + if !opts.DryRun { + t.Fatal("expected dry-run=true") + } +} + +func TestParseArchiveRunOptionsUnknown(t *testing.T) { + if _, err := parseArchiveRunOptions([]string{"--nope"}); err == nil { + t.Fatal("expected error for unknown option") + } +} + +func TestParseArchiveRecallOptions(t *testing.T) { + opts, err := parseArchiveRecallOptions( + []string{"alpha", "token", "--top-k", "4", "--max-chars", "1200", "--session-key", "discord:1", "--json"}, + 6, + 3000, + ) + if err != nil { + t.Fatalf("parseArchiveRecallOptions error: %v", err) + } + if opts.Query != "alpha token" { + t.Fatalf("unexpected query: %q", opts.Query) + } + if opts.TopK != 4 { + t.Fatalf("unexpected top-k: %d", opts.TopK) + } + if opts.MaxChars != 1200 { + t.Fatalf("unexpected max chars: %d", opts.MaxChars) + } + if opts.SessionKey != "discord:1" { + t.Fatalf("unexpected session key: %q", opts.SessionKey) + } + if !opts.JSON { + t.Fatal("expected json=true") + } +} + +func TestParseArchiveRecallOptionsMissingQuery(t *testing.T) { + if _, err := parseArchiveRecallOptions([]string{"--top-k", "2"}, 6, 3000); err == nil { + t.Fatal("expected missing query error") + } +} diff --git a/cmd/picoclaw/doctor_cmd.go b/cmd/picoclaw/doctor_cmd.go index 0b6b0c0..7a3abdc 100644 --- a/cmd/picoclaw/doctor_cmd.go +++ b/cmd/picoclaw/doctor_cmd.go @@ -1032,11 +1032,14 @@ func resolveBundledSkillsDir() string { if err != nil { realExe = exe } - // Typical Homebrew layout: .../Cellar/sciclaw//bin/sciclaw - // Share dir: .../Cellar/sciclaw//share/sciclaw/skills - share := filepath.Clean(filepath.Join(filepath.Dir(realExe), "..", "share", primaryCLIName, "skills")) - if fileExists(share) { - return share + return resolveBundledSkillsDirForExecutable(realExe) +} + +func resolveBundledSkillsDirForExecutable(exePath string) string { + for _, dir := range skillSourceDirsForExecutable(exePath) { + if dirHasSkillMarkdown(dir) { + return dir + } } return "" } diff --git a/cmd/picoclaw/doctor_skills_test.go b/cmd/picoclaw/doctor_skills_test.go new file mode 100644 index 0000000..5f95f97 --- /dev/null +++ b/cmd/picoclaw/doctor_skills_test.go @@ -0,0 +1,34 @@ +package main + +import ( + "os" + "path/filepath" + "testing" +) + +func TestResolveBundledSkillsDirForExecutableFindsDevFormulaShareDir(t *testing.T) { + root := t.TempDir() + exePath := filepath.Join(root, "Cellar", "sciclaw-dev", "0.1.66-dev.15", "bin", "sciclaw") + want := filepath.Join(root, "Cellar", "sciclaw-dev", "0.1.66-dev.15", "share", "sciclaw-dev", "skills") + + if err := os.MkdirAll(filepath.Dir(exePath), 0o755); err != nil { + t.Fatalf("mkdir exe dir: %v", err) + } + if err := os.WriteFile(exePath, []byte(""), 0o755); err != nil { + t.Fatalf("write exe: %v", err) + } + + skillDir := filepath.Join(want, "pandoc-docx") + if err := os.MkdirAll(skillDir, 0o755); err != nil { + t.Fatalf("mkdir skills dir: %v", err) + } + if err := os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# test"), 0o644); err != nil { + t.Fatalf("write SKILL.md: %v", err) + } + + got := resolveBundledSkillsDirForExecutable(exePath) + if filepath.Clean(got) != filepath.Clean(want) { + t.Fatalf("resolveBundledSkillsDirForExecutable() = %q, want %q", got, want) + } +} + diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 16651e9..c766491 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -235,6 +235,8 @@ func main() { } case "backup": backupCmd() + case "archive": + archiveCmd() case "version", "--version", "-v": printVersion() default: @@ -269,6 +271,7 @@ func printHelp() { fmt.Println(" migrate Migrate from OpenClaw to sciClaw") fmt.Println(" skills Manage skills (install, list, remove)") fmt.Println(" backup Backup key sciClaw config/workspace files") + fmt.Println(" archive Manage Discord archive/recall memory") fmt.Println(" version Show version information") fmt.Println() fmt.Println("Agent flags:") diff --git a/cmd/picoclaw/tui/tab_doctor.go b/cmd/picoclaw/tui/tab_doctor.go index 1d56e89..8640c10 100644 --- a/cmd/picoclaw/tui/tab_doctor.go +++ b/cmd/picoclaw/tui/tab_doctor.go @@ -240,15 +240,19 @@ func runDoctorCmd(exec Executor) tea.Cmd { return func() tea.Msg { cmd := "HOME=" + exec.HomePath() + " " + shellEscape(exec.BinaryPath()) + " doctor --json 2>&1" out, err := exec.ExecShell(90*time.Second, cmd) - if err != nil { - return doctorDoneMsg{err: fmt.Errorf("command failed: %w", err)} - } - var rep doctorReport - if err := json.Unmarshal([]byte(out), &rep); err != nil { + if parseErr := json.Unmarshal([]byte(out), &rep); parseErr == nil { + // doctor --json exits non-zero when checks contain hard errors. + // If JSON parsed successfully, surface the report instead of a generic exit-status error. + return doctorDoneMsg{report: &rep} + } else if err == nil { // If JSON parsing fails, the output might be plain text. - return doctorDoneMsg{err: fmt.Errorf("failed to parse results: %w\n\nRaw output:\n%s", err, out)} + return doctorDoneMsg{err: fmt.Errorf("failed to parse results: %w\n\nRaw output:\n%s", parseErr, out)} + } + + if strings.TrimSpace(out) != "" { + return doctorDoneMsg{err: fmt.Errorf("command failed: %w\n\nRaw output:\n%s", err, out)} } - return doctorDoneMsg{report: &rep} + return doctorDoneMsg{err: fmt.Errorf("command failed: %w", err)} } } diff --git a/cmd/picoclaw/tui/tab_doctor_test.go b/cmd/picoclaw/tui/tab_doctor_test.go new file mode 100644 index 0000000..364010f --- /dev/null +++ b/cmd/picoclaw/tui/tab_doctor_test.go @@ -0,0 +1,101 @@ +package tui + +import ( + "errors" + "os" + "os/exec" + "strings" + "testing" + "time" +) + +type doctorTestExec struct { + out string + err error + lastShell string +} + +func (e *doctorTestExec) Mode() Mode { return ModeLocal } + +func (e *doctorTestExec) ExecShell(_ time.Duration, shellCmd string) (string, error) { + e.lastShell = shellCmd + return e.out, e.err +} + +func (e *doctorTestExec) ExecCommand(_ time.Duration, _ ...string) (string, error) { return "", nil } + +func (e *doctorTestExec) ReadFile(_ string) (string, error) { return "", os.ErrNotExist } + +func (e *doctorTestExec) WriteFile(_ string, _ []byte, _ os.FileMode) error { return nil } + +func (e *doctorTestExec) ConfigPath() string { return "/tmp/config.json" } + +func (e *doctorTestExec) AuthPath() string { return "/tmp/auth.json" } + +func (e *doctorTestExec) HomePath() string { return "/home/tester" } + +func (e *doctorTestExec) BinaryPath() string { return "sciclaw" } + +func (e *doctorTestExec) AgentVersion() string { return "vtest" } + +func (e *doctorTestExec) ServiceInstalled() bool { return false } + +func (e *doctorTestExec) ServiceActive() bool { return false } + +func (e *doctorTestExec) InteractiveProcess(_ ...string) *exec.Cmd { return exec.Command("true") } + +func TestRunDoctorCmd_ParsesJSONOnNonZeroExit(t *testing.T) { + execStub := &doctorTestExec{ + out: `{ + "cli":"sciclaw", + "version":"0.1.66-dev.14", + "os":"linux", + "arch":"amd64", + "timestamp":"2026-02-27T04:21:32Z", + "checks":[ + {"name":"auth.openai","status":"error","message":"expired (oauth)"}, + {"name":"workspace","status":"ok","message":"/home/ernie/sciclaw"} + ] +}`, + err: errors.New("exit status 1"), + } + + msg, ok := runDoctorCmd(execStub)().(doctorDoneMsg) + if !ok { + t.Fatalf("unexpected message type: %T", msg) + } + if msg.err != nil { + t.Fatalf("expected nil error when JSON report is valid, got %v", msg.err) + } + if msg.report == nil { + t.Fatalf("expected report to be present") + } + if len(msg.report.Checks) != 2 { + t.Fatalf("expected 2 checks, got %d", len(msg.report.Checks)) + } + if msg.report.Checks[0].Status != dcErr { + t.Fatalf("expected first check status %q, got %q", dcErr, msg.report.Checks[0].Status) + } + if !strings.Contains(execStub.lastShell, "doctor --json") { + t.Fatalf("expected doctor command, got %q", execStub.lastShell) + } +} + +func TestRunDoctorCmd_ReportsCommandFailureWhenOutputIsNotJSON(t *testing.T) { + execStub := &doctorTestExec{ + out: "some plain text error", + err: errors.New("exit status 1"), + } + + msg := runDoctorCmd(execStub)().(doctorDoneMsg) + if msg.err == nil { + t.Fatalf("expected error for non-JSON output") + } + if !strings.Contains(msg.err.Error(), "command failed: exit status 1") { + t.Fatalf("expected wrapped exit status, got %v", msg.err) + } + if !strings.Contains(msg.err.Error(), "some plain text error") { + t.Fatalf("expected raw output in error, got %v", msg.err) + } +} + diff --git a/config/config.example.json b/config/config.example.json index d61868b..d19f09b 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -19,7 +19,18 @@ "discord": { "enabled": false, "token": "YOUR_DISCORD_BOT_TOKEN", - "allow_from": [] + "allow_from": [], + "archive": { + "enabled": true, + "auto_archive": true, + "max_session_tokens": 24000, + "max_session_messages": 120, + "keep_user_pairs": 12, + "min_tail_messages": 4, + "recall_top_k": 6, + "recall_max_chars": 3000, + "recall_min_score": 0.2 + } }, "maixcam": { "enabled": false, diff --git a/pkg/agent/deterministic_fallback_test.go b/pkg/agent/deterministic_fallback_test.go index f45401f..ef47605 100644 --- a/pkg/agent/deterministic_fallback_test.go +++ b/pkg/agent/deterministic_fallback_test.go @@ -86,7 +86,7 @@ func TestDeterministicFallback_NonIntentUsesDefaultResponse(t *testing.T) { if err != nil { t.Fatalf("ProcessDirectWithChannel error: %v", err) } - want := "I've completed processing but have no response to give." + want := defaultEmptyAssistantResponse if resp != want { t.Fatalf("expected default response %q, got %q", want, resp) } diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index df2edd8..4c90df3 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -9,6 +9,7 @@ package agent import ( "context" "encoding/json" + "errors" "fmt" "os" "path/filepath" @@ -17,6 +18,7 @@ import ( "sync/atomic" "time" + "github.com/sipeed/picoclaw/pkg/archive/discordarchive" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" @@ -47,6 +49,12 @@ type AgentLoop struct { state *state.Manager contextBuilder *ContextBuilder tools *tools.ToolRegistry + discordArchive *discordarchive.Manager + archiveEnabled bool + archiveAuto bool + recallTopK int + recallMaxChars int + discordRecallFn func(query, sessionKey string, topK, maxChars int) ([]discordarchive.RecallHit, error) hooks *hooks.Dispatcher hookAuditPath string turnCounter uint64 @@ -67,6 +75,12 @@ type processOptions struct { NoHistory bool // If true, don't load session history (for heartbeat) } +const defaultEmptyAssistantResponse = "I completed the turn but did not produce a user-facing reply. Ask me for a summary of what was done." + +var errDiscordAutoArchiveTimedOut = errors.New("discord auto-archive timed out") + +const discordAutoArchiveTimeout = 750 * time.Millisecond + // createToolRegistry creates a tool registry with common tools. // This is shared between main agent and subagents. func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry { @@ -179,6 +193,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers toolsRegistry.Register(subagentTool) sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) + discordArchiveMgr := discordarchive.NewManager(workspace, sessionsManager, cfg.Channels.Discord.Archive) // Create state manager for atomic state persistence stateManager := state.NewManager(workspace) @@ -216,7 +231,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers hookDispatcher = hooks.NewDispatcher(auditSink) if hookPolicy.Enabled || hookPolicyErr != nil { provenanceHandler := &builtin.ProvenanceHandler{} - policyHandler := builtin.NewPolicyHandler(workspace) + policyHandler := builtin.NewPolicyHandler(hookPolicy, hookDiag, hookPolicyErr) if hookPolicyErr != nil { for _, ev := range hooks.KnownEvents() { hookDispatcher.Register(ev, provenanceHandler) @@ -253,9 +268,20 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers state: stateManager, contextBuilder: contextBuilder, tools: toolsRegistry, - hooks: hookDispatcher, - hookAuditPath: hookAuditPath, - summarizing: sync.Map{}, + discordArchive: discordArchiveMgr, + archiveEnabled: cfg.Channels.Discord.Archive.Enabled, + archiveAuto: cfg.Channels.Discord.Archive.AutoArchive, + recallTopK: cfg.Channels.Discord.Archive.RecallTopK, + recallMaxChars: cfg.Channels.Discord.Archive.RecallMaxChars, + discordRecallFn: func(query, sessionKey string, topK, maxChars int) ([]discordarchive.RecallHit, error) { + if discordArchiveMgr == nil { + return nil, nil + } + return discordArchiveMgr.Recall(query, sessionKey, topK, maxChars), nil + }, + hooks: hookDispatcher, + hookAuditPath: hookAuditPath, + summarizing: sync.Map{}, } } @@ -326,6 +352,15 @@ func (al *AgentLoop) HandleInbound(ctx context.Context, msg bus.InboundMessage) } if response == "" { + if msg.Channel != "system" { + logger.InfoCF("agent", "No outbound response generated", + map[string]interface{}{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "sender_id": msg.SenderID, + "session_key": msg.SessionKey, + }) + } return } @@ -344,6 +379,15 @@ func (al *AgentLoop) HandleInbound(ctx context.Context, msg bus.InboundMessage) ChatID: msg.ChatID, Content: response, }) + } else { + logger.InfoCF("agent", "Suppressing outbound response because message tool already sent content", + map[string]interface{}{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "sender_id": msg.SenderID, + "session_key": msg.SessionKey, + "response_preview": utils.Truncate(response, 120), + }) } } @@ -351,6 +395,15 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) { al.tools.Register(tool) } +func (al *AgentLoop) messageToolSentInRound() bool { + if tool, ok := al.tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + return mt.HasSentInRound() + } + } + return false +} + // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { @@ -387,7 +440,7 @@ func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, cha Channel: channel, ChatID: chatID, UserMessage: content, - DefaultResponse: "I've completed processing but have no response to give.", + DefaultResponse: defaultEmptyAssistantResponse, EnableSummary: false, SendResponse: false, NoHistory: true, // Don't load session history for heartbeat @@ -421,7 +474,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) Channel: msg.Channel, ChatID: msg.ChatID, UserMessage: msg.Content, - DefaultResponse: "I've completed processing but have no response to give.", + DefaultResponse: defaultEmptyAssistantResponse, EnableSummary: true, SendResponse: false, }) @@ -485,6 +538,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str if opts.TurnID == "" { opts.TurnID = al.nextTurnID() } + turnStartedAt := time.Now() // 0. Record last channel for heartbeat notifications (skip internal channels) if opts.Channel != "" && opts.ChatID != "" { @@ -515,6 +569,65 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str if !opts.NoHistory { history = al.sessions.GetHistory(opts.SessionKey) summary = al.sessions.GetSummary(opts.SessionKey) + if opts.Channel == "discord" && al.archiveEnabled && al.archiveAuto && al.discordArchive != nil { + archiveStartedAt := time.Now() + logger.InfoCF("archive", "discord auto-archive start", map[string]interface{}{ + "session_key": opts.SessionKey, + }) + archived, err := al.maybeArchiveDiscordSession(opts.SessionKey) + if err != nil { + logger.WarnCF("archive", fmt.Sprintf("discord auto-archive failed: %v", err), map[string]interface{}{ + "session_key": opts.SessionKey, + "error": err.Error(), + "duration_ms": time.Since(archiveStartedAt).Milliseconds(), + }) + } else { + logger.InfoCF("archive", "discord auto-archive complete", map[string]interface{}{ + "session_key": opts.SessionKey, + "archived": archived, + "duration_ms": time.Since(archiveStartedAt).Milliseconds(), + }) + if archived { + // Reload post-archive state to keep prompt context bounded. + history = al.sessions.GetHistory(opts.SessionKey) + summary = al.sessions.GetSummary(opts.SessionKey) + } + } + } + if opts.Channel == "discord" && al.archiveEnabled && al.discordRecallFn != nil { + recallStartedAt := time.Now() + logger.InfoCF("archive", "discord auto-recall start", map[string]interface{}{ + "session_key": opts.SessionKey, + "query_chars": len(opts.UserMessage), + }) + recallSection, hits, err := al.buildDiscordRecallContext(opts.UserMessage, opts.SessionKey) + if err != nil { + logger.WarnCF("archive", fmt.Sprintf("discord auto-recall failed: %v", err), map[string]interface{}{ + "session_key": opts.SessionKey, + "error": err.Error(), + "duration_ms": time.Since(recallStartedAt).Milliseconds(), + }) + } else if recallSection != "" { + if strings.TrimSpace(summary) == "" { + summary = recallSection + } else { + summary = strings.TrimSpace(summary) + "\n\n" + recallSection + } + logger.InfoCF("archive", "discord auto-recall injected", + map[string]interface{}{ + "session_key": opts.SessionKey, + "hits": hits, + "chars": len(recallSection), + "duration_ms": time.Since(recallStartedAt).Milliseconds(), + }) + } else { + logger.InfoCF("archive", "discord auto-recall complete with no hits", + map[string]interface{}{ + "session_key": opts.SessionKey, + "duration_ms": time.Since(recallStartedAt).Milliseconds(), + }) + } + } } messages := al.contextBuilder.BuildMessages( history, @@ -524,6 +637,17 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str opts.Channel, opts.ChatID, ) + logger.InfoCF("agent", "Turn context prepared", + map[string]interface{}{ + "turn_id": opts.TurnID, + "channel": opts.Channel, + "chat_id": opts.ChatID, + "session_key": opts.SessionKey, + "history_count": len(history), + "summary_chars": len(summary), + "user_chars": len(opts.UserMessage), + "prompt_messages": len(messages), + }) // 3. Save user message to session al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) @@ -546,6 +670,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str }) return "", err } + messageToolSent := al.messageToolSentInRound() // If last tool had ForUser content and we already sent it, we might not need to send final response // This is controlled by the tool's Silent flag and ForUser content @@ -554,13 +679,33 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str if finalContent == "" { if recovered, ok := al.tryDeterministicFallback(ctx, opts); ok { finalContent = recovered + } else if messageToolSent { + // Message tool already delivered content to user; don't add placeholder noise. + logger.InfoCF("agent", "Suppressing default empty-response placeholder after message tool send", + map[string]interface{}{ + "channel": opts.Channel, + "chat_id": opts.ChatID, + "session_key": opts.SessionKey, + "iterations": iteration, + }) } else { finalContent = opts.DefaultResponse } } // 6. Save final assistant message to session - al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent) + if strings.TrimSpace(finalContent) != "" { + al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent) + } else { + logger.InfoCF("agent", "Skipping assistant session write for empty final content", + map[string]interface{}{ + "channel": opts.Channel, + "chat_id": opts.ChatID, + "session_key": opts.SessionKey, + "iterations": iteration, + "message_tool_sent": messageToolSent, + }) + } al.sessions.Save(opts.SessionKey) // 7. Optional: summarization @@ -569,7 +714,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str } // 8. Optional: send response via bus - if opts.SendResponse { + if opts.SendResponse && strings.TrimSpace(finalContent) != "" { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, @@ -578,13 +723,25 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str } // 9. Log response - responsePreview := utils.Truncate(finalContent, 120) - logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), - map[string]interface{}{ - "session_key": opts.SessionKey, - "iterations": iteration, - "final_length": len(finalContent), - }) + if strings.TrimSpace(finalContent) == "" { + logger.InfoCF("agent", "No final assistant text emitted", + map[string]interface{}{ + "session_key": opts.SessionKey, + "iterations": iteration, + "message_tool_sent": messageToolSent, + "turn_ms": time.Since(turnStartedAt).Milliseconds(), + }) + } else { + responsePreview := utils.Truncate(finalContent, 120) + logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), + map[string]interface{}{ + "session_key": opts.SessionKey, + "iterations": iteration, + "final_length": len(finalContent), + "message_tool_sent": messageToolSent, + "turn_ms": time.Since(turnStartedAt).Milliseconds(), + }) + } al.dispatchHook(ctx, hooks.EventAfterTurn, hooks.Context{ Timestamp: time.Now(), TurnID: opts.TurnID, @@ -595,7 +752,9 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str UserMessage: sanitizeHookText(opts.UserMessage), LLMResponseSummary: sanitizeHookText(finalContent), Metadata: map[string]any{ - "iterations": iteration, + "iterations": iteration, + "message_tool_sent": messageToolSent, + "turn_ms": time.Since(turnStartedAt).Milliseconds(), }, }) @@ -608,7 +767,6 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M iteration := 0 var finalContent string var lastMessageToolContent string - messageToolFallbackEligible := constants.IsInternalChannel(opts.Channel) for { if al.maxIterations > 0 && iteration >= al.maxIterations { @@ -618,7 +776,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "max": al.maxIterations, }) if finalContent == "" { - if messageToolFallbackEligible && lastMessageToolContent != "" { + if lastMessageToolContent != "" { finalContent = lastMessageToolContent } else { finalContent = fmt.Sprintf("Iteration limit reached (%d) before task completion. Increase `agents.defaults.max_tool_iterations` or set it to 0 for no hard cap.", al.maxIterations) @@ -684,6 +842,18 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M if al.reasoningEffort != "" { llmOpts["reasoning_effort"] = al.reasoningEffort } + llmCallStartedAt := time.Now() + logger.InfoCF("agent", "LLM call start", + map[string]interface{}{ + "turn_id": opts.TurnID, + "iteration": iteration, + "model": al.model, + "channel": opts.Channel, + "chat_id": opts.ChatID, + "session_key": opts.SessionKey, + "messages_count": len(messages), + "tools_count": len(providerToolDefs), + }) response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, llmOpts) if err != nil { @@ -691,9 +861,18 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M map[string]interface{}{ "iteration": iteration, "error": err.Error(), + "duration": time.Since(llmCallStartedAt).String(), }) return "", iteration, fmt.Errorf("LLM call failed: %w", err) } + logger.InfoCF("agent", "LLM call complete", + map[string]interface{}{ + "turn_id": opts.TurnID, + "iteration": iteration, + "duration": time.Since(llmCallStartedAt).String(), + "response_chars": len(response.Content), + "tool_calls_count": len(response.ToolCalls), + }) al.dispatchHook(ctx, hooks.EventAfterLLM, hooks.Context{ Timestamp: time.Now(), TurnID: opts.TurnID, @@ -714,11 +893,11 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M finalContent = response.Content trimmedFinal := strings.TrimSpace(finalContent) trimmedDefault := strings.TrimSpace(opts.DefaultResponse) - if messageToolFallbackEligible && lastMessageToolContent != "" && (trimmedFinal == "" || (trimmedDefault != "" && trimmedFinal == trimmedDefault)) { + if lastMessageToolContent != "" && (trimmedFinal == "" || (trimmedDefault != "" && trimmedFinal == trimmedDefault)) { finalContent = lastMessageToolContent - logger.InfoCF("agent", "Using message tool fallback content for internal channel", + logger.InfoCF("agent", "Using message tool fallback content", map[string]interface{}{ - "channel": opts.Channel, + "channel": opts.Channel, "content_chars": len(finalContent), }) } @@ -807,11 +986,23 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M } toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback) - if messageToolFallbackEligible && tc.Name == "message" { + if tc.Name == "message" { if content, ok := tc.Arguments["content"].(string); ok { trimmed := strings.TrimSpace(content) if trimmed != "" { lastMessageToolContent = trimmed + attachmentCount := 0 + if rawAttachments, ok := tc.Arguments["attachments"].([]interface{}); ok { + attachmentCount = len(rawAttachments) + } + logger.InfoCF("agent", "Captured message tool content for fallback", + map[string]interface{}{ + "channel": opts.Channel, + "chat_id": opts.ChatID, + "content_chars": len(trimmed), + "attachments_count": attachmentCount, + "iteration": iteration, + }) } } } @@ -885,13 +1076,149 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M } } - if strings.TrimSpace(finalContent) == "" && messageToolFallbackEligible && lastMessageToolContent != "" { + if strings.TrimSpace(finalContent) == "" && lastMessageToolContent != "" { finalContent = lastMessageToolContent } return finalContent, iteration, nil } +func (al *AgentLoop) maybeArchiveDiscordSession(sessionKey string) (bool, error) { + if al == nil || al.discordArchive == nil { + return false, nil + } + + done := make(chan struct { + archived bool + err error + }, 1) + go func() { + result, err := al.discordArchive.MaybeArchiveSession(sessionKey) + done <- struct { + archived bool + err error + }{ + archived: result != nil, + err: err, + } + }() + + select { + case out := <-done: + return out.archived, out.err + case <-time.After(discordAutoArchiveTimeout): + return false, errDiscordAutoArchiveTimedOut + } +} + +func (al *AgentLoop) buildDiscordRecallContext(query, sessionKey string) (string, int, error) { + if al == nil || al.discordRecallFn == nil { + return "", 0, nil + } + query = strings.TrimSpace(query) + if query == "" { + return "", 0, nil + } + + topK := al.recallTopK + if topK <= 0 { + topK = 6 + } + maxChars := al.recallMaxChars + if maxChars <= 0 { + maxChars = 3000 + } + maxTokens := maxChars / 4 + if maxTokens <= 0 { + maxTokens = 1 + } + + hits, err := al.discordRecallFn(query, sessionKey, topK, maxChars) + if err != nil { + return "", 0, err + } + if len(hits) == 0 { + return "", 0, nil + } + + const header = "## Discord Archive Recall (Auto)\nUse this archived context only when relevant to the current user query." + var b strings.Builder + writeWithCap(&b, header, maxChars) + tokenEstimate := len(b.String()) / 4 + added := 0 + + for i, hit := range hits { + entry := fmt.Sprintf( + "\n\n### Hit %d\n- score: %d\n- session: %s\n- source: %s\n%s", + i+1, + hit.Score, + strings.TrimSpace(hit.SessionKey), + strings.TrimSpace(hit.SourcePath), + strings.TrimSpace(hit.Text), + ) + if entry == "" { + continue + } + entryChars := len(entry) + entryTokens := entryChars / 4 + if entryTokens <= 0 { + entryTokens = 1 + } + + remainingChars := maxChars - len(b.String()) + remainingTokens := maxTokens - tokenEstimate + if remainingChars <= 0 || remainingTokens <= 0 { + break + } + if entryChars > remainingChars || entryTokens > remainingTokens { + writeWithCap(&b, entry, remainingChars) + if strings.TrimSpace(hit.Text) != "" { + added++ + } + break + } + + b.WriteString(entry) + tokenEstimate += entryTokens + added++ + } + + section := strings.TrimSpace(b.String()) + if section == "" { + return "", 0, nil + } + if len(section) > maxChars { + section = truncateRunes(section, maxChars) + } + return section, added, nil +} + +func writeWithCap(b *strings.Builder, text string, maxChars int) { + if b == nil || maxChars <= 0 || text == "" { + return + } + remaining := maxChars - len(b.String()) + if remaining <= 0 { + return + } + if len(text) <= remaining { + b.WriteString(text) + return + } + b.WriteString(truncateRunes(text, remaining)) +} + +func truncateRunes(text string, max int) string { + if max <= 0 { + return "" + } + runes := []rune(text) + if len(runes) <= max { + return text + } + return string(runes[:max]) +} + // updateToolContexts updates the context for tools that need channel/chatID info. func (al *AgentLoop) updateToolContexts(channel, chatID string) { // Use ContextualTool interface instead of type assertions diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 2adca7c..558efb3 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -2,12 +2,15 @@ package agent import ( "context" + "errors" "os" "path/filepath" "strings" + "sync" "testing" "time" + "github.com/sipeed/picoclaw/pkg/archive/discordarchive" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" @@ -338,6 +341,259 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) { } } +func TestProcessDirectWithChannel_DiscordAutoArchiveTrim(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-discord-archive-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.Model = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Channels.Discord.Archive.Enabled = true + cfg.Channels.Discord.Archive.AutoArchive = true + cfg.Channels.Discord.Archive.MaxSessionMessages = 8 + cfg.Channels.Discord.Archive.MaxSessionTokens = 60 + cfg.Channels.Discord.Archive.KeepUserPairs = 2 + cfg.Channels.Discord.Archive.MinTailMessages = 4 + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + sessionKey := "discord:test-channel" + for i := 0; i < 8; i++ { + _, err := al.ProcessDirectWithChannel(context.Background(), "alpha archival pressure message", sessionKey, "discord", "test-channel", "user-1") + if err != nil { + t.Fatalf("ProcessDirectWithChannel failed at turn %d: %v", i, err) + } + } + + history := al.sessions.GetHistory(sessionKey) + if len(history) >= 16 { + t.Fatalf("expected trimmed history, got %d messages", len(history)) + } + + archiveDir := filepath.Join(tmpDir, "memory", "archive", "discord", "sessions") + entries, err := os.ReadDir(archiveDir) + if err != nil { + t.Fatalf("expected archive directory %s: %v", archiveDir, err) + } + mdCount := 0 + for _, entry := range entries { + if strings.HasSuffix(entry.Name(), ".md") { + mdCount++ + } + } + if mdCount == 0 { + t.Fatalf("expected at least one archive markdown file in %s", archiveDir) + } +} + +func TestProcessDirectWithChannel_DiscordAutoRecallInjectedAfterArchiveExists(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-discord-recall-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.Model = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Channels.Discord.Archive.Enabled = true + cfg.Channels.Discord.Archive.AutoArchive = true + cfg.Channels.Discord.Archive.MaxSessionMessages = 8 + cfg.Channels.Discord.Archive.MaxSessionTokens = 60 + cfg.Channels.Discord.Archive.KeepUserPairs = 2 + cfg.Channels.Discord.Archive.MinTailMessages = 4 + cfg.Channels.Discord.Archive.RecallTopK = 3 + cfg.Channels.Discord.Archive.RecallMaxChars = 1000 + + msgBus := bus.NewMessageBus() + provider := &captureMockProvider{response: "Mock response"} + al := NewAgentLoop(cfg, msgBus, provider) + + sessionKey := "discord:test-channel" + marker := "rare-marker-phase2" + for i := 0; i < 8; i++ { + _, err := al.ProcessDirectWithChannel( + context.Background(), + "alpha "+marker+" archival pressure message", + sessionKey, + "discord", + "test-channel", + "user-1", + ) + if err != nil { + t.Fatalf("ProcessDirectWithChannel failed at turn %d: %v", i, err) + } + } + + _, err = al.ProcessDirectWithChannel( + context.Background(), + "please recall "+marker, + sessionKey, + "discord", + "test-channel", + "user-1", + ) + if err != nil { + t.Fatalf("ProcessDirectWithChannel recall turn failed: %v", err) + } + + systemPrompt := provider.LastSystemPrompt() + if !strings.Contains(systemPrompt, "## Discord Archive Recall (Auto)") { + t.Fatal("expected auto recall section in system prompt for discord turn") + } + if !strings.Contains(strings.ToLower(systemPrompt), strings.ToLower(marker)) { + t.Fatalf("expected recall section to contain marker %q", marker) + } +} + +func TestProcessDirectWithChannel_DiscordAutoRecallInjectionCappedByBudget(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-discord-recall-cap-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.Model = "test-model" + cfg.Channels.Discord.Archive.Enabled = true + cfg.Channels.Discord.Archive.AutoArchive = true + cfg.Channels.Discord.Archive.RecallTopK = 3 + cfg.Channels.Discord.Archive.RecallMaxChars = 220 + cfg.Channels.Discord.Archive.MaxSessionMessages = 8 + cfg.Channels.Discord.Archive.MaxSessionTokens = 60 + + msgBus := bus.NewMessageBus() + provider := &captureMockProvider{response: "Mock response"} + al := NewAgentLoop(cfg, msgBus, provider) + + al.discordRecallFn = func(query, sessionKey string, topK, maxChars int) ([]discordarchive.RecallHit, error) { + return []discordarchive.RecallHit{ + { + SessionKey: sessionKey, + SourcePath: "/tmp/archive-one.md", + Score: 99, + Text: strings.Repeat("token-heavy-recall-context ", 40), + }, + }, nil + } + + _, err = al.ProcessDirectWithChannel( + context.Background(), + "token-heavy recall test", + "discord:test-channel", + "discord", + "test-channel", + "user-1", + ) + if err != nil { + t.Fatalf("ProcessDirectWithChannel failed: %v", err) + } + + systemPrompt := provider.LastSystemPrompt() + const recallHeader = "## Discord Archive Recall (Auto)" + start := strings.Index(systemPrompt, recallHeader) + if start < 0 { + t.Fatal("expected auto recall header in system prompt") + } + recallSection := systemPrompt[start:] + if len(recallSection) > cfg.Channels.Discord.Archive.RecallMaxChars { + t.Fatalf( + "expected recall section <= %d chars, got %d", + cfg.Channels.Discord.Archive.RecallMaxChars, + len(recallSection), + ) + } +} + +func TestProcessDirectWithChannel_DiscordAutoRecallFailureIsFailOpen(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-discord-recall-failopen-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.Model = "test-model" + cfg.Channels.Discord.Archive.Enabled = true + + msgBus := bus.NewMessageBus() + provider := &captureMockProvider{response: "LLM still replied"} + al := NewAgentLoop(cfg, msgBus, provider) + + al.discordRecallFn = func(query, sessionKey string, topK, maxChars int) ([]discordarchive.RecallHit, error) { + return nil, errors.New("synthetic recall failure") + } + + got, err := al.ProcessDirectWithChannel( + context.Background(), + "does recall failure block chat", + "discord:test-channel", + "discord", + "test-channel", + "user-1", + ) + if err != nil { + t.Fatalf("expected fail-open behavior, got error: %v", err) + } + if got != "LLM still replied" { + t.Fatalf("unexpected response after recall failure: %q", got) + } +} + +func TestProcessDirectWithChannel_NonDiscordNoAutoRecallInjection(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-recall-nondiscord-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.Model = "test-model" + cfg.Channels.Discord.Archive.Enabled = true + + msgBus := bus.NewMessageBus() + provider := &captureMockProvider{response: "Mock response"} + al := NewAgentLoop(cfg, msgBus, provider) + + al.discordRecallFn = func(query, sessionKey string, topK, maxChars int) ([]discordarchive.RecallHit, error) { + return []discordarchive.RecallHit{ + { + SessionKey: "discord:test-channel", + SourcePath: "/tmp/archive.md", + Score: 7, + Text: "discord-only auto recall context", + }, + }, nil + } + + _, err = al.ProcessDirectWithChannel( + context.Background(), + "telegram turn should not inject discord archive", + "telegram:test-channel", + "telegram", + "test-channel", + "user-1", + ) + if err != nil { + t.Fatalf("ProcessDirectWithChannel failed: %v", err) + } + + systemPrompt := provider.LastSystemPrompt() + if strings.Contains(systemPrompt, "## Discord Archive Recall (Auto)") { + t.Fatal("did not expect discord recall section for non-discord channel") + } +} + func TestAgentLoop_HooksCanBeDisabledByPolicy(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil { @@ -416,16 +672,23 @@ func TestAgentLoop_WritesHookAuditEvents(t *testing.T) { } auditPath := filepath.Join(tmpDir, "hooks", "hook-events.jsonl") - data, err := os.ReadFile(auditPath) - if err != nil { - t.Fatalf("read hook audit file: %v", err) - } - body := string(data) - if !strings.Contains(body, "\"event\":\"before_turn\"") { - t.Fatalf("expected before_turn event in audit file") - } - if !strings.Contains(body, "\"event\":\"after_turn\"") { - t.Fatalf("expected after_turn event in audit file") + deadline := time.Now().Add(2 * time.Second) + var body string + for { + data, readErr := os.ReadFile(auditPath) + if readErr == nil { + body = string(data) + if strings.Contains(body, "\"event\":\"before_turn\"") && strings.Contains(body, "\"event\":\"after_turn\"") { + break + } + } + if time.Now().After(deadline) { + if readErr != nil { + t.Fatalf("read hook audit file after wait: %v", readErr) + } + t.Fatalf("expected before_turn and after_turn events in audit file, got: %s", body) + } + time.Sleep(10 * time.Millisecond) } } @@ -480,6 +743,53 @@ func (m *simpleMockProvider) GetDefaultModel() string { return "mock-model" } +type captureMockProvider struct { + response string + mu sync.Mutex + last []providers.Message +} + +func (m *captureMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + m.mu.Lock() + m.last = cloneMessages(messages) + m.mu.Unlock() + return &providers.LLMResponse{ + Content: m.response, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *captureMockProvider) GetDefaultModel() string { + return "mock-model" +} + +func (m *captureMockProvider) LastSystemPrompt() string { + m.mu.Lock() + defer m.mu.Unlock() + if len(m.last) == 0 { + return "" + } + if m.last[0].Role != "system" { + return "" + } + return m.last[0].Content +} + +func cloneMessages(in []providers.Message) []providers.Message { + if len(in) == 0 { + return nil + } + out := make([]providers.Message, len(in)) + for i, msg := range in { + out[i] = msg + if len(msg.ToolCalls) > 0 { + out[i].ToolCalls = make([]providers.ToolCall, len(msg.ToolCalls)) + copy(out[i].ToolCalls, msg.ToolCalls) + } + } + return out +} + // mockCustomTool is a simple mock tool for registration testing type mockCustomTool struct{} diff --git a/pkg/agent/message_tool_fallback_test.go b/pkg/agent/message_tool_fallback_test.go index 9634f13..4003255 100644 --- a/pkg/agent/message_tool_fallback_test.go +++ b/pkg/agent/message_tool_fallback_test.go @@ -3,6 +3,7 @@ package agent import ( "context" "os" + "strings" "testing" "github.com/sipeed/picoclaw/pkg/bus" @@ -115,7 +116,7 @@ func TestProcessDirect_MessageToolFallbackWhenLLMReturnsDefaultPlaceholder(t *te }, }, { - Content: "I've completed processing but have no response to give.", + Content: defaultEmptyAssistantResponse, ToolCalls: nil, }, }, @@ -132,3 +133,57 @@ func TestProcessDirect_MessageToolFallbackWhenLLMReturnsDefaultPlaceholder(t *te t.Fatalf("unexpected response: got %q want %q", got, want) } } + +func TestProcessDirect_MessageToolFallbackForExternalChannel(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-message-fallback-external-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(tmpDir) }) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "mock-model", + MaxTokens: 4096, + MaxToolIterations: 3, + }, + }, + } + + provider := &sequenceProvider{ + responses: []*providers.LLMResponse{ + { + Content: "", + ToolCalls: []providers.ToolCall{ + { + ID: "call-1", + Name: "message", + Arguments: map[string]interface{}{ + "content": "The docx abstract has been saved to disk.", + }, + }, + }, + }, + { + Content: "", + ToolCalls: nil, + }, + }, + } + + al := NewAgentLoop(cfg, bus.NewMessageBus(), provider) + got, err := al.ProcessDirectWithChannel(context.Background(), "save abstract", "discord:test-session", "discord", "chat-123", "user-1") + if err != nil { + t.Fatalf("ProcessDirectWithChannel error: %v", err) + } + + want := "The docx abstract has been saved to disk." + if got != want { + t.Fatalf("unexpected response: got %q want %q", got, want) + } + if strings.Contains(strings.ToLower(got), "no response to give") { + t.Fatalf("unexpected placeholder fallback in response: %q", got) + } +} diff --git a/pkg/archive/discordarchive/manager.go b/pkg/archive/discordarchive/manager.go new file mode 100644 index 0000000..194c81b --- /dev/null +++ b/pkg/archive/discordarchive/manager.go @@ -0,0 +1,517 @@ +package discordarchive + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/session" +) + +const ( + maxArchiveMessageChars = 2000 +) + +type Manager struct { + workspace string + cfg config.DiscordArchiveConfig + sessions *session.SessionManager +} + +type SessionStat struct { + SessionKey string + Messages int + Tokens int + OverLimit bool +} + +type ArchiveResult struct { + SessionKey string + ArchivedMessages int + KeptMessages int + ArchivePath string + TokensBefore int + TokensAfter int + OverLimit bool + DryRun bool +} + +type RecallHit struct { + SessionKey string `json:"session_key"` + SourcePath string `json:"source_path"` + Score int `json:"score"` + Text string `json:"text"` +} + +type archiveState struct { + Sessions map[string]archiveSessionState `json:"sessions"` +} + +type archiveSessionState struct { + LastArchivedAt string `json:"last_archived_at"` + LastArchivePath string `json:"last_archive_path"` + ArchivedMessages int `json:"archived_messages"` + KeptMessages int `json:"kept_messages"` + TokensBefore int `json:"tokens_before"` + TokensAfter int `json:"tokens_after"` + LastOverLimitState bool `json:"last_over_limit_state"` +} + +func NewManager(workspace string, sm *session.SessionManager, cfg config.DiscordArchiveConfig) *Manager { + return &Manager{ + workspace: workspace, + cfg: cfg, + sessions: sm, + } +} + +func (m *Manager) ListDiscordSessions(overLimitOnly bool) []SessionStat { + if m == nil || m.sessions == nil { + return nil + } + stats := make([]SessionStat, 0) + for _, key := range m.sessions.ListKeys() { + if !strings.HasPrefix(key, "discord:") { + continue + } + snap, ok := m.sessions.Snapshot(key) + if !ok { + continue + } + msgCount := len(snap.Messages) + tokenCount := estimateTokens(snap.Messages) + over := tokenCount >= m.cfg.MaxSessionTokens || msgCount >= m.cfg.MaxSessionMessages + if overLimitOnly && !over { + continue + } + stats = append(stats, SessionStat{ + SessionKey: key, + Messages: msgCount, + Tokens: tokenCount, + OverLimit: over, + }) + } + sort.Slice(stats, func(i, j int) bool { + if stats[i].Tokens == stats[j].Tokens { + return stats[i].SessionKey < stats[j].SessionKey + } + return stats[i].Tokens > stats[j].Tokens + }) + return stats +} + +func (m *Manager) MaybeArchiveSession(sessionKey string) (*ArchiveResult, error) { + if m == nil || m.sessions == nil { + return nil, nil + } + if !strings.HasPrefix(sessionKey, "discord:") { + return nil, nil + } + snap, ok := m.sessions.Snapshot(sessionKey) + if !ok { + return nil, nil + } + over := len(snap.Messages) >= m.cfg.MaxSessionMessages || estimateTokens(snap.Messages) >= m.cfg.MaxSessionTokens + if !over { + // Avoid per-message disk writes on cloud-backed filesystems when session + // remains under limits. This metadata is not required for recall behavior. + return nil, nil + } + result, err := m.archiveSnapshot(sessionKey, snap.Messages, false) + if err != nil { + return nil, err + } + return result, nil +} + +func (m *Manager) ArchiveAll(overLimitOnly bool, dryRun bool) ([]ArchiveResult, error) { + if m == nil || m.sessions == nil { + return nil, nil + } + results := make([]ArchiveResult, 0) + for _, stat := range m.ListDiscordSessions(false) { + if overLimitOnly && !stat.OverLimit { + continue + } + snap, ok := m.sessions.Snapshot(stat.SessionKey) + if !ok { + continue + } + result, err := m.archiveSnapshot(stat.SessionKey, snap.Messages, dryRun) + if err != nil { + return results, err + } + if result != nil { + results = append(results, *result) + } + } + return results, nil +} + +func (m *Manager) ArchiveSession(sessionKey string, dryRun bool) (*ArchiveResult, error) { + if m == nil || m.sessions == nil { + return nil, nil + } + snap, ok := m.sessions.Snapshot(sessionKey) + if !ok { + return nil, fmt.Errorf("session not found: %s", sessionKey) + } + return m.archiveSnapshot(sessionKey, snap.Messages, dryRun) +} + +func (m *Manager) Recall(query, sessionKey string, topK, maxChars int) []RecallHit { + query = strings.TrimSpace(query) + if query == "" { + return nil + } + if topK <= 0 { + topK = m.cfg.RecallTopK + if topK <= 0 { + topK = 6 + } + } + if maxChars <= 0 { + maxChars = m.cfg.RecallMaxChars + if maxChars <= 0 { + maxChars = 3000 + } + } + + files, err := os.ReadDir(m.archiveSessionsDir()) + if err != nil { + return nil + } + safeKey := "" + if strings.TrimSpace(sessionKey) != "" { + safeKey = sanitizeSessionKey(sessionKey) + } + + terms := tokenize(query) + type scored struct { + hit RecallHit + size int + } + scoredHits := make([]scored, 0) + for _, entry := range files { + if entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasSuffix(name, ".md") { + continue + } + if safeKey != "" && !strings.Contains(name, safeKey) { + continue + } + path := filepath.Join(m.archiveSessionsDir(), name) + data, err := os.ReadFile(path) + if err != nil { + continue + } + text := string(data) + score := lexicalScore(text, terms) + if score <= 0 { + continue + } + snippet := summarizeText(text, 420) + scoredHits = append(scoredHits, scored{ + hit: RecallHit{ + SessionKey: extractSessionKeyFromArchive(text), + SourcePath: path, + Score: score, + Text: snippet, + }, + size: len(snippet), + }) + } + sort.Slice(scoredHits, func(i, j int) bool { + if scoredHits[i].hit.Score == scoredHits[j].hit.Score { + return scoredHits[i].hit.SourcePath < scoredHits[j].hit.SourcePath + } + return scoredHits[i].hit.Score > scoredHits[j].hit.Score + }) + + out := make([]RecallHit, 0, topK) + remaining := maxChars + for _, candidate := range scoredHits { + if len(out) >= topK { + break + } + if candidate.size > remaining { + continue + } + out = append(out, candidate.hit) + remaining -= candidate.size + } + return out +} + +func (m *Manager) archiveSnapshot(sessionKey string, allMessages []providers.Message, dryRun bool) (*ArchiveResult, error) { + if !strings.HasPrefix(sessionKey, "discord:") { + return nil, nil + } + if len(allMessages) == 0 { + return nil, nil + } + + tokensBefore := estimateTokens(allMessages) + keepStart := calculateKeepStart(allMessages, m.cfg.KeepUserPairs, m.cfg.MinTailMessages) + if keepStart <= 0 { + return nil, nil + } + if keepStart > len(allMessages) { + keepStart = len(allMessages) + } + + archiveSlice := allMessages[:keepStart] + keptSlice := allMessages[keepStart:] + archiveForMarkdown := selectArchiveMessages(archiveSlice) + if len(archiveForMarkdown) == 0 { + return nil, nil + } + + result := &ArchiveResult{ + SessionKey: sessionKey, + ArchivedMessages: len(archiveForMarkdown), + KeptMessages: len(keptSlice), + TokensBefore: tokensBefore, + TokensAfter: estimateTokens(keptSlice), + OverLimit: tokensBefore >= m.cfg.MaxSessionTokens || len(allMessages) >= m.cfg.MaxSessionMessages, + DryRun: dryRun, + } + + archivePath := m.archivePathFor(sessionKey) + result.ArchivePath = archivePath + if dryRun { + return result, nil + } + + if err := os.MkdirAll(filepath.Dir(archivePath), 0755); err != nil { + return nil, err + } + md := buildMarkdown(sessionKey, archiveForMarkdown) + if err := os.WriteFile(archivePath, []byte(md), 0644); err != nil { + return nil, err + } + + m.sessions.ReplaceHistory(sessionKey, keptSlice) + if err := m.sessions.Save(sessionKey); err != nil { + return nil, err + } + + if err := m.writeState(sessionKey, archiveSessionState{ + LastArchivedAt: time.Now().UTC().Format(time.RFC3339), + LastArchivePath: archivePath, + ArchivedMessages: result.ArchivedMessages, + KeptMessages: result.KeptMessages, + TokensBefore: result.TokensBefore, + TokensAfter: result.TokensAfter, + LastOverLimitState: result.OverLimit, + }); err != nil { + return nil, err + } + return result, nil +} + +func (m *Manager) writeState(sessionKey string, stateEntry archiveSessionState) error { + path := filepath.Join(m.archiveBaseDir(), ".archive-state.json") + current := archiveState{Sessions: map[string]archiveSessionState{}} + + if data, err := os.ReadFile(path); err == nil { + _ = json.Unmarshal(data, ¤t) + if current.Sessions == nil { + current.Sessions = map[string]archiveSessionState{} + } + } + current.Sessions[sessionKey] = stateEntry + + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + data, err := json.MarshalIndent(current, "", " ") + if err != nil { + return err + } + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0644); err != nil { + return err + } + return os.Rename(tmp, path) +} + +func (m *Manager) archiveBaseDir() string { + return filepath.Join(m.workspace, "memory", "archive", "discord") +} + +func (m *Manager) archiveSessionsDir() string { + return filepath.Join(m.archiveBaseDir(), "sessions") +} + +func (m *Manager) archivePathFor(sessionKey string) string { + now := time.Now().UTC() + return filepath.Join( + m.archiveSessionsDir(), + fmt.Sprintf("%s-discord-session-%s-%d.md", now.Format("2006-01-02"), sanitizeSessionKey(sessionKey), now.UnixNano()), + ) +} + +func calculateKeepStart(messages []providers.Message, keepUserPairs, minTailMessages int) int { + if len(messages) == 0 { + return 0 + } + if keepUserPairs <= 0 { + keepUserPairs = 12 + } + if minTailMessages <= 0 { + minTailMessages = 4 + } + + keepStart := 0 + userCount := 0 + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + userCount++ + if userCount >= keepUserPairs { + keepStart = i + break + } + } + } + if userCount < keepUserPairs { + keepStart = 0 + } + + minKeepStart := len(messages) - minTailMessages + if minKeepStart < 0 { + minKeepStart = 0 + } + if keepStart > minKeepStart { + keepStart = minKeepStart + } + return keepStart +} + +func selectArchiveMessages(messages []providers.Message) []providers.Message { + out := make([]providers.Message, 0, len(messages)) + for _, msg := range messages { + if msg.Role != "user" && msg.Role != "assistant" { + continue + } + content := strings.TrimSpace(msg.Content) + if content == "" { + continue + } + out = append(out, providers.Message{ + Role: msg.Role, + Content: truncate(content, maxArchiveMessageChars), + }) + } + return out +} + +func buildMarkdown(sessionKey string, messages []providers.Message) string { + var b strings.Builder + now := time.Now().UTC().Format(time.RFC3339) + b.WriteString("# Discord Session Archive\n\n") + b.WriteString("- **Session Key**: " + sessionKey + "\n") + b.WriteString("- **Archived At**: " + now + "\n") + b.WriteString("- **Messages**: " + fmt.Sprintf("%d", len(messages)) + "\n\n") + b.WriteString("---\n\n") + for i, msg := range messages { + label := "Assistant" + if msg.Role == "user" { + label = "User" + } + b.WriteString(fmt.Sprintf("### %d. %s\n\n", i+1, label)) + b.WriteString(msg.Content) + b.WriteString("\n\n") + } + return b.String() +} + +func estimateTokens(messages []providers.Message) int { + total := 0 + for _, msg := range messages { + total += len(msg.Content) / 4 + } + return total +} + +var nonSafeChars = regexp.MustCompile(`[^a-zA-Z0-9._-]+`) + +func sanitizeSessionKey(sessionKey string) string { + key := strings.TrimSpace(sessionKey) + if key == "" { + return "unknown" + } + key = strings.ReplaceAll(key, ":", "_") + key = strings.ReplaceAll(key, "@", "_") + key = nonSafeChars.ReplaceAllString(key, "_") + key = strings.Trim(key, "_") + if key == "" { + return "unknown" + } + return key +} + +func tokenize(text string) []string { + raw := strings.Fields(strings.ToLower(text)) + out := make([]string, 0, len(raw)) + seen := map[string]struct{}{} + for _, token := range raw { + token = strings.Trim(token, ".,!?;:\"'()[]{}<>") + if len(token) < 2 { + continue + } + if _, ok := seen[token]; ok { + continue + } + seen[token] = struct{}{} + out = append(out, token) + } + return out +} + +func lexicalScore(text string, terms []string) int { + lower := strings.ToLower(text) + score := 0 + for _, term := range terms { + score += strings.Count(lower, term) + } + return score +} + +func summarizeText(text string, max int) string { + text = strings.TrimSpace(text) + text = strings.ReplaceAll(text, "\n", " ") + text = strings.Join(strings.Fields(text), " ") + return truncate(text, max) +} + +func extractSessionKeyFromArchive(text string) string { + for _, line := range strings.Split(text, "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "- **Session Key**:") { + return strings.TrimSpace(strings.TrimPrefix(line, "- **Session Key**:")) + } + } + return "" +} + +func truncate(text string, max int) string { + if max <= 0 { + return text + } + runes := []rune(text) + if len(runes) <= max { + return text + } + return string(runes[:max]) + "... [truncated]" +} diff --git a/pkg/archive/discordarchive/manager_test.go b/pkg/archive/discordarchive/manager_test.go new file mode 100644 index 0000000..359a92d --- /dev/null +++ b/pkg/archive/discordarchive/manager_test.go @@ -0,0 +1,206 @@ +package discordarchive + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/session" +) + +func TestArchiveAndRecallRoundTrip(t *testing.T) { + workspace := t.TempDir() + sm := session.NewSessionManager(filepath.Join(workspace, "sessions")) + + key := "discord:12345" + for i := 0; i < 20; i++ { + sm.AddMessage(key, "user", "discuss alpha design and token pressure") + sm.AddMessage(key, "assistant", "alpha response with implementation details") + } + if err := sm.Save(key); err != nil { + t.Fatalf("save session: %v", err) + } + + cfg := config.DiscordArchiveConfig{ + Enabled: true, + AutoArchive: true, + MaxSessionTokens: 50, + MaxSessionMessages: 8, + KeepUserPairs: 3, + MinTailMessages: 4, + RecallTopK: 5, + RecallMaxChars: 2000, + RecallMinScore: 0.2, + } + mgr := NewManager(workspace, sm, cfg) + + result, err := mgr.MaybeArchiveSession(key) + if err != nil { + t.Fatalf("MaybeArchiveSession error: %v", err) + } + if result == nil { + t.Fatal("expected archive result, got nil") + } + if result.ArchivedMessages == 0 { + t.Fatal("expected archived messages > 0") + } + if result.KeptMessages == 0 { + t.Fatal("expected kept messages > 0") + } + if result.TokensAfter >= result.TokensBefore { + t.Fatalf("expected token reduction, before=%d after=%d", result.TokensBefore, result.TokensAfter) + } + if _, err := os.Stat(result.ArchivePath); err != nil { + t.Fatalf("expected archive file at %s: %v", result.ArchivePath, err) + } + + history := sm.GetHistory(key) + if len(history) == 0 { + t.Fatal("expected non-empty trimmed history") + } + if len(history) >= 40 { + t.Fatalf("expected trimmed history < 40, got %d", len(history)) + } + + hits := mgr.Recall("alpha token", key, 3, 600) + if len(hits) == 0 { + t.Fatal("expected recall hits") + } + if hits[0].Score <= 0 { + t.Fatalf("expected positive recall score, got %d", hits[0].Score) + } + if !strings.Contains(strings.ToLower(hits[0].Text), "alpha") { + t.Fatalf("expected hit text to mention alpha, got %q", hits[0].Text) + } +} + +func TestListDiscordSessionsOverLimit(t *testing.T) { + workspace := t.TempDir() + sm := session.NewSessionManager(filepath.Join(workspace, "sessions")) + + overKey := "discord:over" + for i := 0; i < 10; i++ { + sm.AddMessage(overKey, "user", strings.Repeat("x", 80)) + } + sm.AddMessage("discord:small", "user", "tiny") + sm.AddMessage("telegram:small", "user", "tiny") + + cfg := config.DiscordArchiveConfig{ + MaxSessionTokens: 30, + MaxSessionMessages: 8, + KeepUserPairs: 2, + MinTailMessages: 2, + RecallTopK: 3, + RecallMaxChars: 1000, + } + mgr := NewManager(workspace, sm, cfg) + + all := mgr.ListDiscordSessions(false) + if len(all) != 2 { + t.Fatalf("expected 2 discord sessions, got %d", len(all)) + } + overOnly := mgr.ListDiscordSessions(true) + if len(overOnly) != 1 { + t.Fatalf("expected 1 over-limit session, got %d", len(overOnly)) + } + if overOnly[0].SessionKey != overKey { + t.Fatalf("expected over-limit key %q, got %q", overKey, overOnly[0].SessionKey) + } +} + +func TestArchiveSessionDryRunDoesNotMutateHistory(t *testing.T) { + workspace := t.TempDir() + sm := session.NewSessionManager(filepath.Join(workspace, "sessions")) + key := "discord:dryrun" + for i := 0; i < 12; i++ { + sm.AddMessage(key, "user", "dry run content") + sm.AddMessage(key, "assistant", "dry run answer") + } + + cfg := config.DiscordArchiveConfig{ + MaxSessionTokens: 10, + MaxSessionMessages: 8, + KeepUserPairs: 2, + MinTailMessages: 4, + RecallTopK: 3, + RecallMaxChars: 1000, + } + mgr := NewManager(workspace, sm, cfg) + before := sm.GetHistory(key) + result, err := mgr.ArchiveSession(key, true) + if err != nil { + t.Fatalf("ArchiveSession dry-run error: %v", err) + } + if result == nil || !result.DryRun { + t.Fatalf("expected dry-run result, got %#v", result) + } + after := sm.GetHistory(key) + if len(before) != len(after) { + t.Fatalf("dry-run should not mutate history length: before=%d after=%d", len(before), len(after)) + } +} + +func TestRecallWithoutSessionKeyScansAllArchives(t *testing.T) { + workspace := t.TempDir() + sm := session.NewSessionManager(filepath.Join(workspace, "sessions")) + + keyA := "discord:alpha" + keyB := "discord:beta" + for i := 0; i < 10; i++ { + sm.AddMessage(keyA, "user", "alpha memory token") + sm.AddMessage(keyA, "assistant", "alpha assistant memory token") + sm.AddMessage(keyB, "user", "beta channel noise") + sm.AddMessage(keyB, "assistant", "beta response noise") + } + if err := sm.Save(keyA); err != nil { + t.Fatalf("save session A: %v", err) + } + if err := sm.Save(keyB); err != nil { + t.Fatalf("save session B: %v", err) + } + + cfg := config.DiscordArchiveConfig{ + Enabled: true, + AutoArchive: true, + MaxSessionTokens: 40, + MaxSessionMessages: 8, + KeepUserPairs: 3, + MinTailMessages: 4, + RecallTopK: 5, + RecallMaxChars: 2000, + } + mgr := NewManager(workspace, sm, cfg) + + if _, err := mgr.MaybeArchiveSession(keyA); err != nil { + t.Fatalf("archive session A: %v", err) + } + if _, err := mgr.MaybeArchiveSession(keyB); err != nil { + t.Fatalf("archive session B: %v", err) + } + + hits := mgr.Recall("alpha token", "", 5, 2000) + if len(hits) == 0 { + t.Fatal("expected recall hits without session-key filter") + } + if !strings.Contains(strings.ToLower(hits[0].Text), "alpha") { + t.Fatalf("expected alpha content in top hit, got %q", hits[0].Text) + } +} + +func TestCalculateKeepStartContinuityFloor(t *testing.T) { + msgs := []providers.Message{ + {Role: "user", Content: "u1"}, + {Role: "assistant", Content: "a1"}, + {Role: "user", Content: "u2"}, + {Role: "assistant", Content: "a2"}, + {Role: "user", Content: "u3"}, + {Role: "assistant", Content: "a3"}, + } + got := calculateKeepStart(msgs, 1, 4) + if got != 2 { + t.Fatalf("keepStart=%d, want 2", got) + } +} diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index 364a846..6e9c07c 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -267,6 +267,9 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag "mention_roles": len(m.MentionRoles), "bot_user_id": c.botUserID, "guild_id": m.GuildID, + "channel_id": m.ChannelID, + "message_id": m.ID, + "sender_id": senderID, "has_msg_ref": m.MessageReference != nil, "has_ref_msg": m.ReferencedMessage != nil, }) diff --git a/pkg/config/config.go b/pkg/config/config.go index d352ce7..bbde62f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -134,10 +134,23 @@ type FeishuConfig struct { AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` } +type DiscordArchiveConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ARCHIVE_ENABLED"` + AutoArchive bool `json:"auto_archive" env:"PICOCLAW_CHANNELS_DISCORD_ARCHIVE_AUTO_ARCHIVE"` + MaxSessionTokens int `json:"max_session_tokens" env:"PICOCLAW_CHANNELS_DISCORD_ARCHIVE_MAX_SESSION_TOKENS"` + MaxSessionMessages int `json:"max_session_messages" env:"PICOCLAW_CHANNELS_DISCORD_ARCHIVE_MAX_SESSION_MESSAGES"` + KeepUserPairs int `json:"keep_user_pairs" env:"PICOCLAW_CHANNELS_DISCORD_ARCHIVE_KEEP_USER_PAIRS"` + MinTailMessages int `json:"min_tail_messages" env:"PICOCLAW_CHANNELS_DISCORD_ARCHIVE_MIN_TAIL_MESSAGES"` + RecallTopK int `json:"recall_top_k" env:"PICOCLAW_CHANNELS_DISCORD_ARCHIVE_RECALL_TOP_K"` + RecallMaxChars int `json:"recall_max_chars" env:"PICOCLAW_CHANNELS_DISCORD_ARCHIVE_RECALL_MAX_CHARS"` + RecallMinScore float64 `json:"recall_min_score" env:"PICOCLAW_CHANNELS_DISCORD_ARCHIVE_RECALL_MIN_SCORE"` +} + type DiscordConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` + Archive DiscordArchiveConfig `json:"archive"` } type MaixCamConfig struct { @@ -280,6 +293,17 @@ func DefaultConfig() *Config { Enabled: false, Token: "", AllowFrom: FlexibleStringSlice{}, + Archive: DiscordArchiveConfig{ + Enabled: true, + AutoArchive: true, + MaxSessionTokens: 24000, + MaxSessionMessages: 120, + KeepUserPairs: 12, + MinTailMessages: 4, + RecallTopK: 6, + RecallMaxChars: 3000, + RecallMinScore: 0.20, + }, }, MaixCam: MaixCamConfig{ Enabled: false, @@ -381,6 +405,7 @@ func LoadConfig(path string) (*Config, error) { if err := env.Parse(cfg); err != nil { return nil, err } + normalizeDiscordArchiveConfig(&cfg.Channels.Discord.Archive) if cfg.Agents.Defaults.MaxToolIterations < 0 { cfg.Agents.Defaults.MaxToolIterations = 0 } @@ -488,6 +513,33 @@ func expandHome(path string) string { return path } +func normalizeDiscordArchiveConfig(cfg *DiscordArchiveConfig) { + if cfg == nil { + return + } + if cfg.MaxSessionTokens <= 0 { + cfg.MaxSessionTokens = 24000 + } + if cfg.MaxSessionMessages <= 0 { + cfg.MaxSessionMessages = 120 + } + if cfg.KeepUserPairs <= 0 { + cfg.KeepUserPairs = 12 + } + if cfg.MinTailMessages <= 0 { + cfg.MinTailMessages = 4 + } + if cfg.RecallTopK <= 0 { + cfg.RecallTopK = 6 + } + if cfg.RecallMaxChars <= 0 { + cfg.RecallMaxChars = 3000 + } + if cfg.RecallMinScore < 0 || cfg.RecallMinScore > 1 { + cfg.RecallMinScore = 0.20 + } +} + func ValidateRoutingConfig(r RoutingConfig) error { behavior := strings.TrimSpace(r.UnmappedBehavior) if behavior == "" { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 596f525..1121fde 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -148,6 +148,92 @@ func TestDefaultConfig_Channels(t *testing.T) { } } +func TestDefaultConfig_DiscordArchive(t *testing.T) { + cfg := DefaultConfig() + archive := cfg.Channels.Discord.Archive + if !archive.Enabled { + t.Fatal("Discord archive should be enabled by default") + } + if !archive.AutoArchive { + t.Fatal("Discord auto_archive should be enabled by default") + } + if archive.MaxSessionTokens != 24000 { + t.Fatalf("MaxSessionTokens=%d, want 24000", archive.MaxSessionTokens) + } + if archive.MaxSessionMessages != 120 { + t.Fatalf("MaxSessionMessages=%d, want 120", archive.MaxSessionMessages) + } + if archive.KeepUserPairs != 12 { + t.Fatalf("KeepUserPairs=%d, want 12", archive.KeepUserPairs) + } + if archive.MinTailMessages != 4 { + t.Fatalf("MinTailMessages=%d, want 4", archive.MinTailMessages) + } + if archive.RecallTopK != 6 { + t.Fatalf("RecallTopK=%d, want 6", archive.RecallTopK) + } + if archive.RecallMaxChars != 3000 { + t.Fatalf("RecallMaxChars=%d, want 3000", archive.RecallMaxChars) + } + if archive.RecallMinScore != 0.20 { + t.Fatalf("RecallMinScore=%f, want 0.20", archive.RecallMinScore) + } +} + +func TestLoadConfig_NormalizesDiscordArchive(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "config.json") + + raw := `{ + "channels": { + "discord": { + "archive": { + "enabled": true, + "auto_archive": true, + "max_session_tokens": 0, + "max_session_messages": -1, + "keep_user_pairs": 0, + "min_tail_messages": -5, + "recall_top_k": 0, + "recall_max_chars": -1, + "recall_min_score": 2.4 + } + } + } +}` + if err := os.WriteFile(path, []byte(raw), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + cfg, err := LoadConfig(path) + if err != nil { + t.Fatalf("LoadConfig: %v", err) + } + + got := cfg.Channels.Discord.Archive + if got.MaxSessionTokens != 24000 { + t.Fatalf("MaxSessionTokens=%d, want 24000", got.MaxSessionTokens) + } + if got.MaxSessionMessages != 120 { + t.Fatalf("MaxSessionMessages=%d, want 120", got.MaxSessionMessages) + } + if got.KeepUserPairs != 12 { + t.Fatalf("KeepUserPairs=%d, want 12", got.KeepUserPairs) + } + if got.MinTailMessages != 4 { + t.Fatalf("MinTailMessages=%d, want 4", got.MinTailMessages) + } + if got.RecallTopK != 6 { + t.Fatalf("RecallTopK=%d, want 6", got.RecallTopK) + } + if got.RecallMaxChars != 3000 { + t.Fatalf("RecallMaxChars=%d, want 3000", got.RecallMaxChars) + } + if got.RecallMinScore != 0.20 { + t.Fatalf("RecallMinScore=%f, want 0.20", got.RecallMinScore) + } +} + // TestDefaultConfig_WebTools verifies web tools config func TestDefaultConfig_WebTools(t *testing.T) { cfg := DefaultConfig() diff --git a/pkg/hooks/audit_jsonl.go b/pkg/hooks/audit_jsonl.go index ded298b..2cf7548 100644 --- a/pkg/hooks/audit_jsonl.go +++ b/pkg/hooks/audit_jsonl.go @@ -5,13 +5,17 @@ import ( "fmt" "os" "path/filepath" - "sync" +) + +const ( + // Buffer audit writes so hook dispatch never blocks on slow filesystems. + auditQueueSize = 256 ) // JSONLAuditSink appends hook entries as JSONL. type JSONLAuditSink struct { - mu sync.Mutex - path string + path string + queue chan []byte } func NewJSONLAuditSink(workspace string) (*JSONLAuditSink, error) { @@ -23,7 +27,12 @@ func NewJSONLAuditSinkAt(path string) (*JSONLAuditSink, error) { if err := os.MkdirAll(dir, 0755); err != nil { return nil, fmt.Errorf("create hooks audit dir: %w", err) } - return &JSONLAuditSink{path: path}, nil + sink := &JSONLAuditSink{ + path: path, + queue: make(chan []byte, auditQueueSize), + } + go sink.writeLoop() + return sink, nil } func (s *JSONLAuditSink) Path() string { @@ -31,20 +40,44 @@ func (s *JSONLAuditSink) Path() string { } func (s *JSONLAuditSink) Write(entry AuditEntry) error { - s.mu.Lock() - defer s.mu.Unlock() - - f, err := os.OpenFile(s.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) + b, err := json.Marshal(entry) if err != nil { return err } - defer f.Close() - b, err := json.Marshal(entry) + line := append(b, '\n') + select { + case s.queue <- line: + return nil + default: + } + + // Queue full: drop oldest pending line so current hook event can proceed. + select { + case <-s.queue: + default: + } + select { + case s.queue <- line: + default: + } + return nil +} + +func (s *JSONLAuditSink) writeLoop() { + for line := range s.queue { + _ = s.appendLine(line) + } +} + +func (s *JSONLAuditSink) appendLine(line []byte) error { + f, err := os.OpenFile(s.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) if err != nil { return err } - if _, err := f.Write(append(b, '\n')); err != nil { + defer f.Close() + + if _, err := f.Write(line); err != nil { return err } return nil diff --git a/pkg/hooks/audit_jsonl_test.go b/pkg/hooks/audit_jsonl_test.go index b27f350..176adda 100644 --- a/pkg/hooks/audit_jsonl_test.go +++ b/pkg/hooks/audit_jsonl_test.go @@ -18,11 +18,20 @@ func TestJSONLAuditSinkWrite(t *testing.T) { if err := sink.Write(entry); err != nil { t.Fatalf("Write: %v", err) } - data, err := os.ReadFile(filepath.Join(ws, "hooks", "hook-events.jsonl")) - if err != nil { - t.Fatalf("read audit file: %v", err) - } - if !strings.Contains(string(data), "\"turn_id\":\"turn-1\"") { - t.Fatalf("audit content missing turn_id: %s", string(data)) + + auditPath := filepath.Join(ws, "hooks", "hook-events.jsonl") + deadline := time.Now().Add(2 * time.Second) + for { + data, err := os.ReadFile(auditPath) + if err == nil && strings.Contains(string(data), "\"turn_id\":\"turn-1\"") { + return + } + if time.Now().After(deadline) { + if err != nil { + t.Fatalf("read audit file after wait: %v", err) + } + t.Fatalf("audit content missing turn_id after wait: %s", string(data)) + } + time.Sleep(10 * time.Millisecond) } } diff --git a/pkg/hooks/builtin/policy.go b/pkg/hooks/builtin/policy.go index 029035a..f05534c 100644 --- a/pkg/hooks/builtin/policy.go +++ b/pkg/hooks/builtin/policy.go @@ -2,6 +2,7 @@ package builtin import ( "context" + "errors" "github.com/sipeed/picoclaw/pkg/hookpolicy" "github.com/sipeed/picoclaw/pkg/hooks" @@ -9,11 +10,19 @@ import ( // PolicyHandler applies workspace hook policy (HOOKS.md + hooks.yaml). type PolicyHandler struct { - workspace string + policy hookpolicy.Policy + warnings []string + loadErr error } -func NewPolicyHandler(workspace string) *PolicyHandler { - return &PolicyHandler{workspace: workspace} +func NewPolicyHandler(policy hookpolicy.Policy, diag hookpolicy.Diagnostics, loadErr error) *PolicyHandler { + warnings := make([]string, 0, len(diag.Warnings)) + warnings = append(warnings, diag.Warnings...) + return &PolicyHandler{ + policy: policy, + warnings: warnings, + loadErr: loadErr, + } } func (h *PolicyHandler) Name() string { @@ -21,28 +30,34 @@ func (h *PolicyHandler) Name() string { } func (h *PolicyHandler) Handle(_ context.Context, ev hooks.Event, data hooks.Context) hooks.Result { - policy, diag, err := hookpolicy.LoadPolicy(h.workspace) - if err != nil { + if h.loadErr != nil { return hooks.Result{ Status: hooks.StatusError, Message: "failed to load hook policy", - Err: err, + Err: h.loadErr, + } + } + if h.policy.Events == nil { + return hooks.Result{ + Status: hooks.StatusError, + Message: "hook policy missing event configuration", + Err: errors.New("hook policy events are not initialized"), } } meta := map[string]any{ - "policy_enabled": policy.Enabled, + "policy_enabled": h.policy.Enabled, "turn_id": data.TurnID, } - if len(diag.Warnings) > 0 { - meta["warnings"] = diag.Warnings + if len(h.warnings) > 0 { + meta["warnings"] = h.warnings } - if !policy.Enabled { + if !h.policy.Enabled { return hooks.Result{Status: hooks.StatusOK, Message: "hooks disabled by policy", Metadata: meta} } - eventPolicy, ok := policy.Events[ev] + eventPolicy, ok := h.policy.Events[ev] if !ok { return hooks.Result{Status: hooks.StatusOK, Message: "event not configured", Metadata: meta} } diff --git a/pkg/routing/resolver.go b/pkg/routing/resolver.go index c323186..710c29c 100644 --- a/pkg/routing/resolver.go +++ b/pkg/routing/resolver.go @@ -202,8 +202,11 @@ func ensureReadableWorkspace(path string) error { if !info.IsDir() { return fmt.Errorf("not a directory") } - _, err = os.ReadDir(path) - return err + // Do not call os.ReadDir() on every inbound message. + // On macOS cloud-backed folders (e.g. Dropbox/iCloud), ReadDir can stall for + // minutes and block the single routing dispatcher goroutine, which prevents + // otherwise valid Discord mentions from becoming active tasks. + return nil } func isMentionOrDM(metadata map[string]string) bool { diff --git a/pkg/routing/resolver_test.go b/pkg/routing/resolver_test.go index 55e3f78..3fc1dcb 100644 --- a/pkg/routing/resolver_test.go +++ b/pkg/routing/resolver_test.go @@ -170,6 +170,49 @@ func TestResolve_InvalidWorkspace(t *testing.T) { } } +func TestResolve_WorkspaceWithoutReadPermissionStillRoutes(t *testing.T) { + root := t.TempDir() + ws := filepath.Join(root, "no-read") + if err := os.MkdirAll(ws, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.Chmod(ws, 0o111); err != nil { + t.Fatalf("chmod: %v", err) + } + t.Cleanup(func() { + _ = os.Chmod(ws, 0o755) + }) + + cfg := config.DefaultConfig() + cfg.Routing.Enabled = true + cfg.Routing.Mappings = []config.RoutingMapping{ + { + Channel: "discord", + ChatID: "123", + Workspace: ws, + AllowedSenders: []string{"u1"}, + }, + } + + resolver, err := NewResolver(cfg) + if err != nil { + t.Fatalf("NewResolver error: %v", err) + } + + d := resolver.Resolve(bus.InboundMessage{ + Channel: "discord", + ChatID: "123", + SenderID: "u1", + Metadata: map[string]string{"is_mention": "true"}, + }) + if !d.Allowed { + t.Fatalf("expected allowed decision, got %+v", d) + } + if d.Event != EventRouteMatch { + t.Fatalf("unexpected event: %s", d.Event) + } +} + func TestResolve_SystemMessageUsesOriginMapping(t *testing.T) { ws := t.TempDir() cfg := config.DefaultConfig() diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 193ad2b..dfc1e16 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -2,12 +2,15 @@ package session import ( "encoding/json" + "errors" "os" "path/filepath" + "sort" "strings" "sync" "time" + "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" ) @@ -25,6 +28,15 @@ type SessionManager struct { storage string } +var ( + // Keep gateway startup/routing responsive even if cloud-backed folders stall. + sessionLoadTimeout = 750 * time.Millisecond + sessionSaveWarnTime = 750 * time.Millisecond + errSessionLoadTimed = errors.New("session load timed out") + readDir = os.ReadDir + readFile = os.ReadFile +) + func NewSessionManager(storage string) *SessionManager { sm := &SessionManager{ sessions: make(map[string]*Session), @@ -33,7 +45,12 @@ func NewSessionManager(storage string) *SessionManager { if storage != "" { os.MkdirAll(storage, 0755) - sm.loadSessions() + if err := sm.loadSessionsWithTimeout(sessionLoadTimeout); err != nil { + logger.WarnCF("session", "Session preload skipped", map[string]interface{}{ + "storage": storage, + "error": err.Error(), + }) + } } return sm @@ -100,6 +117,67 @@ func (sm *SessionManager) GetHistory(key string) []providers.Message { return history } +// ListKeys returns all known session keys in stable order. +func (sm *SessionManager) ListKeys() []string { + sm.mu.RLock() + defer sm.mu.RUnlock() + + keys := make([]string, 0, len(sm.sessions)) + for key := range sm.sessions { + keys = append(keys, key) + } + sort.Strings(keys) + return keys +} + +// Snapshot returns a deep copy of one session if it exists. +func (sm *SessionManager) Snapshot(key string) (Session, bool) { + sm.mu.RLock() + defer sm.mu.RUnlock() + + stored, ok := sm.sessions[key] + if !ok || stored == nil { + return Session{}, false + } + + out := Session{ + Key: stored.Key, + Summary: stored.Summary, + Created: stored.Created, + Updated: stored.Updated, + } + if len(stored.Messages) > 0 { + out.Messages = make([]providers.Message, len(stored.Messages)) + copy(out.Messages, stored.Messages) + } else { + out.Messages = []providers.Message{} + } + return out, true +} + +// ReplaceHistory replaces the message history for a session. +func (sm *SessionManager) ReplaceHistory(sessionKey string, history []providers.Message) { + sm.mu.Lock() + defer sm.mu.Unlock() + + stored, ok := sm.sessions[sessionKey] + if !ok || stored == nil { + stored = &Session{ + Key: sessionKey, + Created: time.Now(), + } + sm.sessions[sessionKey] = stored + } + + if len(history) == 0 { + stored.Messages = []providers.Message{} + } else { + stored.Messages = make([]providers.Message, len(history)) + copy(stored.Messages, history) + } + stored.Updated = time.Now() +} + func (sm *SessionManager) GetSummary(key string) string { sm.mu.RLock() defer sm.mu.RUnlock() @@ -149,12 +227,20 @@ func (sm *SessionManager) Save(key string) error { if sm.storage == "" { return nil } + saveStartedAt := time.Now() // Validate key to avoid invalid filenames and path traversal. if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") { return os.ErrInvalid } + if strings.HasPrefix(key, "discord:") { + logger.InfoCF("session", "Session save start", map[string]interface{}{ + "session_key": key, + "storage": sm.storage, + }) + } + // Snapshot under read lock, then perform slow file I/O after unlock. sm.mu.RLock() stored, ok := sm.sessions[key] @@ -179,12 +265,26 @@ func (sm *SessionManager) Save(key string) error { data, err := json.MarshalIndent(snapshot, "", " ") if err != nil { + if strings.HasPrefix(key, "discord:") { + logger.WarnCF("session", "Session save marshal failed", map[string]interface{}{ + "session_key": key, + "error": err.Error(), + "duration_ms": time.Since(saveStartedAt).Milliseconds(), + }) + } return err } sessionPath := filepath.Join(sm.storage, key+".json") tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp") if err != nil { + if strings.HasPrefix(key, "discord:") { + logger.WarnCF("session", "Session save temp file create failed", map[string]interface{}{ + "session_key": key, + "error": err.Error(), + "duration_ms": time.Since(saveStartedAt).Milliseconds(), + }) + } return err } @@ -198,29 +298,78 @@ func (sm *SessionManager) Save(key string) error { if _, err := tmpFile.Write(data); err != nil { _ = tmpFile.Close() + if strings.HasPrefix(key, "discord:") { + logger.WarnCF("session", "Session save write failed", map[string]interface{}{ + "session_key": key, + "error": err.Error(), + "duration_ms": time.Since(saveStartedAt).Milliseconds(), + }) + } return err } if err := tmpFile.Chmod(0644); err != nil { _ = tmpFile.Close() + if strings.HasPrefix(key, "discord:") { + logger.WarnCF("session", "Session save chmod failed", map[string]interface{}{ + "session_key": key, + "error": err.Error(), + "duration_ms": time.Since(saveStartedAt).Milliseconds(), + }) + } return err } if err := tmpFile.Sync(); err != nil { _ = tmpFile.Close() + if strings.HasPrefix(key, "discord:") { + logger.WarnCF("session", "Session save fsync failed", map[string]interface{}{ + "session_key": key, + "error": err.Error(), + "duration_ms": time.Since(saveStartedAt).Milliseconds(), + }) + } return err } if err := tmpFile.Close(); err != nil { + if strings.HasPrefix(key, "discord:") { + logger.WarnCF("session", "Session save close failed", map[string]interface{}{ + "session_key": key, + "error": err.Error(), + "duration_ms": time.Since(saveStartedAt).Milliseconds(), + }) + } return err } if err := os.Rename(tmpPath, sessionPath); err != nil { + if strings.HasPrefix(key, "discord:") { + logger.WarnCF("session", "Session save rename failed", map[string]interface{}{ + "session_key": key, + "path": sessionPath, + "error": err.Error(), + "duration_ms": time.Since(saveStartedAt).Milliseconds(), + }) + } return err } cleanup = false + if strings.HasPrefix(key, "discord:") { + elapsed := time.Since(saveStartedAt) + fields := map[string]interface{}{ + "session_key": key, + "path": sessionPath, + "duration_ms": elapsed.Milliseconds(), + } + if elapsed >= sessionSaveWarnTime { + logger.WarnCF("session", "Session save completed slowly", fields) + } else { + logger.InfoCF("session", "Session save complete", fields) + } + } return nil } func (sm *SessionManager) loadSessions() error { - files, err := os.ReadDir(sm.storage) + files, err := readDir(sm.storage) if err != nil { return err } @@ -235,7 +384,7 @@ func (sm *SessionManager) loadSessions() error { } sessionPath := filepath.Join(sm.storage, file.Name()) - data, err := os.ReadFile(sessionPath) + data, err := readFile(sessionPath) if err != nil { continue } @@ -245,8 +394,28 @@ func (sm *SessionManager) loadSessions() error { continue } + sm.mu.Lock() sm.sessions[session.Key] = &session + sm.mu.Unlock() } return nil } + +func (sm *SessionManager) loadSessionsWithTimeout(timeout time.Duration) error { + if timeout <= 0 { + return sm.loadSessions() + } + + done := make(chan error, 1) + go func() { + done <- sm.loadSessions() + }() + + select { + case err := <-done: + return err + case <-time.After(timeout): + return errSessionLoadTimed + } +} diff --git a/pkg/session/manager_test.go b/pkg/session/manager_test.go new file mode 100644 index 0000000..57d9019 --- /dev/null +++ b/pkg/session/manager_test.go @@ -0,0 +1,133 @@ +package session + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func TestSessionManagerListKeysSorted(t *testing.T) { + sm := NewSessionManager("") + sm.AddMessage("discord:z", "user", "z") + sm.AddMessage("discord:a", "user", "a") + sm.AddMessage("discord:m", "user", "m") + + keys := sm.ListKeys() + if len(keys) != 3 { + t.Fatalf("expected 3 keys, got %d", len(keys)) + } + if keys[0] != "discord:a" || keys[1] != "discord:m" || keys[2] != "discord:z" { + t.Fatalf("unexpected key order: %#v", keys) + } +} + +func TestSessionManagerSnapshotDeepCopy(t *testing.T) { + sm := NewSessionManager("") + sm.AddMessage("discord:test", "user", "hello") + sm.AddMessage("discord:test", "assistant", "world") + + snap, ok := sm.Snapshot("discord:test") + if !ok { + t.Fatal("expected snapshot to exist") + } + if len(snap.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(snap.Messages)) + } + + // Mutate snapshot and ensure manager state is unchanged. + snap.Messages[0].Content = "mutated" + history := sm.GetHistory("discord:test") + if history[0].Content != "hello" { + t.Fatalf("manager history should remain unchanged, got %q", history[0].Content) + } +} + +func TestSessionManagerReplaceHistory(t *testing.T) { + sm := NewSessionManager("") + sm.AddMessage("discord:test", "user", "old") + + newHistory := []providers.Message{ + {Role: "user", Content: "u1"}, + {Role: "assistant", Content: "a1"}, + } + sm.ReplaceHistory("discord:test", newHistory) + + history := sm.GetHistory("discord:test") + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if history[0].Content != "u1" || history[1].Content != "a1" { + t.Fatalf("unexpected replaced history: %#v", history) + } + + // Mutating caller slice should not mutate stored history. + newHistory[0].Content = "changed" + history = sm.GetHistory("discord:test") + if history[0].Content != "u1" { + t.Fatalf("stored history mutated by caller slice change: %#v", history) + } +} + +func TestSessionManagerPreloadFastPath(t *testing.T) { + dir := t.TempDir() + payload := Session{ + Key: "discord:123@abc", + Messages: []providers.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "world"}, + }, + Created: time.Now().UTC(), + Updated: time.Now().UTC(), + } + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(filepath.Join(dir, payload.Key+".json"), data, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + sm := NewSessionManager(dir) + history := sm.GetHistory(payload.Key) + if len(history) != 2 { + t.Fatalf("expected 2 messages preloaded, got %d", len(history)) + } + if history[0].Content != "hello" || history[1].Content != "world" { + t.Fatalf("unexpected preloaded history: %#v", history) + } +} + +func TestSessionManagerPreloadTimeoutDoesNotBlockConstructor(t *testing.T) { + dir := t.TempDir() + + prevTimeout := sessionLoadTimeout + prevReadDir := readDir + prevReadFile := readFile + defer func() { + sessionLoadTimeout = prevTimeout + readDir = prevReadDir + readFile = prevReadFile + }() + + sessionLoadTimeout = 20 * time.Millisecond + release := make(chan struct{}) + readDir = func(string) ([]os.DirEntry, error) { + <-release + return nil, nil + } + + start := time.Now() + _ = NewSessionManager(dir) + elapsed := time.Since(start) + if elapsed > 120*time.Millisecond { + t.Fatalf("constructor blocked too long: %v", elapsed) + } + + // Let background preload goroutine finish before restoring globals. + close(release) + time.Sleep(5 * time.Millisecond) +} diff --git a/pkg/state/state.go b/pkg/state/state.go index 0bb9cd4..198fb63 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -29,8 +29,14 @@ type Manager struct { state *State mu sync.RWMutex stateFile string + saveQueue chan State } +var ( + stateReadFile = os.ReadFile + stateBootstrapTimeout = 750 * time.Millisecond +) + // NewManager creates a new state manager for the given workspace. func NewManager(workspace string) *Manager { stateDir := filepath.Join(workspace, "state") @@ -44,23 +50,23 @@ func NewManager(workspace string) *Manager { workspace: workspace, stateFile: stateFile, state: &State{}, + saveQueue: make(chan State, 1), } - // Try to load from new location first - if _, err := os.Stat(stateFile); os.IsNotExist(err) { - // New file doesn't exist, try migrating from old location - if data, err := os.ReadFile(oldStateFile); err == nil { - if err := json.Unmarshal(data, sm.state); err == nil { - // Migrate to new location - sm.saveAtomic() - log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile) - } + loadedState, loadedFromLegacy, err := loadBootstrapWithTimeout(stateFile, oldStateFile, stateBootstrapTimeout) + if err != nil { + log.Printf("[WARN] state: bootstrap skipped for %s: %v", workspace, err) + } else if loadedState != nil { + sm.state = loadedState + if loadedFromLegacy { + // Keep startup non-blocking on cloud-backed filesystems. + // The state will be persisted in the new location on next write. + log.Printf("[INFO] state: loaded legacy state from %s", oldStateFile) } - } else { - // Load from new location - sm.load() } + go sm.saveLoop() + return sm } @@ -69,34 +75,22 @@ func NewManager(workspace string) *Manager { // ensuring that the state file is never corrupted even if the process crashes. func (sm *Manager) SetLastChannel(channel string) error { sm.mu.Lock() - defer sm.mu.Unlock() - - // Update state sm.state.LastChannel = channel sm.state.Timestamp = time.Now() - - // Atomic save using temp file + rename - if err := sm.saveAtomic(); err != nil { - return fmt.Errorf("failed to save state atomically: %w", err) - } - + snapshot := *sm.state + sm.mu.Unlock() + sm.enqueueSave(snapshot) return nil } // SetLastChatID atomically updates the last chat ID and saves the state. func (sm *Manager) SetLastChatID(chatID string) error { sm.mu.Lock() - defer sm.mu.Unlock() - - // Update state sm.state.LastChatID = chatID sm.state.Timestamp = time.Now() - - // Atomic save using temp file + rename - if err := sm.saveAtomic(); err != nil { - return fmt.Errorf("failed to save state atomically: %w", err) - } - + snapshot := *sm.state + sm.mu.Unlock() + sm.enqueueSave(snapshot) return nil } @@ -121,19 +115,17 @@ func (sm *Manager) GetTimestamp() time.Time { return sm.state.Timestamp } -// saveAtomic performs an atomic save using temp file + rename. +// saveAtomicSnapshot performs an atomic save using temp file + rename. // This ensures that the state file is never corrupted: // 1. Write to a temp file // 2. Rename temp file to target (atomic on POSIX systems) // 3. If rename fails, cleanup the temp file -// -// Must be called with the lock held. -func (sm *Manager) saveAtomic() error { +func (sm *Manager) saveAtomicSnapshot(snapshot State) error { // Create temp file in the same directory as the target tempFile := sm.stateFile + ".tmp" // Marshal state to JSON - data, err := json.MarshalIndent(sm.state, "", " ") + data, err := json.MarshalIndent(snapshot, "", " ") if err != nil { return fmt.Errorf("failed to marshal state: %w", err) } @@ -153,20 +145,101 @@ func (sm *Manager) saveAtomic() error { return nil } +func (sm *Manager) saveLoop() { + for snapshot := range sm.saveQueue { + if err := sm.saveAtomicSnapshot(snapshot); err != nil { + log.Printf("[WARN] state: async save failed for %s: %v", sm.workspace, err) + } + } +} + +func (sm *Manager) enqueueSave(snapshot State) { + select { + case sm.saveQueue <- snapshot: + return + default: + } + + // Queue already has an older snapshot; drop it and enqueue the latest. + select { + case <-sm.saveQueue: + default: + } + select { + case sm.saveQueue <- snapshot: + default: + } +} + // load loads the state from disk. func (sm *Manager) load() error { - data, err := os.ReadFile(sm.stateFile) + loaded, err := loadStateFromPath(sm.stateFile) if err != nil { - // File doesn't exist yet, that's OK - if os.IsNotExist(err) { - return nil + return err + } + if loaded != nil { + sm.state = loaded + } + return nil +} + +func loadBootstrapWithTimeout(stateFile, oldStateFile string, timeout time.Duration) (*State, bool, error) { + if timeout <= 0 { + return loadBootstrap(stateFile, oldStateFile) + } + + type result struct { + state *State + fromLegacy bool + err error + } + + done := make(chan result, 1) + go func() { + st, legacy, err := loadBootstrap(stateFile, oldStateFile) + done <- result{ + state: st, + fromLegacy: legacy, + err: err, } - return fmt.Errorf("failed to read state file: %w", err) + }() + + select { + case out := <-done: + return out.state, out.fromLegacy, out.err + case <-time.After(timeout): + return nil, false, fmt.Errorf("state load timed out") + } +} + +func loadBootstrap(stateFile, oldStateFile string) (*State, bool, error) { + if st, err := loadStateFromPath(stateFile); err != nil { + return nil, false, err + } else if st != nil { + return st, false, nil } - if err := json.Unmarshal(data, sm.state); err != nil { - return fmt.Errorf("failed to unmarshal state: %w", err) + if st, err := loadStateFromPath(oldStateFile); err != nil { + return nil, false, err + } else if st != nil { + return st, true, nil } - return nil + return nil, false, nil +} + +func loadStateFromPath(path string) (*State, error) { + data, err := stateReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to read state file %s: %w", path, err) + } + + var st State + if err := json.Unmarshal(data, &st); err != nil { + return nil, fmt.Errorf("failed to unmarshal state %s: %w", path, err) + } + return &st, nil } diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index ce3dd72..3d7101f 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -6,8 +6,21 @@ import ( "os" "path/filepath" "testing" + "time" ) +func waitForCondition(t *testing.T, timeout time.Duration, cond func() bool, msg string) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal(msg) +} + func TestAtomicSave(t *testing.T) { // Create temp workspace tmpDir, err := os.MkdirTemp("", "state-test-*") @@ -37,15 +50,16 @@ func TestAtomicSave(t *testing.T) { // Verify state file exists stateFile := filepath.Join(tmpDir, "state", "state.json") - if _, err := os.Stat(stateFile); os.IsNotExist(err) { - t.Error("Expected state file to exist") - } + waitForCondition(t, time.Second, func() bool { + _, err := os.Stat(stateFile) + return err == nil + }, "expected state file to exist") // Create a new manager to verify persistence - sm2 := NewManager(tmpDir) - if sm2.GetLastChannel() != "test-channel" { - t.Errorf("Expected persistent channel 'test-channel', got '%s'", sm2.GetLastChannel()) - } + waitForCondition(t, time.Second, func() bool { + sm2 := NewManager(tmpDir) + return sm2.GetLastChannel() == "test-channel" + }, "expected persistent channel 'test-channel'") } func TestSetLastChatID(t *testing.T) { @@ -75,10 +89,10 @@ func TestSetLastChatID(t *testing.T) { } // Create a new manager to verify persistence - sm2 := NewManager(tmpDir) - if sm2.GetLastChatID() != "test-chat-id" { - t.Errorf("Expected persistent chat ID 'test-chat-id', got '%s'", sm2.GetLastChatID()) - } + waitForCondition(t, time.Second, func() bool { + sm2 := NewManager(tmpDir) + return sm2.GetLastChatID() == "test-chat-id" + }, "expected persistent chat ID 'test-chat-id'") } func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) { @@ -156,6 +170,11 @@ func TestConcurrentAccess(t *testing.T) { // Verify state file is valid JSON stateFile := filepath.Join(tmpDir, "state", "state.json") + waitForCondition(t, time.Second, func() bool { + _, err := os.Stat(stateFile) + return err == nil + }, "expected state file to exist after concurrent writes") + data, err := os.ReadFile(stateFile) if err != nil { t.Fatalf("Failed to read state file: %v", err) @@ -179,17 +198,12 @@ func TestNewManager_ExistingState(t *testing.T) { sm1.SetLastChannel("existing-channel") sm1.SetLastChatID("existing-chat-id") - // Create new manager with same workspace - sm2 := NewManager(tmpDir) - - // Verify state was loaded - if sm2.GetLastChannel() != "existing-channel" { - t.Errorf("Expected channel 'existing-channel', got '%s'", sm2.GetLastChannel()) - } - - if sm2.GetLastChatID() != "existing-chat-id" { - t.Errorf("Expected chat ID 'existing-chat-id', got '%s'", sm2.GetLastChatID()) - } + // Create new manager with same workspace once persistence catches up. + waitForCondition(t, time.Second, func() bool { + sm2 := NewManager(tmpDir) + return sm2.GetLastChannel() == "existing-channel" && + sm2.GetLastChatID() == "existing-chat-id" + }, "expected existing state to be loaded") } func TestNewManager_EmptyWorkspace(t *testing.T) { @@ -214,3 +228,64 @@ func TestNewManager_EmptyWorkspace(t *testing.T) { t.Error("Expected zero timestamp for new state") } } + +func TestNewManager_LoadsLegacyStateFile(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + legacy := State{ + LastChannel: "legacy-channel", + LastChatID: "legacy-chat", + Timestamp: time.Now().UTC(), + } + data, err := json.Marshal(legacy) + if err != nil { + t.Fatalf("marshal legacy state: %v", err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "state.json"), data, 0o644); err != nil { + t.Fatalf("write legacy state: %v", err) + } + + sm := NewManager(tmpDir) + if sm.GetLastChannel() != "legacy-channel" { + t.Fatalf("expected legacy channel, got %q", sm.GetLastChannel()) + } + if sm.GetLastChatID() != "legacy-chat" { + t.Fatalf("expected legacy chat id, got %q", sm.GetLastChatID()) + } +} + +func TestNewManager_BootstrapTimeoutDoesNotBlock(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + prevRead := stateReadFile + prevTimeout := stateBootstrapTimeout + defer func() { + stateReadFile = prevRead + stateBootstrapTimeout = prevTimeout + }() + + block := make(chan struct{}) + stateReadFile = func(string) ([]byte, error) { + <-block + return nil, os.ErrNotExist + } + stateBootstrapTimeout = 20 * time.Millisecond + + start := time.Now() + _ = NewManager(tmpDir) + elapsed := time.Since(start) + if elapsed > 150*time.Millisecond { + t.Fatalf("expected constructor to return quickly, took %v", elapsed) + } + + close(block) + time.Sleep(5 * time.Millisecond) +} diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 545e343..9a85e8e 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -2,10 +2,12 @@ package tools import ( "context" + "errors" "fmt" "os" "path/filepath" "strings" + "time" ) type AccessMode int @@ -21,6 +23,15 @@ type allowedRoot struct { readOnly bool } +var ( + fileToolOpTimeout = 1500 * time.Millisecond + fileToolReadFile = os.ReadFile + fileToolReadDir = os.ReadDir + + errFileReadTimedOut = errors.New("file read timed out") + errDirReadTimedOut = errors.New("directory read timed out") +) + // validatePath ensures the given path is within the workspace if restrict is true. func validatePath(path, workspace string, restrict bool) (string, error) { return validatePathWithPolicy(path, workspace, restrict, AccessRead, "", false) @@ -209,8 +220,15 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) return UserErrorResult(err.Error()) } - content, err := os.ReadFile(resolvedPath) + content, err := readFileWithTimeout(ctx, resolvedPath, fileToolOpTimeout) if err != nil { + if errors.Is(err, errFileReadTimedOut) { + return ErrorResult(fmt.Sprintf( + "failed to read file: timed out after %dms (path=%s)", + fileToolOpTimeout.Milliseconds(), + resolvedPath, + )) + } return ErrorResult(fmt.Sprintf("failed to read file: %v", err)) } @@ -334,8 +352,15 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) return UserErrorResult(err.Error()) } - entries, err := os.ReadDir(resolvedPath) + entries, err := readDirWithTimeout(ctx, resolvedPath, fileToolOpTimeout) if err != nil { + if errors.Is(err, errDirReadTimedOut) { + return ErrorResult(fmt.Sprintf( + "failed to read directory: timed out after %dms (path=%s)", + fileToolOpTimeout.Milliseconds(), + resolvedPath, + )) + } return ErrorResult(fmt.Sprintf("failed to read directory: %v", err)) } @@ -350,3 +375,59 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) return NewToolResult(result) } + +func readFileWithTimeout(ctx context.Context, path string, timeout time.Duration) ([]byte, error) { + if timeout <= 0 { + return fileToolReadFile(path) + } + + type fileReadResult struct { + content []byte + err error + } + done := make(chan fileReadResult, 1) + go func() { + content, err := fileToolReadFile(path) + done <- fileReadResult{content: content, err: err} + }() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case out := <-done: + return out.content, out.err + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + return nil, errFileReadTimedOut + } +} + +func readDirWithTimeout(ctx context.Context, path string, timeout time.Duration) ([]os.DirEntry, error) { + if timeout <= 0 { + return fileToolReadDir(path) + } + + type dirReadResult struct { + entries []os.DirEntry + err error + } + done := make(chan dirReadResult, 1) + go func() { + entries, err := fileToolReadDir(path) + done <- dirReadResult{entries: entries, err: err} + }() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case out := <-done: + return out.entries, out.err + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + return nil, errDirReadTimedOut + } +} diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go index 684e568..ccca6f3 100644 --- a/pkg/tools/filesystem_test.go +++ b/pkg/tools/filesystem_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "strings" "testing" + "time" ) // TestFilesystemTool_ReadFile_Success verifies successful file reading @@ -248,6 +249,54 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { } } +func TestFilesystemTool_ReadFile_Timeout(t *testing.T) { + origReadFile := fileToolReadFile + origTimeout := fileToolOpTimeout + t.Cleanup(func() { + fileToolReadFile = origReadFile + fileToolOpTimeout = origTimeout + }) + + fileToolOpTimeout = 20 * time.Millisecond + fileToolReadFile = func(path string) ([]byte, error) { + time.Sleep(200 * time.Millisecond) + return []byte("late"), nil + } + + tool := &ReadFileTool{} + result := tool.Execute(context.Background(), map[string]interface{}{"path": "/tmp/slow-file.txt"}) + if !result.IsError { + t.Fatalf("expected timeout error") + } + if !strings.Contains(result.ForLLM, "timed out") { + t.Fatalf("expected timeout message, got: %s", result.ForLLM) + } +} + +func TestFilesystemTool_ListDir_Timeout(t *testing.T) { + origReadDir := fileToolReadDir + origTimeout := fileToolOpTimeout + t.Cleanup(func() { + fileToolReadDir = origReadDir + fileToolOpTimeout = origTimeout + }) + + fileToolOpTimeout = 20 * time.Millisecond + fileToolReadDir = func(path string) ([]os.DirEntry, error) { + time.Sleep(200 * time.Millisecond) + return nil, nil + } + + tool := &ListDirTool{} + result := tool.Execute(context.Background(), map[string]interface{}{"path": "/tmp/slow-dir"}) + if !result.IsError { + t.Fatalf("expected timeout error") + } + if !strings.Contains(result.ForLLM, "timed out") { + t.Fatalf("expected timeout message, got: %s", result.ForLLM) + } +} + // Block paths that look inside workspace but point outside via symlink. func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) { diff --git a/pkg/tools/message.go b/pkg/tools/message.go index 976b0e8..6456cfc 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -6,12 +6,15 @@ import ( "os" "path/filepath" "strings" + "unicode/utf8" "github.com/sipeed/picoclaw/pkg/bus" ) type SendCallback func(channel, chatID, content string, attachments []bus.OutboundAttachment) error +const messageStatusPreviewRunes = 80 + type MessageTool struct { sendCallback SendCallback defaultChannel string @@ -148,9 +151,16 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) t.sentInRound = true // Silent: user already received the message directly - status := fmt.Sprintf("Message sent to %s:%s", channel, chatID) + status := fmt.Sprintf("Message sent to %s:%s (chars=%d", channel, chatID, utf8.RuneCountInString(content)) + if preview := summarizeMessageContentForStatus(content); preview != "" { + status += fmt.Sprintf(", preview=%q", preview) + } + status += ")" if len(attachments) > 0 { status += fmt.Sprintf(" with %d attachment(s)", len(attachments)) + if attachmentNames := summarizeAttachmentNames(attachments, 3); attachmentNames != "" { + status += fmt.Sprintf(" [%s]", attachmentNames) + } } return &ToolResult{ ForLLM: status, @@ -213,3 +223,49 @@ func (t *MessageTool) parseAttachments(args map[string]interface{}) ([]bus.Outbo return attachments, nil } + +func summarizeMessageContentForStatus(content string) string { + compact := strings.Join(strings.Fields(strings.TrimSpace(content)), " ") + if compact == "" { + return "" + } + runes := []rune(compact) + if len(runes) <= messageStatusPreviewRunes { + return compact + } + return string(runes[:messageStatusPreviewRunes]) + "..." +} + +func summarizeAttachmentNames(attachments []bus.OutboundAttachment, max int) string { + if len(attachments) == 0 || max <= 0 { + return "" + } + + names := make([]string, 0, minInt(len(attachments), max)) + for i, attachment := range attachments { + if i >= max { + break + } + name := strings.TrimSpace(attachment.Filename) + if name == "" { + name = filepath.Base(attachment.Path) + } + if name != "" { + names = append(names, name) + } + } + if len(names) == 0 { + return "" + } + if len(attachments) > max { + names = append(names, fmt.Sprintf("+%d more", len(attachments)-max)) + } + return strings.Join(names, ", ") +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go index a1b1a7f..70c2911 100644 --- a/pkg/tools/message_test.go +++ b/pkg/tools/message_test.go @@ -51,8 +51,14 @@ func TestMessageTool_Execute_Success(t *testing.T) { } // - ForLLM contains send status description - if result.ForLLM != "Message sent to test-channel:test-chat-id" { - t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM) + if !strings.Contains(result.ForLLM, "Message sent to test-channel:test-chat-id") { + t.Errorf("Expected channel/chat in ForLLM, got '%s'", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "chars=13") { + t.Errorf("Expected char count in ForLLM, got '%s'", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "preview=\"Hello, world!\"") { + t.Errorf("Expected content preview in ForLLM, got '%s'", result.ForLLM) } // - ForUser is empty (user already received message directly) @@ -97,8 +103,8 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { if !result.Silent { t.Error("Expected Silent=true") } - if result.ForLLM != "Message sent to custom-channel:custom-chat-id" { - t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM) + if !strings.Contains(result.ForLLM, "Message sent to custom-channel:custom-chat-id") { + t.Errorf("Expected channel/chat in ForLLM, got '%s'", result.ForLLM) } } @@ -243,6 +249,9 @@ func TestMessageTool_Execute_WithAttachments(t *testing.T) { if !strings.Contains(result.ForLLM, "with 1 attachment(s)") { t.Fatalf("expected attachment status in ForLLM, got %q", result.ForLLM) } + if !strings.Contains(result.ForLLM, "final-report.docx") { + t.Fatalf("expected attachment filename in ForLLM, got %q", result.ForLLM) + } } func TestMessageTool_Execute_AttachmentOutsideWorkspaceBlocked(t *testing.T) { diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 69b0085..9c79b07 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -14,6 +14,8 @@ import ( "strconv" "strings" "time" + + "github.com/sipeed/picoclaw/pkg/logger" ) type ExecTool struct { @@ -28,7 +30,7 @@ type ExecTool struct { } var ( - shellPathPattern = regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`) + shellPathPattern = regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`) shellURLPattern = regexp.MustCompile("https?://[^\\s\"'`]+") shellMutatingCommandPattern = regexp.MustCompile(`(?i)(^|[;&|()\s])(touch|mkdir|rmdir|rm|mv|cp|install|chmod|chown|truncate|tee|sed\s+-i|perl\s+-i|pandoc)([;&|()\s]|$)`) shellWriteRedirectPattern = regexp.MustCompile(`(^|[^0-9])>>?`) @@ -43,6 +45,14 @@ var ( } ) +const ( + execSlowLogThreshold = 2 * time.Second + execStageLogThreshold = 250 * time.Millisecond + execCommandPreviewMax = 220 + execStderrPreviewMax = 320 + execWorkingDirMaxChars = 220 +) + //go:embed assets/nih-standard.docx var embeddedNIHTemplate []byte @@ -111,39 +121,108 @@ func (t *ExecTool) Parameters() map[string]interface{} { } func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { + execStartedAt := time.Now() + totalDuration := func() time.Duration { return time.Since(execStartedAt) } + commandPreview := "" + commandExecPreview := "" + requestedCWD := t.workingDir + resolvedCWD := "" + cwdSource := "tool_default" + var cwdResolveDuration time.Duration + var cwdValidateDuration time.Duration + var guardDuration time.Duration + var pandocDuration time.Duration + var startDuration time.Duration + var waitDuration time.Duration + var terminateDuration time.Duration + + logFields := func(extra map[string]interface{}) map[string]interface{} { + fields := map[string]interface{}{ + "command": commandPreview, + "command_exec": commandExecPreview, + "requested_cwd": truncateForLog(requestedCWD, execWorkingDirMaxChars), + "resolved_cwd": truncateForLog(resolvedCWD, execWorkingDirMaxChars), + "cwd_source": cwdSource, + "total_ms": totalDuration().Milliseconds(), + "cwd_resolve_ms": cwdResolveDuration.Milliseconds(), + "cwd_validate_ms": cwdValidateDuration.Milliseconds(), + "guard_ms": guardDuration.Milliseconds(), + "pandoc_ms": pandocDuration.Milliseconds(), + "start_ms": startDuration.Milliseconds(), + "wait_ms": waitDuration.Milliseconds(), + "terminate_ms": terminateDuration.Milliseconds(), + "restrict_enabled": t.restrictToWorkspace, + } + for k, v := range extra { + fields[k] = v + } + return fields + } + command, ok := args["command"].(string) if !ok { return ErrorResult("command is required") } + commandPreview = truncateForLog(strings.TrimSpace(command), execCommandPreviewMax) + commandExecPreview = commandPreview cwd := t.workingDir if wd, ok := args["working_dir"].(string); ok && wd != "" { cwd = wd + requestedCWD = wd + cwdSource = "tool_arg" } if cwd == "" { + cwdResolveStartedAt := time.Now() wd, err := os.Getwd() + cwdResolveDuration = time.Since(cwdResolveStartedAt) if err == nil { cwd = wd + requestedCWD = wd + cwdSource = "os_getwd" + } else { + logger.WarnCF("tool.exec", "Exec could not resolve default working directory", logFields(map[string]interface{}{ + "error": err.Error(), + })) } } + resolvedCWD = cwd if t.restrictToWorkspace && strings.TrimSpace(cwd) != "" { + cwdValidateStartedAt := time.Now() resolvedCWD, err := validatePathWithPolicy(cwd, t.workingDir, true, AccessRead, t.sharedWorkspace, t.sharedWorkspaceReadOnly) + cwdValidateDuration = time.Since(cwdValidateStartedAt) if err != nil { + logger.WarnCF("tool.exec", "Exec blocked: invalid working directory", logFields(map[string]interface{}{ + "error": err.Error(), + })) return UserErrorResult("Command blocked by safety guard (" + err.Error() + ")") } cwd = resolvedCWD + resolvedCWD = cwd } - if guardError := t.guardCommand(command, cwd); guardError != "" { + guardStartedAt := time.Now() + guardError := t.guardCommand(command, cwd) + guardDuration = time.Since(guardStartedAt) + if guardError != "" { + logger.WarnCF("tool.exec", "Exec blocked by safety guard", logFields(map[string]interface{}{ + "error": guardError, + })) return UserErrorResult(guardError) } + pandocStartedAt := time.Now() withPandocDefaults, pandocErr := t.commandWithPandocDefaults(command) + pandocDuration = time.Since(pandocStartedAt) if pandocErr != nil { + logger.ErrorCF("tool.exec", "Exec failed while preparing pandoc defaults", logFields(map[string]interface{}{ + "error": pandocErr.Error(), + })) return ErrorResult(pandocErr.Error()) } command = withPandocDefaults + commandExecPreview = truncateForLog(strings.TrimSpace(command), execCommandPreviewMax) cmdCtx, cancel := context.WithTimeout(ctx, t.timeout) defer cancel() @@ -174,9 +253,15 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To cmd.Stdout = &stdout cmd.Stderr = &stderr + startStartedAt := time.Now() if err := cmd.Start(); err != nil { + startDuration = time.Since(startStartedAt) + logger.ErrorCF("tool.exec", "Exec failed to start process", logFields(map[string]interface{}{ + "error": err.Error(), + })) return ErrorResult(fmt.Sprintf("failed to start command: %v", err)) } + startDuration = time.Since(startStartedAt) done := make(chan error, 1) go func() { @@ -184,19 +269,28 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To }() var err error + waitStartedAt := time.Now() + timedOut := false + forceKilled := false select { case err = <-done: case <-cmdCtx.Done(): + timedOut = errors.Is(cmdCtx.Err(), context.DeadlineExceeded) + terminateStartedAt := time.Now() _ = terminateProcessTree(cmd) + terminateDuration = time.Since(terminateStartedAt) select { case err = <-done: case <-time.After(2 * time.Second): + forceKilled = true if cmd.Process != nil { _ = cmd.Process.Kill() } err = <-done } } + waitDuration = time.Since(waitStartedAt) + cmdCtxErr := cmdCtx.Err() output := stdout.String() if stderr.Len() > 0 { @@ -205,14 +299,10 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To if err != nil { if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) { - msg := fmt.Sprintf("Command timed out after %v", t.timeout) - return &ToolResult{ - ForLLM: msg, - ForUser: msg, - IsError: true, - } + output = fmt.Sprintf("Command timed out after %v", t.timeout) + } else { + output += fmt.Sprintf("\nExit code: %v", err) } - output += fmt.Sprintf("\nExit code: %v", err) } if output == "" { @@ -224,13 +314,41 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen) } + shouldLogSlow := totalDuration() >= execSlowLogThreshold || + cwdValidateDuration >= execStageLogThreshold || + guardDuration >= execStageLogThreshold + stderrPreview := "" + if stderr.Len() > 0 { + stderrPreview = truncateForLog(strings.TrimSpace(stderr.String()), execStderrPreviewMax) + } + if err != nil { + logMessage := "Exec command failed" + if timedOut { + logMessage = "Exec command timed out" + } + extra := map[string]interface{}{ + "error": err.Error(), + "timed_out": timedOut, + "force_killed": forceKilled, + "ctx_error": "", + "stderr_preview": stderrPreview, + } + if cmdCtxErr != nil { + extra["ctx_error"] = cmdCtxErr.Error() + } + logger.ErrorCF("tool.exec", logMessage, logFields(extra)) return &ToolResult{ ForLLM: output, ForUser: output, IsError: true, } } + if shouldLogSlow { + logger.WarnCF("tool.exec", "Exec command slow", logFields(map[string]interface{}{ + "stderr_preview": stderrPreview, + })) + } return &ToolResult{ ForLLM: output, @@ -239,6 +357,20 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To } } +func truncateForLog(input string, max int) string { + if max <= 0 { + return "" + } + input = strings.TrimSpace(input) + if len(input) <= max { + return input + } + if max <= 3 { + return input[:max] + } + return input[:max-3] + "..." +} + func mergeEnv(base []string, overrides map[string]string) []string { if len(overrides) == 0 { return base