-
-
Notifications
You must be signed in to change notification settings - Fork 29
🛡️ Sentinel: Mitigate SSRF in user-supplied content fetching #303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
fb02f51
568bf18
da0f31c
25b27dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,14 +3,17 @@ package client | |
| import ( | ||
| "crypto/tls" | ||
| "fmt" | ||
| "net" | ||
| "net/http" | ||
| "net/url" | ||
| "syscall" | ||
| "time" | ||
|
|
||
| "github.com/Laisky/zap" | ||
|
|
||
| "github.com/songquanpeng/one-api/common/config" | ||
| "github.com/songquanpeng/one-api/common/logger" | ||
| "github.com/songquanpeng/one-api/common/network" | ||
| ) | ||
|
|
||
| // HTTPClient is the default outbound client used for relay requests. | ||
|
|
@@ -24,9 +27,30 @@ var UserContentRequestHTTPClient *http.Client | |
|
|
||
| // Init builds the shared HTTP clients with proxy and timeout settings derived from configuration. | ||
| func Init() { | ||
| // Create a transport with HTTP/2 disabled to avoid stream errors in CI environments | ||
| createTransport := func(proxyURL *url.URL) *http.Transport { | ||
| // Create a transport with HTTP/2 disabled to avoid stream errors in CI environments. | ||
| // Optionally blocks internal IP addresses to mitigate SSRF risks. | ||
| createTransport := func(proxyURL *url.URL, blockInternal bool) *http.Transport { | ||
| dialer := &net.Dialer{ | ||
| Timeout: 30 * time.Second, | ||
| KeepAlive: 30 * time.Second, | ||
| } | ||
|
|
||
| if blockInternal { | ||
| dialer.Control = func(networkName, address string, c syscall.RawConn) error { | ||
| host, _, err := net.SplitHostPort(address) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| ip := net.ParseIP(host) | ||
| if ip != nil && network.IsInternalIP(ip) { | ||
| return fmt.Errorf("SSRF protection: internal IP %s is blocked", ip) | ||
| } | ||
| return nil | ||
| } | ||
| } | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| transport := &http.Transport{ | ||
| DialContext: dialer.DialContext, | ||
| TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper), // Disable HTTP/2 | ||
| } | ||
| if proxyURL != nil { | ||
|
|
@@ -42,12 +66,12 @@ func Init() { | |
| logger.Logger.Fatal(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy)) | ||
| } | ||
| UserContentRequestHTTPClient = &http.Client{ | ||
| Transport: createTransport(proxyURL), | ||
| Transport: createTransport(proxyURL, config.BlockInternalUserContentRequests), | ||
| Timeout: time.Second * time.Duration(config.UserContentRequestTimeout), | ||
| } | ||
| } else { | ||
| UserContentRequestHTTPClient = &http.Client{ | ||
| Transport: createTransport(nil), | ||
| Transport: createTransport(nil, config.BlockInternalUserContentRequests), | ||
| Timeout: 30 * time.Second, // Set a reasonable default timeout | ||
|
||
| } | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
@@ -58,9 +82,9 @@ func Init() { | |
| if err != nil { | ||
| logger.Logger.Fatal(fmt.Sprintf("RELAY_PROXY set but invalid: %s", config.RelayProxy)) | ||
| } | ||
| transport = createTransport(proxyURL) | ||
| transport = createTransport(proxyURL, false) | ||
| } else { | ||
| transport = createTransport(nil) | ||
| transport = createTransport(nil, false) | ||
| } | ||
|
|
||
| if config.RelayTimeout == 0 { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,3 +28,16 @@ func TestInit(t *testing.T) { | |
| require.NotNil(t, HTTPClient) | ||
| require.NotNil(t, ImpatientHTTPClient) | ||
| } | ||
|
|
||
| func TestUserContentRequestHTTPClient_SSRF(t *testing.T) { | ||
| // Test that UserContentRequestHTTPClient blocks internal IPs | ||
| Init() | ||
|
|
||
| // Try to fetch from localhost (which is an internal IP) | ||
| // We use a random port that is likely not listening to avoid connection refused | ||
| // but the DialControl should block it before it even tries to connect. | ||
| _, err := UserContentRequestHTTPClient.Get("http://127.0.0.1:12345") | ||
| require.Error(t, err) | ||
| require.Contains(t, err.Error(), "SSRF protection") | ||
| require.Contains(t, err.Error(), "blocked") | ||
| } | ||
|
Comment on lines
33
to
80
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SSRF guard only blocks when the dial target host is a literal IP (net.ParseIP(host) != nil). Requests to hostnames like
localhost,*.internal, or DNS rebinding domains will bypass this check because the dialer will resolve them later. Consider resolvinghostvianet.Resolver.LookupIPAddr(and blocking if any resolved IP is internal), or performing the check earlier at the request layer usingreq.URL.Hostname()(including during redirects).