diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 29827d0b2..504ce5c38 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -115,7 +115,7 @@ func registerSharedTools( }); searchTool != nil { agent.Tools.Register(searchTool) } - agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy)) + agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)) // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms agent.Tools.Register(tools.NewI2CTool()) diff --git a/pkg/config/config.go b/pkg/config/config.go index d84772d2b..55d0cfb2c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -523,7 +523,8 @@ type WebToolsConfig struct { Perplexity PerplexityConfig `json:"perplexity"` // Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h). // For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config. - Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` + FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"` } type CronToolsConfig struct { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index ebb924859..a2977b17e 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -299,7 +299,8 @@ func DefaultConfig() *Config { Interval: 5, }, Web: WebToolsConfig{ - Proxy: "", + Proxy: "", + FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default Brave: BraveConfig{ Enabled: false, APIKey: "", diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 8ba2a723a..695cc07c5 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -506,26 +507,35 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR } type WebFetchTool struct { - maxChars int - proxy string + maxChars int + proxy string + fetchLimitBytes int64 } -func NewWebFetchTool(maxChars int) *WebFetchTool { +func NewWebFetchTool(maxChars int, fetchLimitBytes int64) *WebFetchTool { if maxChars <= 0 { maxChars = 50000 } + if fetchLimitBytes <= 0 { + fetchLimitBytes = 10 * 1024 * 1024 // Security Fallback + } return &WebFetchTool{ - maxChars: maxChars, + maxChars: maxChars, + fetchLimitBytes: fetchLimitBytes, } } -func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool { +func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) *WebFetchTool { if maxChars <= 0 { maxChars = 50000 } + if fetchLimitBytes <= 0 { + fetchLimitBytes = 10 * 1024 * 1024 // Security Fallback + } return &WebFetchTool{ - maxChars: maxChars, - proxy: proxy, + maxChars: maxChars, + proxy: proxy, + fetchLimitBytes: fetchLimitBytes, } } @@ -605,10 +615,17 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe if err != nil { return ErrorResult(fmt.Sprintf("request failed: %v", err)) } + + resp.Body = http.MaxBytesReader(nil, resp.Body, t.fetchLimitBytes) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + return ErrorResult(fmt.Sprintf("failed to read response: size exceeded %d bytes limit", t.fetchLimitBytes)) + } return ErrorResult(fmt.Sprintf("failed to read response: %v", err)) } diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 2cd79eb24..299b911fd 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -1,8 +1,10 @@ package tools import ( + "bytes" "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" @@ -10,6 +12,8 @@ import ( "time" ) +const testFetchLimit = int64(10 * 1024 * 1024) + // TestWebTool_WebFetch_Success verifies successful URL fetching func TestWebTool_WebFetch_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -19,7 +23,7 @@ func TestWebTool_WebFetch_Success(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + tool := NewWebFetchTool(50000, testFetchLimit) ctx := context.Background() args := map[string]any{ "url": server.URL, @@ -55,7 +59,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + tool := NewWebFetchTool(50000, testFetchLimit) ctx := context.Background() args := map[string]any{ "url": server.URL, @@ -76,7 +80,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) { // TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL func TestWebTool_WebFetch_InvalidURL(t *testing.T) { - tool := NewWebFetchTool(50000) + tool := NewWebFetchTool(50000, testFetchLimit) ctx := context.Background() args := map[string]any{ "url": "not-a-valid-url", @@ -97,7 +101,7 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) { // TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) { - tool := NewWebFetchTool(50000) + tool := NewWebFetchTool(50000, testFetchLimit) ctx := context.Background() args := map[string]any{ "url": "ftp://example.com/file.txt", @@ -118,7 +122,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) { // TestWebTool_WebFetch_MissingURL verifies error handling for missing URL func TestWebTool_WebFetch_MissingURL(t *testing.T) { - tool := NewWebFetchTool(50000) + tool := NewWebFetchTool(50000, testFetchLimit) ctx := context.Background() args := map[string]any{} @@ -146,7 +150,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(1000) // Limit to 1000 chars + tool := NewWebFetchTool(1000, testFetchLimit) // Limit to 1000 chars ctx := context.Background() args := map[string]any{ "url": server.URL, @@ -174,6 +178,46 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { } } +func TestWebFetchTool_PayloadTooLarge(t *testing.T) { + // Create a mock HTTP server + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + + // Generate a payload intentionally larger than our limit. + // Limit: 10 * 1024 * 1024 (10MB). We generate 10MB + 100 bytes of the letter 'A'. + largeData := bytes.Repeat([]byte("A"), int(testFetchLimit)+100) + + w.Write(largeData) + })) + // Ensure the server is shut down at the end of the test + defer ts.Close() + + // Initialize the tool + tool := NewWebFetchTool(50000, testFetchLimit) + + // Prepare the arguments pointing to the URL of our local mock server + args := map[string]any{ + "url": ts.URL, + } + + // Execute the tool + ctx := context.Background() + result := tool.Execute(ctx, args) + + // Assuming ErrorResult sets the ForLLM field with the error text. + if result == nil { + t.Fatal("expected a ToolResult, got nil") + } + + // Search for the exact error string we set earlier in the Execute method + expectedErrorMsg := fmt.Sprintf("size exceeded %d bytes limit", testFetchLimit) + + if !strings.Contains(result.ForLLM, expectedErrorMsg) && !strings.Contains(result.ForUser, expectedErrorMsg) { + t.Errorf("test failed: expected error %q, but got: %+v", expectedErrorMsg, result) + } +} + // TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing func TestWebTool_WebSearch_NoApiKey(t *testing.T) { tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""}) @@ -215,7 +259,7 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + tool := NewWebFetchTool(50000, testFetchLimit) ctx := context.Background() args := map[string]any{ "url": server.URL, @@ -316,7 +360,7 @@ func TestWebFetchTool_extractText(t *testing.T) { // TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain func TestWebTool_WebFetch_MissingDomain(t *testing.T) { - tool := NewWebFetchTool(50000) + tool := NewWebFetchTool(50000, testFetchLimit) ctx := context.Background() args := map[string]any{ "url": "https://", @@ -438,7 +482,7 @@ func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) { } func TestNewWebFetchToolWithProxy(t *testing.T) { - tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890") + tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", testFetchLimit) if tool.maxChars != 1024 { t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024) } @@ -446,7 +490,7 @@ func TestNewWebFetchToolWithProxy(t *testing.T) { t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890") } - tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890") + tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", testFetchLimit) if tool.maxChars != 50000 { t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000) }