Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ import (

// FileConfig represents the YAML configuration file structure
type FileConfig struct {
Host string `yaml:"host"`
Port int `yaml:"port"`
DataDir string `yaml:"data_dir"`
TLS TLSConfig `yaml:"tls"`
Users map[string]string `yaml:"users"`
RateLimit RateLimitFileConfig `yaml:"rate_limit"`
Extensions []string `yaml:"extensions"`
DuckLake DuckLakeFileConfig `yaml:"ducklake"`
Host string `yaml:"host"`
Port int `yaml:"port"`
DataDir string `yaml:"data_dir"`
TLS TLSConfig `yaml:"tls"`
Users map[string]string `yaml:"users"`
RateLimit RateLimitFileConfig `yaml:"rate_limit"`
Extensions []string `yaml:"extensions"`
DuckLake DuckLakeFileConfig `yaml:"ducklake"`
MaxConnections int `yaml:"max_connections"` // Maximum concurrent connections (0 = unlimited)
}

type TLSConfig struct {
Expand Down Expand Up @@ -109,12 +110,13 @@ func main() {
fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nEnvironment variables:\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_CONFIG Path to YAML config file\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_HOST Host to bind to (default: 0.0.0.0)\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_PORT Port to listen on (default: 5432)\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_DATA_DIR Directory for DuckDB files (default: ./data)\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_CERT TLS certificate file (default: ./certs/server.crt)\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_KEY TLS private key file (default: ./certs/server.key)\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_CONFIG Path to YAML config file\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_HOST Host to bind to (default: 0.0.0.0)\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_PORT Port to listen on (default: 5432)\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_DATA_DIR Directory for DuckDB files (default: ./data)\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_CERT TLS certificate file (default: ./certs/server.crt)\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_KEY TLS private key file (default: ./certs/server.key)\n")
fmt.Fprintf(os.Stderr, " DUCKGRES_MAX_CONNECTIONS Maximum concurrent connections (default: 0, unlimited)\n")
fmt.Fprintf(os.Stderr, "\nPrecedence: CLI flags > environment variables > config file > defaults\n")
}

Expand Down Expand Up @@ -229,6 +231,11 @@ func main() {
if fileCfg.DuckLake.S3Profile != "" {
cfg.DuckLake.S3Profile = fileCfg.DuckLake.S3Profile
}

// Apply connection limit config
if fileCfg.MaxConnections > 0 {
cfg.MaxConnections = fileCfg.MaxConnections
}
}

// Apply environment variables (override config file)
Expand Down Expand Up @@ -282,6 +289,11 @@ func main() {
if v := os.Getenv("DUCKGRES_DUCKLAKE_S3_PROFILE"); v != "" {
cfg.DuckLake.S3Profile = v
}
if v := os.Getenv("DUCKGRES_MAX_CONNECTIONS"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n >= 0 {
cfg.MaxConnections = n
}
}

// Apply CLI flags (highest priority)
if *host != "" {
Expand Down
24 changes: 24 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ var rateLimitedIPsGauge = promauto.NewGauge(prometheus.GaugeOpts{
Help: "Number of currently rate-limited IP addresses",
})

var connectionLimitRejectsCounter = promauto.NewCounter(prometheus.CounterOpts{
Name: "duckgres_connection_limit_rejects_total",
Help: "Total number of connections rejected due to max_connections limit",
})

func redactConnectionString(connStr string) string {
return passwordPattern.ReplaceAllString(connStr, "${1}[REDACTED]")
}
Expand Down Expand Up @@ -81,6 +86,11 @@ type Config struct {
// This prevents accumulation of zombie connections from clients that disconnect
// uncleanly. Default: 10 minutes. Set to 0 to disable.
IdleTimeout time.Duration

// MaxConnections is the maximum number of concurrent client connections.
// New connections are rejected when this limit is reached.
// Default: 0 (unlimited).
MaxConnections int
}

// DuckLakeConfig configures DuckLake catalog attachment
Expand Down Expand Up @@ -170,6 +180,9 @@ func New(cfg Config) (*Server, error) {

slog.Info("TLS enabled.", "cert_file", cfg.TLSCertFile)
slog.Info("Rate limiting enabled.", "max_failed_attempts", cfg.RateLimit.MaxFailedAttempts, "window", cfg.RateLimit.FailedAttemptWindow, "ban_duration", cfg.RateLimit.BanDuration)
if cfg.MaxConnections > 0 {
slog.Info("Connection limit enabled.", "max_connections", cfg.MaxConnections)
}
return s, nil
}

Expand Down Expand Up @@ -618,6 +631,17 @@ func (s *Server) buildCredentialChainSecret() string {
func (s *Server) handleConnection(conn net.Conn) {
remoteAddr := conn.RemoteAddr()

// Check global connection limit
if s.cfg.MaxConnections > 0 {
currentConns := atomic.LoadInt64(&s.activeConns)
if currentConns >= int64(s.cfg.MaxConnections) {
slog.Warn("Connection rejected: max connections reached.", "remote_addr", remoteAddr, "current", currentConns, "max", s.cfg.MaxConnections)
connectionLimitRejectsCounter.Inc()
_ = conn.Close()
return
}
}

// Check rate limiting before doing anything
if msg := s.rateLimiter.CheckConnection(remoteAddr); msg != "" {
// Send PostgreSQL error and close
Expand Down