Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/config.example.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"discord": {
"enabled": false,
"token": "YOUR_DISCORD_BOT_TOKEN",
"proxy": "",
"allow_from": [],
"mention_only": false
},
Expand Down
41 changes: 41 additions & 0 deletions pkg/channels/discord.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ package channels
import (
"context"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"

"github.com/bwmarrin/discordgo"
"github.com/gorilla/websocket"

"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
Expand Down Expand Up @@ -39,6 +42,10 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
return nil, fmt.Errorf("failed to create discord session: %w", err)
}

if err := applyDiscordProxy(session, cfg.Proxy); err != nil {
return nil, err
}

base := NewBaseChannel("discord", cfg, bus, cfg.AllowFrom)

return &DiscordChannel{
Expand Down Expand Up @@ -357,9 +364,43 @@ func (c *DiscordChannel) stopTyping(chatID string) {
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "discord",
ProxyURL: c.config.Proxy,
})
}

func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
var proxyFunc func(*http.Request) (*url.URL, error)
if proxyAddr != "" {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err)
}
proxyFunc = http.ProxyURL(proxyURL)
} else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
proxyFunc = http.ProxyFromEnvironment
}

if proxyFunc == nil {
return nil
}

transport := &http.Transport{Proxy: proxyFunc}
session.Client = &http.Client{
Timeout: 20 * time.Second,
Transport: transport,
}

if session.Dialer != nil {
dialerCopy := *session.Dialer
dialerCopy.Proxy = proxyFunc
session.Dialer = &dialerCopy
} else {
session.Dialer = &websocket.Dialer{Proxy: proxyFunc}
}

return nil
}

// stripBotMention removes the bot mention from the message content.
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
func (c *DiscordChannel) stripBotMention(text string) string {
Expand Down
94 changes: 94 additions & 0 deletions pkg/channels/discord_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
//go:build discord_proxy
// +build discord_proxy

package channels

import (
"net/http"
"net/url"
"testing"

"github.com/bwmarrin/discordgo"
)

func TestApplyDiscordProxy_CustomProxy(t *testing.T) {
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}

if err := applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil {
t.Fatalf("applyDiscordProxy() error: %v", err)
}

req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}

restProxy := session.Client.Transport.(*http.Transport).Proxy
restProxyURL, err := restProxy(req)
if err != nil {
t.Fatalf("rest proxy func error: %v", err)
}
if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want {
t.Fatalf("REST proxy = %q, want %q", got, want)
}

wsProxyURL, err := session.Dialer.Proxy(req)
if err != nil {
t.Fatalf("ws proxy func error: %v", err)
}
if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want {
t.Fatalf("WS proxy = %q, want %q", got, want)
}
}

func TestApplyDiscordProxy_FromEnvironment(t *testing.T) {
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
t.Setenv("http_proxy", "http://127.0.0.1:8888")
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
t.Setenv("https_proxy", "http://127.0.0.1:8888")
t.Setenv("ALL_PROXY", "")
t.Setenv("all_proxy", "")
t.Setenv("NO_PROXY", "")
t.Setenv("no_proxy", "")

session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}

if err := applyDiscordProxy(session, ""); err != nil {
t.Fatalf("applyDiscordProxy() error: %v", err)
}

req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}

gotURL, err := session.Dialer.Proxy(req)
if err != nil {
t.Fatalf("ws proxy func error: %v", err)
}

wantURL, err := url.Parse("http://127.0.0.1:8888")
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
if gotURL.String() != wantURL.String() {
t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String())
}
}

func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) {
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}

if err := applyDiscordProxy(session, "://bad-proxy"); err == nil {
t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil")
}
}
1 change: 1 addition & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ type FeishuConfig struct {
type DiscordConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"`
}
Expand Down
23 changes: 19 additions & 4 deletions pkg/utils/media.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package utils
import (
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -52,11 +53,12 @@ type DownloadOptions struct {
Timeout time.Duration
ExtraHeaders map[string]string
LoggerPrefix string
ProxyURL string
}

// DownloadFile downloads a file from URL to a local temp directory.
// Returns the local file path or empty string on error.
func DownloadFile(url, filename string, opts DownloadOptions) string {
func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
// Set defaults
if opts.Timeout == 0 {
opts.Timeout = 60 * time.Second
Expand All @@ -78,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName)

// Create HTTP request
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{
"error": err.Error(),
Expand All @@ -92,11 +94,24 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
}

client := &http.Client{Timeout: opts.Timeout}
if opts.ProxyURL != "" {
proxyURL, parseErr := url.Parse(opts.ProxyURL)
if parseErr != nil {
logger.ErrorCF(opts.LoggerPrefix, "Invalid proxy URL for download", map[string]any{
"error": parseErr.Error(),
"proxy": opts.ProxyURL,
})
return ""
}
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
resp, err := client.Do(req)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{
"error": err.Error(),
"url": url,
"url": urlStr,
})
return ""
}
Expand All @@ -105,7 +120,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
if resp.StatusCode != http.StatusOK {
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{
"status": resp.StatusCode,
"url": url,
"url": urlStr,
})
return ""
}
Expand Down